gomog/internal/engine/aggregate_batch3.go

545 lines
12 KiB
Go

package engine
import (
"fmt"
"math/rand"
"sort"
"strings"
"time"
"git.kingecg.top/kingecg/gomog/pkg/types"
)
// executeReplaceRoot 执行 $replaceRoot 阶段
func (e *AggregationEngine) executeReplaceRoot(spec interface{}, docs []types.Document) ([]types.Document, error) {
specMap, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
newRootRaw, exists := specMap["newRoot"]
if !exists {
return docs, nil
}
var results []types.Document
for _, doc := range docs {
newRoot := e.evaluateExpression(doc.Data, newRootRaw)
if newRootMap, ok := newRoot.(map[string]interface{}); ok {
results = append(results, types.Document{
ID: doc.ID,
Data: newRootMap,
CreatedAt: doc.CreatedAt,
UpdatedAt: doc.UpdatedAt,
})
} else {
// 如果不是对象,创建包装文档
results = append(results, types.Document{
ID: doc.ID,
Data: map[string]interface{}{"value": newRoot},
CreatedAt: doc.CreatedAt,
UpdatedAt: doc.UpdatedAt,
})
}
}
return results, nil
}
// executeReplaceWith 执行 $replaceWith 阶段($replaceRoot 的别名)
func (e *AggregationEngine) executeReplaceWith(spec interface{}, docs []types.Document) ([]types.Document, error) {
// $replaceWith 是 $replaceRoot 的简写形式
// spec 本身就是 newRoot 表达式
var results []types.Document
for _, doc := range docs {
newRoot := e.evaluateExpression(doc.Data, spec)
if newRootMap, ok := newRoot.(map[string]interface{}); ok {
results = append(results, types.Document{
ID: doc.ID,
Data: newRootMap,
CreatedAt: doc.CreatedAt,
UpdatedAt: doc.UpdatedAt,
})
} else {
// 如果不是对象,创建包装文档
results = append(results, types.Document{
ID: doc.ID,
Data: map[string]interface{}{"value": newRoot},
CreatedAt: doc.CreatedAt,
UpdatedAt: doc.UpdatedAt,
})
}
}
return results, nil
}
// executeGraphLookup 执行 $graphLookup 阶段(递归查找)
func (e *AggregationEngine) executeGraphLookup(spec interface{}, docs []types.Document) ([]types.Document, error) {
specMap, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
from, _ := specMap["from"].(string)
startWith := specMap["startWith"]
connectFromField, _ := specMap["connectFromField"].(string)
connectToField, _ := specMap["connectToField"].(string)
as, _ := specMap["as"].(string)
maxDepthRaw, _ := specMap["maxDepth"].(float64)
restrictSearchWithMatchRaw, _ := specMap["restrictSearchWithMatch"]
if as == "" || connectFromField == "" || connectToField == "" {
return docs, nil
}
maxDepth := int(maxDepthRaw)
if maxDepth == 0 {
maxDepth = -1 // 无限制
}
var results []types.Document
for _, doc := range docs {
// 计算起始值
startValue := e.evaluateExpression(doc.Data, startWith)
// 递归查找
connectedDocs := e.graphLookupRecursive(
from,
startValue,
connectFromField,
connectToField,
maxDepth,
restrictSearchWithMatchRaw,
make(map[string]bool),
)
// 添加结果数组
newDoc := make(map[string]interface{})
for k, v := range doc.Data {
newDoc[k] = v
}
newDoc[as] = connectedDocs
results = append(results, types.Document{
ID: doc.ID,
Data: newDoc,
CreatedAt: doc.CreatedAt,
UpdatedAt: doc.UpdatedAt,
})
}
return results, nil
}
// graphLookupRecursive 递归查找关联文档
func (e *AggregationEngine) graphLookupRecursive(
collection string,
startValue interface{},
connectFromField string,
connectToField string,
maxDepth int,
restrictSearchWithMatch interface{},
visited map[string]bool,
) []map[string]interface{} {
var results []map[string]interface{}
if maxDepth == 0 {
return results
}
// 获取目标集合
targetCollection := e.store.collections[collection]
if targetCollection == nil {
return results
}
// 查找匹配的文档
for docID, doc := range targetCollection.documents {
// 避免循环引用
if visited[docID] {
continue
}
// 检查是否匹配
docValue := getNestedValue(doc.Data, connectToField)
if !valuesEqual(startValue, docValue) {
continue
}
// 应用 restrictSearchWithMatch 过滤
if restrictSearchWithMatch != nil {
if matchSpec, ok := restrictSearchWithMatch.(map[string]interface{}); ok {
if !MatchFilter(doc.Data, matchSpec) {
continue
}
}
}
// 标记为已访问
visited[docID] = true
// 添加到结果
docCopy := make(map[string]interface{})
for k, v := range doc.Data {
docCopy[k] = v
}
results = append(results, docCopy)
// 递归查找下一级
nextValue := getNestedValue(doc.Data, connectFromField)
moreResults := e.graphLookupRecursive(
collection,
nextValue,
connectFromField,
connectToField,
maxDepth-1,
restrictSearchWithMatch,
visited,
)
results = append(results, moreResults...)
}
return results
}
// executeSetWindowFields 执行 $setWindowFields 阶段(窗口函数)
func (e *AggregationEngine) executeSetWindowFields(spec interface{}, docs []types.Document) ([]types.Document, error) {
specMap, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
outputsRaw, _ := specMap["output"].(map[string]interface{})
partitionByRaw, _ := specMap["partitionBy"]
sortByRaw, _ := specMap["sortBy"].(map[string]interface{})
if outputsRaw == nil {
return docs, nil
}
// 分组(分区)
partitions := make(map[string][]types.Document)
for _, doc := range docs {
var key string
if partitionByRaw != nil {
partitionKey := e.evaluateExpression(doc.Data, partitionByRaw)
key = fmt.Sprintf("%v", partitionKey)
} else {
key = "all"
}
partitions[key] = append(partitions[key], doc)
}
// 对每个分区排序
for key := range partitions {
if sortByRaw != nil && len(sortByRaw) > 0 {
sortDocsBySpec(partitions[key], sortByRaw)
}
}
// 应用窗口函数
var results []types.Document
for _, partition := range partitions {
for i, doc := range partition {
newDoc := make(map[string]interface{})
for k, v := range doc.Data {
newDoc[k] = v
}
// 计算每个输出字段
for fieldName, windowSpecRaw := range outputsRaw {
windowSpec, ok := windowSpecRaw.(map[string]interface{})
if !ok {
continue
}
value := e.calculateWindowValue(windowSpec, partition, i, doc)
newDoc[fieldName] = value
}
results = append(results, types.Document{
ID: doc.ID,
Data: newDoc,
CreatedAt: doc.CreatedAt,
UpdatedAt: doc.UpdatedAt,
})
}
}
return results, nil
}
// calculateWindowValue 计算窗口函数值
func (e *AggregationEngine) calculateWindowValue(
windowSpec map[string]interface{},
partition []types.Document,
currentIndex int,
currentDoc types.Document,
) interface{} {
// 解析窗口操作符
for op, operand := range windowSpec {
switch op {
case "$documentNumber":
return float64(currentIndex + 1)
case "$rank":
return float64(currentIndex + 1)
case "$first":
expr := e.evaluateExpression(partition[0].Data, operand)
return expr
case "$last":
expr := e.evaluateExpression(partition[len(partition)-1].Data, operand)
return expr
case "$shift":
n := int(toFloat64(operand))
targetIndex := currentIndex + n
if targetIndex < 0 || targetIndex >= len(partition) {
return nil
}
return partition[targetIndex].Data
case "$fillDefault":
val := e.evaluateExpression(currentDoc.Data, operand)
if val == nil {
return 0 // 默认值
}
return val
case "$sum", "$avg", "$min", "$max":
// 聚合窗口函数
return e.aggregateWindow(op, operand, partition, currentIndex)
default:
// 普通表达式
return e.evaluateExpression(currentDoc.Data, windowSpec)
}
}
return nil
}
// aggregateWindow 聚合窗口函数
func (e *AggregationEngine) aggregateWindow(
op string,
operand interface{},
partition []types.Document,
currentIndex int,
) interface{} {
var values []float64
for i, doc := range partition {
// 根据窗口范围决定是否包含
windowSpec := getWindowRange(op, operand)
if !inWindow(i, currentIndex, windowSpec) {
continue
}
val := e.evaluateExpression(doc.Data, operand)
if num, ok := toNumber(val); ok {
values = append(values, num)
}
}
if len(values) == 0 {
return nil
}
switch op {
case "$sum":
sum := 0.0
for _, v := range values {
sum += v
}
return sum
case "$avg":
sum := 0.0
for _, v := range values {
sum += v
}
return sum / float64(len(values))
case "$min":
min := values[0]
for _, v := range values[1:] {
if v < min {
min = v
}
}
return min
case "$max":
max := values[0]
for _, v := range values[1:] {
if v > max {
max = v
}
}
return max
default:
return nil
}
}
// getWindowRange 获取窗口范围
func getWindowRange(op string, operand interface{}) map[string]interface{} {
// 简化实现:默认使用整个分区
return map[string]interface{}{"window": "all"}
}
// inWindow 检查索引是否在窗口内
func inWindow(index, current int, windowSpec map[string]interface{}) bool {
// 简化实现:包含所有索引
return true
}
// executeTextSearch 执行 $text 文本搜索
func (e *AggregationEngine) executeTextSearch(docs []types.Document, search string, language string, caseSensitive bool) ([]types.Document, error) {
var results []types.Document
// 分词搜索
searchTerms := strings.Fields(strings.ToLower(search))
for _, doc := range docs {
score := e.calculateTextScore(doc.Data, searchTerms, caseSensitive)
if score > 0 {
// 添加文本得分
newDoc := make(map[string]interface{})
for k, v := range doc.Data {
newDoc[k] = v
}
newDoc["_textScore"] = score
results = append(results, types.Document{
ID: doc.ID,
Data: newDoc,
CreatedAt: doc.CreatedAt,
UpdatedAt: doc.UpdatedAt,
})
}
}
// 按文本得分排序
sort.Slice(results, func(i, j int) bool {
scoreI := results[i].Data["_textScore"].(float64)
scoreJ := results[j].Data["_textScore"].(float64)
return scoreI > scoreJ
})
return results, nil
}
// calculateTextScore 计算文本匹配得分
func (e *AggregationEngine) calculateTextScore(doc map[string]interface{}, searchTerms []string, caseSensitive bool) float64 {
score := 0.0
// 递归搜索所有字符串字段
e.searchInValue(doc, searchTerms, caseSensitive, &score)
return score
}
// searchInValue 在值中搜索
func (e *AggregationEngine) searchInValue(value interface{}, searchTerms []string, caseSensitive bool, score *float64) {
switch v := value.(type) {
case string:
if !caseSensitive {
v = strings.ToLower(v)
}
for _, term := range searchTerms {
searchTerm := term
if !caseSensitive {
searchTerm = strings.ToLower(term)
}
if strings.Contains(v, searchTerm) {
*score += 1.0
}
}
case []interface{}:
for _, item := range v {
e.searchInValue(item, searchTerms, caseSensitive, score)
}
case map[string]interface{}:
for _, val := range v {
e.searchInValue(val, searchTerms, caseSensitive, score)
}
}
}
// sortDocsBySpec 根据规范对文档排序
func sortDocsBySpec(docs []types.Document, sortByRaw map[string]interface{}) {
type sortKeys struct {
doc types.Document
keys []float64
}
keys := make([]sortKeys, len(docs))
for i, doc := range docs {
var docKeys []float64
for _, fieldRaw := range sortByRaw {
field := getFieldValueStrFromDoc(doc, fieldRaw)
if num, ok := toNumber(field); ok {
docKeys = append(docKeys, num)
} else {
docKeys = append(docKeys, 0)
}
}
keys[i] = sortKeys{doc: doc, keys: docKeys}
}
sort.Slice(keys, func(i, j int) bool {
for k := range keys[i].keys {
if keys[i].keys[k] != keys[j].keys[k] {
return keys[i].keys[k] < keys[j].keys[k]
}
}
return false
})
for i, k := range keys {
docs[i] = k.doc
}
}
// getFieldValueStrFromDoc 从文档获取字段值
func getFieldValueStrFromDoc(doc types.Document, fieldRaw interface{}) interface{} {
if fieldStr, ok := fieldRaw.(string); ok {
return getNestedValue(doc.Data, fieldStr)
}
return fieldRaw
}
// valuesEqual 比较两个值是否相等
func valuesEqual(a, b interface{}) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b)
}
// getRandomDocuments 随机获取指定数量的文档
func getRandomDocuments(docs []types.Document, n int) []types.Document {
if n >= len(docs) {
return docs
}
// 随机打乱
rand.Seed(time.Now().UnixNano())
rand.Shuffle(len(docs), func(i, j int) {
docs[i], docs[j] = docs[j], docs[i]
})
return docs[:n]
}