gomog/internal/engine/aggregate.go

787 lines
18 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package engine
import (
"fmt"
"sort"
"strings"
"time"
"git.kingecg.top/kingecg/gomog/pkg/errors"
"git.kingecg.top/kingecg/gomog/pkg/types"
)
// AggregationEngine 聚合引擎
type AggregationEngine struct {
store *MemoryStore
}
// NewAggregationEngine 创建聚合引擎
func NewAggregationEngine(store *MemoryStore) *AggregationEngine {
return &AggregationEngine{store: store}
}
// Execute 执行聚合管道
func (e *AggregationEngine) Execute(collection string, pipeline []types.AggregateStage) ([]types.Document, error) {
// 获取集合所有文档
docs, err := e.store.GetAllDocuments(collection)
if err != nil {
return nil, err
}
// 依次执行每个阶段
result := docs
for _, stage := range pipeline {
result, err = e.executeStage(stage, result)
if err != nil {
return nil, errors.Wrap(err, errors.ErrAggregationError, "aggregation failed")
}
}
return result, nil
}
// executeStage 执行单个阶段
func (e *AggregationEngine) executeStage(stage types.AggregateStage, docs []types.Document) ([]types.Document, error) {
switch stage.Stage {
case "$match":
return e.executeMatch(stage.Spec, docs)
case "$group":
return e.executeGroup(stage.Spec, docs)
case "$sort":
return e.executeSort(stage.Spec, docs)
case "$project":
return e.executeProject(stage.Spec, docs)
case "$limit":
return e.executeLimit(stage.Spec, docs)
case "$skip":
return e.executeSkip(stage.Spec, docs)
case "$unwind":
return e.executeUnwind(stage.Spec, docs)
case "$lookup":
return e.executeLookup(stage.Spec, docs)
case "$count":
return e.executeCount(stage.Spec, docs)
case "$addFields", "$set":
return e.executeAddFields(stage.Spec, docs)
case "$unset":
return e.executeUnset(stage.Spec, docs)
case "$facet":
return e.executeFacet(stage.Spec, docs)
case "$sample":
return e.executeSample(stage.Spec, docs)
case "$bucket":
return e.executeBucket(stage.Spec, docs)
case "$replaceRoot":
return e.executeReplaceRoot(stage.Spec, docs)
case "$replaceWith":
return e.executeReplaceWith(stage.Spec, docs)
case "$graphLookup":
return e.executeGraphLookup(stage.Spec, docs)
case "$setWindowFields":
return e.executeSetWindowFields(stage.Spec, docs)
default:
return docs, nil // 未知阶段,跳过
}
}
// executeMatch 执行 $match 阶段
func (e *AggregationEngine) executeMatch(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 处理 types.Filter 类型
var filter map[string]interface{}
if f, ok := spec.(types.Filter); ok {
filter = f
} else if f, ok := spec.(map[string]interface{}); ok {
filter = f
} else {
return docs, nil
}
var results []types.Document
for _, doc := range docs {
if MatchFilter(doc.Data, filter) {
results = append(results, doc)
}
}
return results, nil
}
// executeGroup 执行 $group 阶段
func (e *AggregationEngine) executeGroup(spec interface{}, docs []types.Document) ([]types.Document, error) {
groupSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
// 获取分组字段
idField, _ := groupSpec["_id"].(string)
// 分组
groups := make(map[string][]types.Document)
for _, doc := range docs {
key := e.getGroupKey(doc, idField)
groups[key] = append(groups[key], doc)
}
// 聚合每个组
var results []types.Document
for key, groupDocs := range groups {
aggregated := e.aggregateGroup(groupSpec, groupDocs)
// 设置 _id
if key != "" {
aggregated["_id"] = key
}
results = append(results, types.Document{
ID: key,
Data: aggregated,
})
}
return results, nil
}
// getGroupKey 获取分组键
func (e *AggregationEngine) getGroupKey(doc types.Document, field string) string {
if field == "" || field[0] != '$' {
return ""
}
fieldName := field[1:] // 去掉 $ 前缀
value := getNestedValue(doc.Data, fieldName)
if value == nil {
return ""
}
// 转换为字符串作为键
switch v := value.(type) {
case string:
return v
case int, int64, float64:
return toString(v)
default:
return toString(value)
}
}
// aggregateGroup 聚合一组文档
func (e *AggregationEngine) aggregateGroup(groupSpec map[string]interface{}, docs []types.Document) map[string]interface{} {
result := make(map[string]interface{})
for field, expr := range groupSpec {
if field == "_id" {
continue
}
// 处理聚合操作符
if exprMap, ok := expr.(map[string]interface{}); ok {
for op, operand := range exprMap {
switch op {
case "$sum":
result[field] = e.sum(docs, operand)
case "$avg":
result[field] = e.avg(docs, operand)
case "$min":
result[field] = e.min(docs, operand)
case "$max":
result[field] = e.max(docs, operand)
case "$count":
result[field] = len(docs)
case "$first":
if len(docs) > 0 {
result[field] = e.getFieldValue(docs[0], operand)
}
case "$last":
if len(docs) > 0 {
result[field] = e.getFieldValue(docs[len(docs)-1], operand)
}
case "$push":
values := make([]interface{}, 0, len(docs))
for _, doc := range docs {
values = append(values, e.getFieldValue(doc, operand))
}
result[field] = values
case "$addToSet":
set := make(map[interface{}]bool)
for _, doc := range docs {
v := e.getFieldValue(doc, operand)
set[v] = true
}
values := make([]interface{}, 0, len(set))
for v := range set {
values = append(values, v)
}
result[field] = values
}
}
}
}
return result
}
// sum 计算总和
func (e *AggregationEngine) sum(docs []types.Document, field interface{}) float64 {
total := 0.0
for _, doc := range docs {
total += toFloat64(e.getFieldValue(doc, field))
}
return total
}
// avg 计算平均值
func (e *AggregationEngine) avg(docs []types.Document, field interface{}) float64 {
if len(docs) == 0 {
return 0
}
return e.sum(docs, field) / float64(len(docs))
}
// min 计算最小值
func (e *AggregationEngine) min(docs []types.Document, field interface{}) float64 {
if len(docs) == 0 {
return 0
}
min := toFloat64(e.getFieldValue(docs[0], field))
for _, doc := range docs[1:] {
val := toFloat64(e.getFieldValue(doc, field))
if val < min {
min = val
}
}
return min
}
// max 计算最大值
func (e *AggregationEngine) max(docs []types.Document, field interface{}) float64 {
if len(docs) == 0 {
return 0
}
max := toFloat64(e.getFieldValue(docs[0], field))
for _, doc := range docs[1:] {
val := toFloat64(e.getFieldValue(doc, field))
if val > max {
max = val
}
}
return max
}
// getFieldValue 获取字段值
func (e *AggregationEngine) getFieldValue(doc types.Document, field interface{}) interface{} {
switch f := field.(type) {
case string:
if len(f) > 0 && f[0] == '$' {
return getNestedValue(doc.Data, f[1:])
}
return f
default:
return field
}
}
// executeSort 执行 $sort 阶段
func (e *AggregationEngine) executeSort(spec interface{}, docs []types.Document) ([]types.Document, error) {
sortSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
// 转换为排序字段映射
sortFields := make(map[string]int)
for field, direction := range sortSpec {
dir := 1
switch d := direction.(type) {
case int:
dir = d
case int64:
dir = int(d)
case float64:
dir = int(d)
}
sortFields[field] = dir
}
// 创建可排序的副本
sorted := make([]types.Document, len(docs))
copy(sorted, docs)
sort.Slice(sorted, func(i, j int) bool {
return e.compareDocs(sorted[i], sorted[j], sortFields)
})
return sorted, nil
}
// compareDocs 比较两个文档
func (e *AggregationEngine) compareDocs(a, b types.Document, sortFields map[string]int) bool {
for field, dir := range sortFields {
valA := getNestedValue(a.Data, field)
valB := getNestedValue(b.Data, field)
cmp := compareValues(valA, valB)
if cmp != 0 {
if dir < 0 {
return cmp > 0
}
return cmp < 0
}
}
return false
}
// compareValues 比较两个值
func compareValues(a, b interface{}) int {
if a == nil && b == nil {
return 0
}
if a == nil {
return -1
}
if b == nil {
return 1
}
// 数值比较
numA := toFloat64(a)
numB := toFloat64(b)
if numA < numB {
return -1
} else if numA > numB {
return 1
}
return 0
}
// executeProject 执行 $project 阶段
func (e *AggregationEngine) executeProject(spec interface{}, docs []types.Document) ([]types.Document, error) {
projectSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
var results []types.Document
for _, doc := range docs {
projected := e.projectDocument(doc.Data, projectSpec)
results = append(results, types.Document{
ID: doc.ID,
Data: projected,
})
}
return results, nil
}
// projectDocument 投影文档
func (e *AggregationEngine) projectDocument(data map[string]interface{}, spec map[string]interface{}) map[string]interface{} {
result := make(map[string]interface{})
for field, include := range spec {
if field == "_id" {
// 特殊处理 _id
if isFalse(include) {
// 排除 _id
} else {
result["_id"] = data["_id"]
}
continue
}
if isTrue(include) {
// 包含字段
result[field] = getNestedValue(data, field)
} else if isFalse(include) {
// 排除字段(在包含模式下不处理)
continue
} else {
// 表达式
result[field] = e.evaluateExpression(data, include)
}
}
return result
}
// evaluateExpression 评估表达式
func (e *AggregationEngine) evaluateExpression(data map[string]interface{}, expr interface{}) interface{} {
// 处理 types.Filter 类型(转换为 map[string]interface{}
if filter, ok := expr.(types.Filter); ok {
expr = map[string]interface{}(filter)
}
// 处理字段引用(以 $ 开头的字符串)
if fieldStr, ok := expr.(string); ok && len(fieldStr) > 0 && fieldStr[0] == '$' {
fieldName := fieldStr[1:] // 移除 $ 前缀
return getNestedValue(data, fieldName)
}
if exprMap, ok := expr.(map[string]interface{}); ok {
for op, operand := range exprMap {
switch op {
case "$concat":
return e.concat(operand, data)
case "$substr", "$substring":
return e.substr(operand, data)
case "$toUpper":
str := e.getFieldValueStr(types.Document{Data: data}, operand)
return strings.ToUpper(str)
case "$toLower":
str := e.getFieldValueStr(types.Document{Data: data}, operand)
return strings.ToLower(str)
case "$add":
return e.add(operand, data)
case "$multiply":
return e.multiply(operand, data)
case "$divide":
return e.divide(operand, data)
case "$subtract":
return e.subtract(operand, data)
case "$abs":
return e.abs(operand, data)
case "$ceil":
return e.ceil(operand, data)
case "$floor":
return e.floor(operand, data)
case "$round":
return e.round(operand, data)
case "$sqrt":
return e.sqrt(operand, data)
case "$pow":
return e.pow(operand, data)
case "$size":
arr := getNestedValue(data, operand.(string))
if a, ok := arr.([]interface{}); ok {
return len(a)
}
return 0
case "$ifNull":
return e.ifNull(operand, data)
case "$cond":
return e.cond(operand, data)
case "$switch":
return e.switchExpr(operand, data)
case "$trim":
return e.trim(operand, data)
case "$ltrim":
return e.ltrim(operand, data)
case "$rtrim":
return e.rtrim(operand, data)
case "$split":
return e.split(operand, data)
case "$replaceAll":
return e.replaceAll(operand, data)
case "$strcasecmp":
return e.strcasecmp(operand, data)
case "$filter":
return e.filter(operand, data)
case "$map":
return e.mapArr(operand, data)
case "$concatArrays":
return e.concatArrays(operand, data)
case "$slice":
return e.slice(operand, data)
case "$mergeObjects":
return e.mergeObjects(operand, data)
case "$objectToArray":
return e.objectToArray(operand, data)
case "$year":
return e.year(operand, data)
case "$month":
return e.month(operand, data)
case "$dayOfMonth":
return e.dayOfMonth(operand, data)
case "$hour":
return e.hour(operand, data)
case "$minute":
return e.minute(operand, data)
case "$second":
return e.second(operand, data)
case "$dateToString":
return e.dateToString(operand, data)
case "$dateAdd":
return e.dateAdd(operand, data)
case "$dateDiff":
return e.dateDiff(operand, data)
case "$week":
return float64(e.week(operand, data))
case "$isoWeek":
return float64(e.isoWeek(operand, data))
case "$dayOfYear":
return float64(e.dayOfYear(operand, data))
case "$isoDayOfWeek":
return float64(e.isoDayOfWeek(operand, data))
case "$now":
return e.now().Format(time.RFC3339)
case "$gt":
return e.compareGt(operand, data)
case "$gte":
return e.compareGte(operand, data)
case "$lt":
return e.compareLt(operand, data)
case "$lte":
return e.compareLte(operand, data)
case "$eq":
return e.compareEq(operand, data)
case "$ne":
return e.compareNe(operand, data)
}
}
}
return expr
}
// executeLimit 执行 $limit 阶段
func (e *AggregationEngine) executeLimit(spec interface{}, docs []types.Document) ([]types.Document, error) {
limit := 0
switch l := spec.(type) {
case int:
limit = l
case int64:
limit = int(l)
case float64:
limit = int(l)
}
if limit <= 0 || limit >= len(docs) {
return docs, nil
}
return docs[:limit], nil
}
// executeSkip 执行 $skip 阶段
func (e *AggregationEngine) executeSkip(spec interface{}, docs []types.Document) ([]types.Document, error) {
skip := 0
switch s := spec.(type) {
case int:
skip = s
case int64:
skip = int(s)
case float64:
skip = int(s)
}
if skip <= 0 {
return docs, nil
}
if skip >= len(docs) {
return []types.Document{}, nil
}
return docs[skip:], nil
}
// executeUnwind 执行 $unwind 阶段
func (e *AggregationEngine) executeUnwind(spec interface{}, docs []types.Document) ([]types.Document, error) {
var path string
var preserveNull bool
switch s := spec.(type) {
case string:
path = s
case map[string]interface{}:
if p, ok := s["path"].(string); ok {
path = p
}
if pn, ok := s["preserveNullAndEmptyArrays"].(bool); ok {
preserveNull = pn
}
}
if path == "" || path[0] != '$' {
return docs, nil
}
fieldPath := path[1:]
var results []types.Document
for _, doc := range docs {
arr := getNestedValue(doc.Data, fieldPath)
if arr == nil {
if preserveNull {
results = append(results, doc)
}
continue
}
array, ok := arr.([]interface{})
if !ok || len(array) == 0 {
if preserveNull {
results = append(results, doc)
}
continue
}
for _, item := range array {
newData := deepCopyMap(doc.Data)
setNestedValue(newData, fieldPath, item)
results = append(results, types.Document{
ID: doc.ID,
Data: newData,
})
}
}
return results, nil
}
// executeLookup 执行 $lookup 阶段
func (e *AggregationEngine) executeLookup(spec interface{}, docs []types.Document) ([]types.Document, error) {
lookupSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
from, _ := lookupSpec["from"].(string)
localField, _ := lookupSpec["localField"].(string)
foreignField, _ := lookupSpec["foreignField"].(string)
as, _ := lookupSpec["as"].(string)
if from == "" || as == "" {
return docs, nil
}
// 获取关联集合的数据
foreignDocs, err := e.store.GetAllDocuments(from)
if err != nil {
return docs, nil // 忽略错误,继续处理
}
var results []types.Document
for _, doc := range docs {
localValue := getNestedValue(doc.Data, localField)
var matches []map[string]interface{}
for _, foreignDoc := range foreignDocs {
foreignValue := getNestedValue(foreignDoc.Data, foreignField)
if compareEq(localValue, foreignValue) {
matches = append(matches, foreignDoc.Data)
}
}
newData := deepCopyMap(doc.Data)
newData[as] = matches
results = append(results, types.Document{
ID: doc.ID,
Data: newData,
})
}
return results, nil
}
// executeCount 执行 $count 阶段
func (e *AggregationEngine) executeCount(spec interface{}, docs []types.Document) ([]types.Document, error) {
fieldName, ok := spec.(string)
if !ok {
fieldName = "count"
}
return []types.Document{
{
ID: "count",
Data: map[string]interface{}{
fieldName: len(docs),
},
},
}, nil
}
// 辅助函数
func isTrue(v interface{}) bool {
switch val := v.(type) {
case bool:
return val
case int:
return val != 0
case float64:
return val != 0
}
return true
}
func isFalse(v interface{}) bool {
return !isTrue(v)
}
func toString(v interface{}) string {
switch val := v.(type) {
case string:
return val
case int:
return string(rune(val))
case int64:
return string(rune(val))
case float64:
return fmt.Sprintf("%v", val)
default:
return ""
}
}
// 比较操作符辅助方法
func (e *AggregationEngine) compareGt(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return false
}
left := e.evaluateExpression(data, arr[0])
right := e.evaluateExpression(data, arr[1])
return toFloat64(left) > toFloat64(right)
}
func (e *AggregationEngine) compareGte(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return false
}
left := e.evaluateExpression(data, arr[0])
right := e.evaluateExpression(data, arr[1])
return toFloat64(left) >= toFloat64(right)
}
func (e *AggregationEngine) compareLt(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return false
}
left := e.evaluateExpression(data, arr[0])
right := e.evaluateExpression(data, arr[1])
return toFloat64(left) < toFloat64(right)
}
func (e *AggregationEngine) compareLte(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return false
}
left := e.evaluateExpression(data, arr[0])
right := e.evaluateExpression(data, arr[1])
return toFloat64(left) <= toFloat64(right)
}
func (e *AggregationEngine) compareEq(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return false
}
left := e.evaluateExpression(data, arr[0])
right := e.evaluateExpression(data, arr[1])
return left == right
}
func (e *AggregationEngine) compareNe(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return false
}
left := e.evaluateExpression(data, arr[0])
right := e.evaluateExpression(data, arr[1])
return left != right
}