gomog/internal/protocol/tcp/server.go

552 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package tcp
import (
"bufio"
"bytes"
"encoding/binary"
"encoding/json"
"io"
"log"
"net"
"sync"
"git.kingecg.top/kingecg/gomog/internal/engine"
"git.kingecg.top/kingecg/gomog/pkg/types"
)
// 操作码常量
const (
OP_REPLY = 1
OP_UPDATE = 4
OP_INSERT = 8
OP_QUERY = 2004
OP_GETMORE = 2006
OP_DELETE = 2007
OP_MSG = 2013
)
// MessageHeader 消息头16 字节)
type MessageHeader struct {
Length uint32 // 消息总长度
RequestID uint32 // 请求 ID
ResponseTo uint32 // 响应到的请求 ID
OpCode uint32 // 操作码
}
// TCPServer TCP 服务器
type TCPServer struct {
listener net.Listener
handler *MessageHandler
wg sync.WaitGroup
done chan struct{}
}
// MessageHandler 消息处理器
type MessageHandler struct {
store *engine.MemoryStore
crud *engine.CRUDHandler
agg *engine.AggregationEngine
}
// NewMessageHandler 创建消息处理器
func NewMessageHandler(store *engine.MemoryStore, crud *engine.CRUDHandler, agg *engine.AggregationEngine) *MessageHandler {
return &MessageHandler{
store: store,
crud: crud,
agg: agg,
}
}
// NewTCPServer 创建 TCP 服务器
func NewTCPServer(addr string, handler *MessageHandler) (*TCPServer, error) {
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
return &TCPServer{
listener: ln,
handler: handler,
done: make(chan struct{}),
}, nil
}
// Start 启动服务器
func (s *TCPServer) Start() error {
go s.acceptLoop()
return nil
}
// acceptLoop 接受连接循环
func (s *TCPServer) acceptLoop() {
for {
select {
case <-s.done:
return
default:
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.done:
return
default:
log.Printf("Accept error: %v", err)
continue
}
}
s.wg.Add(1)
go s.handleConnection(conn)
}
}
}
// handleConnection 处理连接
func (s *TCPServer) handleConnection(conn net.Conn) {
defer s.wg.Done()
defer conn.Close()
reader := bufio.NewReader(conn)
for {
select {
case <-s.done:
return
default:
// 读取消息头
header, err := readHeader(reader)
if err != nil {
if err != io.EOF {
log.Printf("Read header error: %v", err)
}
return
}
// 读取消息体
bodySize := header.Length - 16
body := make([]byte, bodySize)
if _, err := io.ReadFull(reader, body); err != nil {
log.Printf("Read body error: %v", err)
return
}
// 处理消息
response, err := s.handler.HandleMessage(header.OpCode, body, header.RequestID)
if err != nil {
sendErrorResponse(conn, header.RequestID, err)
continue
}
// 发送响应
if err := writeResponse(conn, header.RequestID, response); err != nil {
log.Printf("Write response error: %v", err)
return
}
}
}
}
// readHeader 读取消息头
func readHeader(r *bufio.Reader) (*MessageHeader, error) {
header := &MessageHeader{}
// 读取 16 字节消息头
buf := make([]byte, 16)
if _, err := io.ReadFull(r, buf); err != nil {
return nil, err
}
reader := bytes.NewReader(buf)
// 小端序读取
binary.Read(reader, binary.LittleEndian, &header.Length)
binary.Read(reader, binary.LittleEndian, &header.RequestID)
binary.Read(reader, binary.LittleEndian, &header.ResponseTo)
binary.Read(reader, binary.LittleEndian, &header.OpCode)
return header, nil
}
// writeResponse 写入响应
func writeResponse(conn net.Conn, requestID uint32, response interface{}) error {
// 序列化响应
data, err := json.Marshal(response)
if err != nil {
return err
}
// 构建响应消息
msgLength := uint32(16 + len(data))
header := &MessageHeader{
Length: msgLength,
RequestID: 0, // 服务器生成的请求 ID
ResponseTo: requestID,
OpCode: OP_REPLY,
}
// 写入消息头
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, header.Length)
binary.Write(buf, binary.LittleEndian, header.RequestID)
binary.Write(buf, binary.LittleEndian, header.ResponseTo)
binary.Write(buf, binary.LittleEndian, header.OpCode)
// 写入消息体
buf.Write(data)
_, err = conn.Write(buf.Bytes())
return err
}
// sendErrorResponse 发送错误响应
func sendErrorResponse(conn net.Conn, requestID uint32, err error) {
response := map[string]interface{}{
"ok": 0,
"errmsg": err.Error(),
}
writeResponse(conn, requestID, response)
}
// HandleMessage 处理消息
func (h *MessageHandler) HandleMessage(opCode uint32, body []byte, requestID uint32) (interface{}, error) {
switch opCode {
case OP_INSERT:
return h.handleInsert(body)
case OP_QUERY:
return h.handleQuery(body)
case OP_UPDATE:
return h.handleUpdate(body)
case OP_DELETE:
return h.handleDelete(body)
case OP_MSG:
return h.handleMsg(body)
default:
return nil, ErrUnknownOpCode
}
}
// handleInsert 处理插入消息
func (h *MessageHandler) handleInsert(body []byte) (interface{}, error) {
var req struct {
Collection string `json:"collection"`
Documents []map[string]interface{} `json:"documents"`
Ordered bool `json:"ordered"`
BypassValidation bool `json:"bypassDocumentValidation"`
}
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
// 执行插入
result, err := h.crud.Insert(nil, req.Collection, req.Documents)
if err != nil {
return nil, err
}
return map[string]interface{}{
"ok": 1,
"n": result.N,
"insertedIds": result.InsertedIDs,
}, nil
}
// handleQuery 处理查询消息
func (h *MessageHandler) handleQuery(body []byte) (interface{}, error) {
var req struct {
Collection string `json:"collection"`
Filter types.Filter `json:"filter"`
Projection types.Projection `json:"projection"`
Sort types.Sort `json:"sort"`
Skip int `json:"skip"`
Limit int `json:"limit"`
}
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
// 执行查询
docs, err := h.store.Find(req.Collection, req.Filter)
if err != nil {
return nil, err
}
// 应用限制和跳过
if req.Skip > 0 && req.Skip < len(docs) {
docs = docs[req.Skip:]
}
if req.Limit > 0 && req.Limit < len(docs) {
docs = docs[:req.Limit]
}
return map[string]interface{}{
"ok": 1,
"cursor": map[string]interface{}{
"firstBatch": docs,
"id": 0,
"ns": req.Collection,
},
}, nil
}
// handleUpdate 处理更新消息
func (h *MessageHandler) handleUpdate(body []byte) (interface{}, error) {
var req struct {
Collection string `json:"collection"`
Updates []types.UpdateOperation `json:"updates"`
Ordered bool `json:"ordered"`
}
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
totalMatched := 0
totalModified := 0
for _, op := range req.Updates {
matched, modified, err := h.store.Update(req.Collection, op.Q, op.U)
if err != nil {
return nil, err
}
totalMatched += matched
totalModified += modified
}
return map[string]interface{}{
"ok": 1,
"n": totalMatched,
"nModified": totalModified,
}, nil
}
// handleDelete 处理删除消息
func (h *MessageHandler) handleDelete(body []byte) (interface{}, error) {
var req struct {
Collection string `json:"collection"`
Deletes []types.DeleteOperation `json:"deletes"`
Ordered bool `json:"ordered"`
}
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
totalDeleted := 0
for _, op := range req.Deletes {
deleted, err := h.store.Delete(req.Collection, op.Q)
if err != nil {
return nil, err
}
totalDeleted += deleted
if op.Limit == 1 && deleted > 0 {
break
}
}
return map[string]interface{}{
"ok": 1,
"n": totalDeleted,
"deletedCount": totalDeleted,
}, nil
}
// handleMsg 处理 OP_MSG 消息MongoDB 3.6+ 通用消息格式)
func (h *MessageHandler) handleMsg(body []byte) (interface{}, error) {
// 解析 OP_MSG 格式
// 简化实现:假设 body 是 JSON 格式的通用请求
var req struct {
Operation string `json:"operation"`
Collection string `json:"collection"`
Params map[string]interface{} `json:"params"`
}
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
switch req.Operation {
case "find":
return h.handleFindMsg(req.Collection, req.Params)
case "insert":
return h.handleInsertMsg(req.Collection, req.Params)
case "update":
return h.handleUpdateMsg(req.Collection, req.Params)
case "delete":
return h.handleDeleteMsg(req.Collection, req.Params)
case "aggregate":
return h.handleAggregateMsg(req.Collection, req.Params)
default:
return nil, ErrUnknownOperation
}
}
// handleFindMsg 处理 find 消息
func (h *MessageHandler) handleFindMsg(collection string, params map[string]interface{}) (interface{}, error) {
filter, _ := params["filter"].(types.Filter)
docs, err := h.store.Find(collection, filter)
if err != nil {
return nil, err
}
return map[string]interface{}{
"ok": 1,
"cursor": map[string]interface{}{
"firstBatch": docs,
"id": 0,
"ns": collection,
},
}, nil
}
// handleInsertMsg 处理 insert 消息
func (h *MessageHandler) handleInsertMsg(collection string, params map[string]interface{}) (interface{}, error) {
documents, ok := params["documents"].([]map[string]interface{})
if !ok {
return nil, ErrInvalidDocuments
}
result, err := h.crud.Insert(nil, collection, documents)
if err != nil {
return nil, err
}
return map[string]interface{}{
"ok": 1,
"n": result.N,
"insertedIds": result.InsertedIDs,
}, nil
}
// handleUpdateMsg 处理 update 消息
func (h *MessageHandler) handleUpdateMsg(collection string, params map[string]interface{}) (interface{}, error) {
updatesRaw, ok := params["updates"].([]interface{})
if !ok {
return nil, ErrInvalidUpdates
}
// 转换 updates
updates := make([]types.UpdateOperation, 0, len(updatesRaw))
for _, u := range updatesRaw {
if updateMap, ok := u.(map[string]interface{}); ok {
q, _ := updateMap["q"].(types.Filter)
uData, _ := updateMap["u"].(types.Update)
upsert, _ := updateMap["upsert"].(bool)
multi, _ := updateMap["multi"].(bool)
updates = append(updates, types.UpdateOperation{
Q: q,
U: uData,
Upsert: upsert,
Multi: multi,
})
}
}
totalMatched := 0
totalModified := 0
for _, op := range updates {
matched, modified, err := h.store.Update(collection, op.Q, op.U)
if err != nil {
return nil, err
}
totalMatched += matched
totalModified += modified
}
return map[string]interface{}{
"ok": 1,
"n": totalMatched,
"nModified": totalModified,
}, nil
}
// handleDeleteMsg 处理 delete 消息
func (h *MessageHandler) handleDeleteMsg(collection string, params map[string]interface{}) (interface{}, error) {
deletesRaw, ok := params["deletes"].([]interface{})
if !ok {
return nil, ErrInvalidDeletes
}
deletes := make([]types.DeleteOperation, 0, len(deletesRaw))
for _, d := range deletesRaw {
if deleteMap, ok := d.(map[string]interface{}); ok {
q, _ := deleteMap["q"].(types.Filter)
limit := 0
if l, ok := deleteMap["limit"].(float64); ok {
limit = int(l)
}
deletes = append(deletes, types.DeleteOperation{
Q: q,
Limit: limit,
})
}
}
totalDeleted := 0
for _, op := range deletes {
deleted, err := h.store.Delete(collection, op.Q)
if err != nil {
return nil, err
}
totalDeleted += deleted
if op.Limit == 1 && deleted > 0 {
break
}
}
return map[string]interface{}{
"ok": 1,
"n": totalDeleted,
"deletedCount": totalDeleted,
}, nil
}
// handleAggregateMsg 处理 aggregate 消息
func (h *MessageHandler) handleAggregateMsg(collection string, params map[string]interface{}) (interface{}, error) {
pipelineRaw, ok := params["pipeline"].([]interface{})
if !ok {
return nil, ErrInvalidPipeline
}
// 转换 pipeline
pipeline := make([]types.AggregateStage, 0, len(pipelineRaw))
for _, stage := range pipelineRaw {
if stageMap, ok := stage.(map[string]interface{}); ok {
for stageName, spec := range stageMap {
pipeline = append(pipeline, types.AggregateStage{
Stage: stageName,
Spec: spec,
})
break
}
}
}
results, err := h.agg.Execute(collection, pipeline)
if err != nil {
return nil, err
}
return map[string]interface{}{
"ok": 1,
"result": results,
}, nil
}
// Stop 停止服务器
func (s *TCPServer) Stop() error {
close(s.done)
s.listener.Close()
s.wg.Wait()
return nil
}