KevinStephenson
Adding in weaviate code
b110593
raw
history blame
17.4 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package aggregate
import (
"fmt"
"github.com/tailor-inc/graphql"
"github.com/weaviate/weaviate/adapters/handlers/graphql/descriptions"
"github.com/weaviate/weaviate/entities/aggregation"
"github.com/weaviate/weaviate/entities/models"
)
func numericPropertyFields(class *models.Class, property *models.Property, prefix string) *graphql.Object {
getMetaIntFields := graphql.Fields{
"sum": &graphql.Field{
Name: fmt.Sprintf("%s%s%sSum", prefix, class.Class, property.Name),
Description: descriptions.AggregateSum,
Type: graphql.Float,
Resolve: makeResolveNumericFieldAggregator("sum"),
},
"minimum": &graphql.Field{
Name: fmt.Sprintf("%s%s%sMinimum", prefix, class.Class, property.Name),
Description: descriptions.AggregateMin,
Type: graphql.Float,
Resolve: makeResolveNumericFieldAggregator("minimum"),
},
"maximum": &graphql.Field{
Name: fmt.Sprintf("%s%s%sMaximum", prefix, class.Class, property.Name),
Description: descriptions.AggregateMax,
Type: graphql.Float,
Resolve: makeResolveNumericFieldAggregator("maximum"),
},
"mean": &graphql.Field{
Name: fmt.Sprintf("%s%s%sMean", prefix, class.Class, property.Name),
Description: descriptions.AggregateMean,
Type: graphql.Float,
Resolve: makeResolveNumericFieldAggregator("mean"),
},
"mode": &graphql.Field{
Name: fmt.Sprintf("%s%s%sMode", prefix, class.Class, property.Name),
Description: descriptions.AggregateMode,
Type: graphql.Float,
Resolve: makeResolveNumericFieldAggregator("mode"),
},
"median": &graphql.Field{
Name: fmt.Sprintf("%s%s%sMedian", prefix, class.Class, property.Name),
Description: descriptions.AggregateMedian,
Type: graphql.Float,
Resolve: makeResolveNumericFieldAggregator("median"),
},
"count": &graphql.Field{
Name: fmt.Sprintf("%s%s%sCount", prefix, class.Class, property.Name),
Description: descriptions.AggregateCount,
Type: graphql.Int,
Resolve: makeResolveNumericFieldAggregator("count"),
},
"type": &graphql.Field{
Name: fmt.Sprintf("%s%s%sType", prefix, class.Class, property.Name),
Description: descriptions.AggregateCount,
Type: graphql.String,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
prop, ok := p.Source.(aggregation.Property)
if !ok {
return nil, fmt.Errorf("numerical: type: expected aggregation.Property, got %T", p.Source)
}
return prop.SchemaType, nil
},
},
}
return graphql.NewObject(graphql.ObjectConfig{
Name: fmt.Sprintf("%s%s%sObj", prefix, class.Class, property.Name),
Fields: getMetaIntFields,
Description: descriptions.AggregatePropertyObject,
})
}
func datePropertyFields(class *models.Class,
property *models.Property, prefix string,
) *graphql.Object {
getMetaDateFields := graphql.Fields{
"count": &graphql.Field{
Name: fmt.Sprintf("%s%sCount", prefix, class.Class),
Description: descriptions.AggregateCount,
Type: graphql.Int,
Resolve: makeResolveDateFieldAggregator("count"),
},
"minimum": &graphql.Field{
Name: fmt.Sprintf("%s%s%sMinimum", prefix, class.Class, property.Name),
Description: descriptions.AggregateMin,
Type: graphql.String,
Resolve: makeResolveDateFieldAggregator("minimum"),
},
"maximum": &graphql.Field{
Name: fmt.Sprintf("%s%s%sMaximum", prefix, class.Class, property.Name),
Description: descriptions.AggregateMax,
Type: graphql.String,
Resolve: makeResolveDateFieldAggregator("maximum"),
},
"mode": &graphql.Field{
Name: fmt.Sprintf("%s%s%sMode", prefix, class.Class, property.Name),
Description: descriptions.AggregateMode,
Type: graphql.String,
Resolve: makeResolveDateFieldAggregator("mode"),
},
"median": &graphql.Field{
Name: fmt.Sprintf("%s%s%sMedian", prefix, class.Class, property.Name),
Description: descriptions.AggregateMedian,
Type: graphql.String,
Resolve: makeResolveDateFieldAggregator("median"),
},
}
return graphql.NewObject(graphql.ObjectConfig{
Name: fmt.Sprintf("%s%s%sObj", prefix, class.Class, property.Name),
Fields: getMetaDateFields,
Description: descriptions.AggregatePropertyObject,
})
}
func referencePropertyFields(class *models.Class,
property *models.Property, prefix string,
) *graphql.Object {
getMetaPointingFields := graphql.Fields{
"type": &graphql.Field{
Name: fmt.Sprintf("%s%sType", prefix, class.Class),
Description: descriptions.AggregatePropertyType,
Type: graphql.String,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
prop, ok := p.Source.(aggregation.Property)
if !ok {
return nil, fmt.Errorf("ref property type: expected aggregation.Property, got %T",
p.Source)
}
return prop.SchemaType, nil
},
},
"pointingTo": &graphql.Field{
Name: fmt.Sprintf("%s%sPointingTo", prefix, class.Class),
Description: descriptions.AggregateClassPropertyPointingTo,
Type: graphql.NewList(graphql.String),
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
ref, err := extractReferenceAggregation(p.Source)
if err != nil {
return nil, fmt.Errorf("ref property pointingTo: %v", err)
}
return ref.PointingTo, nil
},
DeprecationReason: "Experimental, the format will change",
},
}
return graphql.NewObject(graphql.ObjectConfig{
Name: fmt.Sprintf("%s%s%sObj", prefix, class.Class, property.Name),
Fields: getMetaPointingFields,
Description: descriptions.AggregatePropertyObject,
})
}
func extractReferenceAggregation(source interface{}) (*aggregation.Reference, error) {
property, ok := source.(aggregation.Property)
if !ok {
return nil, fmt.Errorf("expected aggregation.Property, got %T", source)
}
if property.Type != aggregation.PropertyTypeReference {
return nil, fmt.Errorf("expected property to be of type reference, got %s", property.Type)
}
return &property.ReferenceAggregation, nil
}
func booleanPropertyFields(class *models.Class,
property *models.Property, prefix string,
) *graphql.Object {
getMetaPointingFields := graphql.Fields{
"count": &graphql.Field{
Name: fmt.Sprintf("%s%s%sCount", prefix, class.Class, property.Name),
Description: descriptions.AggregatePropertyCount,
Type: graphql.Int,
Resolve: booleanResolver(func(b aggregation.Boolean) interface{} { return b.Count }),
},
"totalTrue": &graphql.Field{
Name: fmt.Sprintf("%s%s%sTotalTrue", prefix, class.Class, property.Name),
Description: descriptions.AggregateClassPropertyTotalTrue,
Type: graphql.Int,
Resolve: booleanResolver(func(b aggregation.Boolean) interface{} { return b.TotalTrue }),
},
"percentageTrue": &graphql.Field{
Name: fmt.Sprintf("%s%s%sPercentageTrue", prefix, class.Class, property.Name),
Description: descriptions.AggregateClassPropertyPercentageTrue,
Type: graphql.Float,
Resolve: booleanResolver(func(b aggregation.Boolean) interface{} { return b.PercentageTrue }),
},
"totalFalse": &graphql.Field{
Name: fmt.Sprintf("%s%s%sTotalFalse", prefix, class.Class, property.Name),
Description: descriptions.AggregateClassPropertyTotalFalse,
Type: graphql.Int,
Resolve: booleanResolver(func(b aggregation.Boolean) interface{} { return b.TotalFalse }),
},
"percentageFalse": &graphql.Field{
Name: fmt.Sprintf("%s%s%sPercentageFalse", prefix, class.Class, property.Name),
Description: descriptions.AggregateClassPropertyPercentageFalse,
Type: graphql.Float,
Resolve: booleanResolver(func(b aggregation.Boolean) interface{} { return b.PercentageFalse }),
},
"type": &graphql.Field{
Name: fmt.Sprintf("%s%s%sType", prefix, class.Class, property.Name),
Description: descriptions.AggregateCount,
Type: graphql.String,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
prop, ok := p.Source.(aggregation.Property)
if !ok {
return nil, fmt.Errorf("boolean: type: expected aggregation.Property, got %T", p.Source)
}
return prop.SchemaType, nil
},
},
}
return graphql.NewObject(graphql.ObjectConfig{
Name: fmt.Sprintf("%s%s%sObj", prefix, class.Class, property.Name),
Fields: getMetaPointingFields,
Description: descriptions.AggregatePropertyObject,
})
}
type booleanExtractorFunc func(aggregation.Boolean) interface{}
func booleanResolver(extractor booleanExtractorFunc) func(p graphql.ResolveParams) (interface{}, error) {
return func(p graphql.ResolveParams) (interface{}, error) {
boolean, err := extractBooleanAggregation(p.Source)
if err != nil {
return nil, fmt.Errorf("boolean: %v", err)
}
return extractor(*boolean), nil
}
}
func extractBooleanAggregation(source interface{}) (*aggregation.Boolean, error) {
property, ok := source.(aggregation.Property)
if !ok {
return nil, fmt.Errorf("expected aggregation.Property, got %T", source)
}
if property.Type != aggregation.PropertyTypeBoolean {
return nil, fmt.Errorf("expected property to be of type boolean, got %s", property.Type)
}
return &property.BooleanAggregation, nil
}
func stringPropertyFields(class *models.Class,
property *models.Property, prefix string,
) *graphql.Object {
getAggregatePointingFields := graphql.Fields{
"count": &graphql.Field{
Name: fmt.Sprintf("%s%sCount", prefix, class.Class),
Description: descriptions.AggregatePropertyCount,
Type: graphql.Int,
Resolve: textResolver(func(text aggregation.Text) (interface{}, error) {
return text.Count, nil
}),
},
"type": &graphql.Field{
Name: fmt.Sprintf("%s%s%sType", prefix, class.Class, property.Name),
Description: descriptions.AggregateCount,
Type: graphql.String,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
prop, ok := p.Source.(aggregation.Property)
if !ok {
return nil, fmt.Errorf("text type: expected aggregation.Property, got %T", p.Source)
}
return prop.SchemaType, nil
},
},
"topOccurrences": &graphql.Field{
Name: fmt.Sprintf("%s%sTopOccurrences", prefix, class.Class),
Description: descriptions.AggregatePropertyTopOccurrences,
Type: graphql.NewList(stringTopOccurrences(class, property, prefix)),
Resolve: textResolver(func(text aggregation.Text) (interface{}, error) {
list := make([]interface{}, len(text.Items))
for i, to := range text.Items {
list[i] = to
}
return list, nil
}),
Args: graphql.FieldConfigArgument{
"limit": &graphql.ArgumentConfig{
Description: descriptions.First,
Type: graphql.Int,
},
},
},
}
return graphql.NewObject(graphql.ObjectConfig{
Name: fmt.Sprintf("%s%s%sObj", prefix, class.Class, property.Name),
Fields: getAggregatePointingFields,
Description: descriptions.AggregatePropertyObject,
})
}
type textExtractorFunc func(aggregation.Text) (interface{}, error)
func textResolver(extractor textExtractorFunc) func(p graphql.ResolveParams) (interface{}, error) {
return func(p graphql.ResolveParams) (interface{}, error) {
text, err := extractTextAggregation(p.Source)
if err != nil {
return nil, fmt.Errorf("text: %v", err)
}
return extractor(text)
}
}
func stringTopOccurrences(class *models.Class,
property *models.Property, prefix string,
) *graphql.Object {
getAggregateAggregatePointingFields := graphql.Fields{
"value": &graphql.Field{
Name: fmt.Sprintf("%s%s%sTopOccurrencesValue", prefix, class.Class, property.Name),
Description: descriptions.AggregatePropertyTopOccurrencesValue,
Type: graphql.String,
Resolve: textOccurrenceResolver(func(t aggregation.TextOccurrence) interface{} { return t.Value }),
},
"occurs": &graphql.Field{
Name: fmt.Sprintf("%s%s%sTopOccurrencesOccurs", prefix, class.Class, property.Name),
Description: descriptions.AggregatePropertyTopOccurrencesOccurs,
Type: graphql.Int,
Resolve: textOccurrenceResolver(func(t aggregation.TextOccurrence) interface{} { return t.Occurs }),
},
}
getAggregateAggregatePointing := graphql.ObjectConfig{
Name: fmt.Sprintf("%s%s%sTopOccurrencesObj", prefix, class.Class, property.Name),
Fields: getAggregateAggregatePointingFields,
Description: descriptions.AggregatePropertyTopOccurrences,
}
return graphql.NewObject(getAggregateAggregatePointing)
}
type textOccurrenceExtractorFunc func(aggregation.TextOccurrence) interface{}
func textOccurrenceResolver(extractor textOccurrenceExtractorFunc) func(p graphql.ResolveParams) (interface{}, error) {
return func(p graphql.ResolveParams) (interface{}, error) {
textOccurrence, ok := p.Source.(aggregation.TextOccurrence)
if !ok {
return nil, fmt.Errorf("textOccurrence: %s: expected aggregation.TextOccurrence, but got %T",
p.Info.FieldName, p.Source)
}
return extractor(textOccurrence), nil
}
}
func extractTextAggregation(source interface{}) (aggregation.Text, error) {
property, ok := source.(aggregation.Property)
if !ok {
return aggregation.Text{}, fmt.Errorf("expected aggregation.Property, got %T", source)
}
if property.Type == aggregation.PropertyTypeNumerical {
// in this case we can only use count
return aggregation.Text{
Count: property.NumericalAggregations["count"].(int),
}, nil
}
if property.Type != aggregation.PropertyTypeText {
return aggregation.Text{}, fmt.Errorf("expected property to be of type text, got %s (%#v)", property.Type, property)
}
return property.TextAggregation, nil
}
func groupedByProperty(class *models.Class) *graphql.Object {
classProperties := graphql.Fields{
"path": &graphql.Field{
Description: descriptions.AggregateGroupedByGroupedByPath,
Type: graphql.NewList(graphql.String),
Resolve: groupedByResolver(func(g *aggregation.GroupedBy) interface{} { return g.Path }),
},
"value": &graphql.Field{
Description: descriptions.AggregateGroupedByGroupedByValue,
Type: graphql.String,
Resolve: groupedByResolver(func(g *aggregation.GroupedBy) interface{} { return g.Value }),
},
}
classPropertiesObj := graphql.NewObject(graphql.ObjectConfig{
Name: fmt.Sprintf("Aggregate%sGroupedByObj", class.Class),
Fields: classProperties,
Description: descriptions.AggregateGroupedByObj,
})
return classPropertiesObj
}
type groupedByExtractorFunc func(*aggregation.GroupedBy) interface{}
func groupedByResolver(extractor groupedByExtractorFunc) func(p graphql.ResolveParams) (interface{}, error) {
return func(p graphql.ResolveParams) (interface{}, error) {
groupedBy, ok := p.Source.(*aggregation.GroupedBy)
if !ok {
return nil, fmt.Errorf("groupedBy: %s: expected aggregation.GroupedBy, but got %T",
p.Info.FieldName, p.Source)
}
return extractor(groupedBy), nil
}
}
func makeResolveNumericFieldAggregator(aggregator string) func(p graphql.ResolveParams) (interface{}, error) {
return func(p graphql.ResolveParams) (interface{}, error) {
num, err := extractNumericAggregation(p.Source)
if err != nil {
return nil, fmt.Errorf("numerical aggregator %s: %v", aggregator, err)
}
return num[aggregator], nil
}
}
func extractNumericAggregation(source interface{}) (map[string]interface{}, error) {
property, ok := source.(aggregation.Property)
if !ok {
return nil, fmt.Errorf("expected aggregation.Property, got %T", source)
}
if property.Type != aggregation.PropertyTypeNumerical {
return nil, fmt.Errorf("expected property to be of type numerical, got %s", property.Type)
}
return property.NumericalAggregations, nil
}
func makeResolveDateFieldAggregator(aggregator string) func(p graphql.ResolveParams) (interface{}, error) {
return func(p graphql.ResolveParams) (interface{}, error) {
date, err := extractDateAggregation(p.Source)
if err != nil {
return nil, fmt.Errorf("date aggregator %s: %v", aggregator, err)
}
return date[aggregator], nil
}
}
func extractDateAggregation(source interface{}) (map[string]interface{}, error) {
property, ok := source.(aggregation.Property)
if !ok {
return nil, fmt.Errorf("expected aggregation.Property, got %T", source)
}
if property.Type != aggregation.PropertyTypeDate {
return nil, fmt.Errorf("expected property to be of type date, got %s", property.Type)
}
return property.DateAggregations, nil
}