KevinStephenson
Adding in weaviate code
b110593
raw
history blame
9.03 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package aggregate
import (
"fmt"
"github.com/tailor-inc/graphql"
"github.com/weaviate/weaviate/adapters/handlers/graphql/descriptions"
"github.com/weaviate/weaviate/adapters/handlers/graphql/local/common_filters"
"github.com/weaviate/weaviate/adapters/handlers/graphql/utils"
"github.com/weaviate/weaviate/entities/aggregation"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/schema"
"github.com/weaviate/weaviate/usecases/config"
)
type ModulesProvider interface {
AggregateArguments(class *models.Class) map[string]*graphql.ArgumentConfig
ExtractSearchParams(arguments map[string]interface{}, className string) map[string]interface{}
}
// Build the Aggregate Kinds schema
func Build(dbSchema *schema.Schema, config config.Config,
modulesProvider ModulesProvider,
) (*graphql.Field, error) {
if len(dbSchema.Objects.Classes) == 0 {
return nil, utils.ErrEmptySchema
}
var err error
var localAggregateObjects *graphql.Object
if len(dbSchema.Objects.Classes) > 0 {
localAggregateObjects, err = classFields(dbSchema.Objects.Classes, config, modulesProvider)
if err != nil {
return nil, err
}
}
field := graphql.Field{
Name: "Aggregate",
Description: descriptions.AggregateWhere,
Type: localAggregateObjects,
Resolve: passThroughResolver,
}
return &field, nil
}
func classFields(databaseSchema []*models.Class,
config config.Config, modulesProvider ModulesProvider,
) (*graphql.Object, error) {
fields := graphql.Fields{}
for _, class := range databaseSchema {
field, err := classField(class, class.Description, config, modulesProvider)
if err != nil {
return nil, err
}
fields[class.Class] = field
}
return graphql.NewObject(graphql.ObjectConfig{
Name: "AggregateObjectsObj",
Fields: fields,
Description: descriptions.AggregateObjectsObj,
}), nil
}
func classField(class *models.Class, description string,
config config.Config, modulesProvider ModulesProvider,
) (*graphql.Field, error) {
metaClassName := fmt.Sprintf("Aggregate%s", class.Class)
fields := graphql.ObjectConfig{
Name: metaClassName,
Fields: (graphql.FieldsThunk)(func() graphql.Fields {
fields, err := classPropertyFields(class)
if err != nil {
// we cannot return an error in this FieldsThunk and have to panic unfortunately
panic(fmt.Sprintf("Failed to assemble single Local Aggregate Class field: %s", err))
}
return fields
}),
Description: description,
}
fieldsObject := graphql.NewObject(fields)
fieldsField := &graphql.Field{
Type: graphql.NewList(fieldsObject),
Description: description,
Args: graphql.FieldConfigArgument{
"limit": &graphql.ArgumentConfig{
Description: descriptions.First,
Type: graphql.Int,
},
"where": &graphql.ArgumentConfig{
Description: descriptions.GetWhere,
Type: graphql.NewInputObject(
graphql.InputObjectConfig{
Name: fmt.Sprintf("AggregateObjects%sWhereInpObj", class.Class),
Fields: common_filters.BuildNew(fmt.Sprintf("AggregateObjects%s", class.Class)),
Description: descriptions.GetWhereInpObj,
},
),
},
"groupBy": &graphql.ArgumentConfig{
Description: descriptions.GroupBy,
Type: graphql.NewList(graphql.String),
},
"nearVector": nearVectorArgument(class.Class),
"nearObject": nearObjectArgument(class.Class),
"objectLimit": &graphql.ArgumentConfig{
Description: descriptions.First,
Type: graphql.Int,
},
"hybrid": hybridArgument(fieldsObject, class, modulesProvider),
},
Resolve: makeResolveClass(modulesProvider, class),
}
if modulesProvider != nil {
for name, argument := range modulesProvider.AggregateArguments(class) {
fieldsField.Args[name] = argument
}
}
if schema.MultiTenancyEnabled(class) {
fieldsField.Args["tenant"] = tenantArgument()
}
return fieldsField, nil
}
func classPropertyFields(class *models.Class) (graphql.Fields, error) {
fields := graphql.Fields{}
for _, property := range class.Properties {
propertyType, err := schema.GetPropertyDataType(class, property.Name)
if err != nil {
return nil, fmt.Errorf("%s.%s: %s", class.Class, property.Name, err)
}
convertedDataType, err := classPropertyField(*propertyType, class, property)
if err != nil {
return nil, err
}
fields[property.Name] = convertedDataType
}
// Special case: meta { count } appended to all regular props
fields["meta"] = &graphql.Field{
Description: descriptions.LocalMetaObj,
Type: metaObject(fmt.Sprintf("Aggregate%s", class.Class)),
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
// pass-through
return p.Source, nil
},
}
// Always append Grouped By field
fields["groupedBy"] = &graphql.Field{
Description: descriptions.AggregateGroupedBy,
Type: groupedByProperty(class),
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
switch typed := p.Source.(type) {
case aggregation.Group:
return typed.GroupedBy, nil
case map[string]interface{}:
return typed["groupedBy"], nil
default:
return nil, fmt.Errorf("groupedBy: unsupported type %T", p.Source)
}
},
}
return fields, nil
}
func metaObject(prefix string) *graphql.Object {
return graphql.NewObject(graphql.ObjectConfig{
Name: fmt.Sprintf("%sMetaObject", prefix),
Fields: graphql.Fields{
"count": &graphql.Field{
Type: graphql.Int,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
group, ok := p.Source.(aggregation.Group)
if !ok {
return nil, fmt.Errorf("meta count: expected aggregation.Group, got %T", p.Source)
}
return group.Count, nil
},
},
},
})
}
func classPropertyField(dataType schema.DataType, class *models.Class, property *models.Property) (*graphql.Field, error) {
switch dataType {
case schema.DataTypeText:
return makePropertyField(class, property, stringPropertyFields)
case schema.DataTypeInt:
return makePropertyField(class, property, numericPropertyFields)
case schema.DataTypeNumber:
return makePropertyField(class, property, numericPropertyFields)
case schema.DataTypeBoolean:
return makePropertyField(class, property, booleanPropertyFields)
case schema.DataTypeDate:
return makePropertyField(class, property, datePropertyFields)
case schema.DataTypeCRef:
return makePropertyField(class, property, referencePropertyFields)
case schema.DataTypeGeoCoordinates:
// simply skip for now, see gh-729
return nil, nil
case schema.DataTypePhoneNumber:
// skipping for now, see gh-1088 where it was outscoped
return nil, nil
case schema.DataTypeBlob:
return makePropertyField(class, property, stringPropertyFields)
case schema.DataTypeTextArray:
return makePropertyField(class, property, stringPropertyFields)
case schema.DataTypeIntArray, schema.DataTypeNumberArray:
return makePropertyField(class, property, numericPropertyFields)
case schema.DataTypeBooleanArray:
return makePropertyField(class, property, booleanPropertyFields)
case schema.DataTypeDateArray:
return makePropertyField(class, property, datePropertyFields)
case schema.DataTypeUUID, schema.DataTypeUUIDArray:
// not aggregatable
return nil, nil
case schema.DataTypeObject, schema.DataTypeObjectArray:
// TODO: check if it's aggregable, skip for now
return nil, nil
default:
return nil, fmt.Errorf(schema.ErrorNoSuchDatatype+": %s", dataType)
}
}
type propertyFieldMaker func(class *models.Class,
property *models.Property, prefix string) *graphql.Object
func makePropertyField(class *models.Class, property *models.Property,
fieldMaker propertyFieldMaker,
) (*graphql.Field, error) {
prefix := "Aggregate"
return &graphql.Field{
Description: fmt.Sprintf(`%s"%s"`, descriptions.AggregateProperty, property.Name),
Type: fieldMaker(class, property, prefix),
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
switch typed := p.Source.(type) {
case aggregation.Group:
res, ok := typed.Properties[property.Name]
if !ok {
return nil, fmt.Errorf("missing property '%s'", property.Name)
}
return res, nil
default:
return nil, fmt.Errorf("property %s, unsupported type %T", property.Name, p.Source)
}
},
}, nil
}
func passThroughResolver(p graphql.ResolveParams) (interface{}, error) {
// bubble up root resolver
return p.Source, nil
}