File size: 3,419 Bytes
b110593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
//                           _       _
// __      _____  __ ___   ___  __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
//  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
//  CONTACT: [email protected]
//

package db

import (
	"fmt"
	"sort"

	"github.com/weaviate/weaviate/entities/additional"
	"github.com/weaviate/weaviate/entities/searchparams"
	"github.com/weaviate/weaviate/entities/storobj"
)

type groupMerger struct {
	objects []*storobj.Object
	dists   []float32
	groupBy *searchparams.GroupBy
}

func newGroupMerger(objects []*storobj.Object, dists []float32,

	groupBy *searchparams.GroupBy,

) *groupMerger {
	return &groupMerger{objects, dists, groupBy}
}

func (gm *groupMerger) Do() ([]*storobj.Object, []float32, error) {
	groups := map[string][]*additional.Group{}
	objects := map[string][]int{}

	for i, obj := range gm.objects {
		g, ok := obj.AdditionalProperties()["group"]
		if !ok {
			return nil, nil, fmt.Errorf("group not found for object: %v", obj.ID())
		}
		group, ok := g.(*additional.Group)
		if !ok {
			return nil, nil, fmt.Errorf("wrong group type for object: %v", obj.ID())
		}
		groups[group.GroupedBy.Value] = append(groups[group.GroupedBy.Value], group)
		objects[group.GroupedBy.Value] = append(objects[group.GroupedBy.Value], i)
	}

	getMinDistance := func(groups []*additional.Group) float32 {
		min := groups[0].MinDistance
		for i := range groups {
			if groups[i].MinDistance < min {
				min = groups[i].MinDistance
			}
		}
		return min
	}

	type groupMinDistance struct {
		value    string
		distance float32
	}

	groupDistances := []groupMinDistance{}
	for val, group := range groups {
		groupDistances = append(groupDistances, groupMinDistance{
			value: val, distance: getMinDistance(group),
		})
	}

	sort.Slice(groupDistances, func(i, j int) bool {
		return groupDistances[i].distance < groupDistances[j].distance
	})

	desiredLength := len(groups)
	if desiredLength > gm.groupBy.Groups {
		desiredLength = gm.groupBy.Groups
	}

	objs := make([]*storobj.Object, desiredLength)
	dists := make([]float32, desiredLength)
	for i, groupDistance := range groupDistances[:desiredLength] {
		val := groupDistance.value
		group := groups[groupDistance.value]
		count := 0
		hits := []map[string]interface{}{}
		for _, g := range group {
			count += g.Count
			hits = append(hits, g.Hits...)
		}

		sort.Slice(hits, func(i, j int) bool {
			return hits[i]["_additional"].(*additional.GroupHitAdditional).Distance <
				hits[j]["_additional"].(*additional.GroupHitAdditional).Distance
		})

		if len(hits) > gm.groupBy.ObjectsPerGroup {
			hits = hits[:gm.groupBy.ObjectsPerGroup]
			count = len(hits)
		}

		indx := objects[val][0]
		obj, dist := gm.objects[indx], gm.dists[indx]
		obj.AdditionalProperties()["group"] = &additional.Group{
			ID: i,
			GroupedBy: &additional.GroupedBy{
				Value: val,
				Path:  []string{gm.groupBy.Property},
			},
			Count:       count,
			Hits:        hits,
			MaxDistance: hits[0]["_additional"].(*additional.GroupHitAdditional).Distance,
			MinDistance: hits[len(hits)-1]["_additional"].(*additional.GroupHitAdditional).Distance,
		}
		objs[i], dists[i] = obj, dist
	}

	return objs, dists, nil
}