// _ _ // __ _____ __ ___ ___ __ _| |_ ___ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ // \ V V / __/ (_| |\ V /| | (_| | || __/ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| // // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. // // CONTACT: hello@weaviate.io // package get import ( "context" "fmt" "regexp" "strings" "github.com/tailor-inc/graphql" "github.com/tailor-inc/graphql/language/ast" "github.com/weaviate/weaviate/adapters/handlers/graphql/descriptions" "github.com/weaviate/weaviate/adapters/handlers/graphql/local/common_filters" "github.com/weaviate/weaviate/entities/additional" "github.com/weaviate/weaviate/entities/dto" enterrors "github.com/weaviate/weaviate/entities/errors" "github.com/weaviate/weaviate/entities/filters" "github.com/weaviate/weaviate/entities/models" "github.com/weaviate/weaviate/entities/modulecapabilities" "github.com/weaviate/weaviate/entities/schema" "github.com/weaviate/weaviate/entities/search" "github.com/weaviate/weaviate/entities/searchparams" ) func (b *classBuilder) primitiveField(propertyType schema.PropertyDataType, property *models.Property, className string, ) *graphql.Field { switch propertyType.AsPrimitive() { case schema.DataTypeText: return &graphql.Field{ Description: property.Description, Name: property.Name, Type: graphql.String, } case schema.DataTypeInt: return &graphql.Field{ Description: property.Description, Name: property.Name, Type: graphql.Int, } case schema.DataTypeNumber: return &graphql.Field{ Description: property.Description, Name: property.Name, Type: graphql.Float, } case schema.DataTypeBoolean: return &graphql.Field{ Description: property.Description, Name: property.Name, Type: graphql.Boolean, } case schema.DataTypeDate: return &graphql.Field{ Description: property.Description, Name: property.Name, Type: graphql.String, // String since no graphql date datatype exists } case schema.DataTypeGeoCoordinates: obj := newGeoCoordinatesObject(className, property.Name) return &graphql.Field{ Description: property.Description, Name: property.Name, Type: obj, Resolve: resolveGeoCoordinates, } case schema.DataTypePhoneNumber: obj := newPhoneNumberObject(className, property.Name) return &graphql.Field{ Description: property.Description, Name: property.Name, Type: obj, Resolve: resolvePhoneNumber, } case schema.DataTypeBlob: return &graphql.Field{ Description: property.Description, Name: property.Name, Type: graphql.String, } case schema.DataTypeTextArray: return &graphql.Field{ Description: property.Description, Name: property.Name, Type: graphql.NewList(graphql.String), } case schema.DataTypeIntArray: return &graphql.Field{ Description: property.Description, Name: property.Name, Type: graphql.NewList(graphql.Int), } case schema.DataTypeNumberArray: return &graphql.Field{ Description: property.Description, Name: property.Name, Type: graphql.NewList(graphql.Float), } case schema.DataTypeBooleanArray: return &graphql.Field{ Description: property.Description, Name: property.Name, Type: graphql.NewList(graphql.Boolean), } case schema.DataTypeDateArray: return &graphql.Field{ Description: property.Description, Name: property.Name, Type: graphql.NewList(graphql.String), // String since no graphql date datatype exists } case schema.DataTypeUUIDArray: return &graphql.Field{ Description: property.Description, Name: property.Name, Type: graphql.NewList(graphql.String), // Always return UUID as string representation to the user } case schema.DataTypeUUID: return &graphql.Field{ Description: property.Description, Name: property.Name, Type: graphql.String, // Always return UUID as string representation to the user } default: panic(fmt.Sprintf("buildGetClass: unknown primitive type for %s.%s; %s", className, property.Name, propertyType.AsPrimitive())) } } func newGeoCoordinatesObject(className string, propertyName string) *graphql.Object { return graphql.NewObject(graphql.ObjectConfig{ Description: "GeoCoordinates as latitude and longitude in decimal form", Name: fmt.Sprintf("%s%sGeoCoordinatesObj", className, propertyName), Fields: graphql.Fields{ "latitude": &graphql.Field{ Name: "Latitude", Description: "The Latitude of the point in decimal form.", Type: graphql.Float, }, "longitude": &graphql.Field{ Name: "Longitude", Description: "The Longitude of the point in decimal form.", Type: graphql.Float, }, }, }) } func newPhoneNumberObject(className string, propertyName string) *graphql.Object { return graphql.NewObject(graphql.ObjectConfig{ Description: "PhoneNumber in various parsed formats", Name: fmt.Sprintf("%s%sPhoneNumberObj", className, propertyName), Fields: graphql.Fields{ "input": &graphql.Field{ Name: "Input", Description: "The raw phone number as put in by the user prior to parsing", Type: graphql.String, }, "internationalFormatted": &graphql.Field{ Name: "Input", Description: "The parsed phone number in the international format", Type: graphql.String, }, "nationalFormatted": &graphql.Field{ Name: "Input", Description: "The parsed phone number in the national format", Type: graphql.String, }, "national": &graphql.Field{ Name: "Input", Description: "The parsed phone number in the national format", Type: graphql.Int, }, "valid": &graphql.Field{ Name: "Input", Description: "Whether the phone number could be successfully parsed and was considered valid by the parser", Type: graphql.Boolean, }, "countryCode": &graphql.Field{ Name: "Input", Description: "The parsed country code, i.e. the leading numbers identifing the country in an international format", Type: graphql.Int, }, "defaultCountry": &graphql.Field{ Name: "Input", Description: "The defaultCountry as put in by the user. (This is used to help parse national numbers into an international format)", Type: graphql.String, }, }, }) } func buildGetClassField(classObject *graphql.Object, class *models.Class, modulesProvider ModulesProvider, fusionEnum *graphql.Enum, ) graphql.Field { field := graphql.Field{ Type: graphql.NewList(classObject), Description: class.Description, Args: graphql.FieldConfigArgument{ "after": &graphql.ArgumentConfig{ Description: descriptions.AfterID, Type: graphql.String, }, "limit": &graphql.ArgumentConfig{ Description: descriptions.Limit, Type: graphql.Int, }, "offset": &graphql.ArgumentConfig{ Description: descriptions.After, Type: graphql.Int, }, "autocut": &graphql.ArgumentConfig{ Description: "Cut off number of results after the Nth extrema. Off by default, negative numbers mean off.", Type: graphql.Int, }, "sort": sortArgument(class.Class), "nearVector": nearVectorArgument(class.Class), "nearObject": nearObjectArgument(class.Class), "where": whereArgument(class.Class), "group": groupArgument(class.Class), "groupBy": groupByArgument(class.Class), }, Resolve: newResolver(modulesProvider).makeResolveGetClass(class.Class), } field.Args["bm25"] = bm25Argument(class.Class) field.Args["hybrid"] = hybridArgument(classObject, class, modulesProvider, fusionEnum) if modulesProvider != nil { for name, argument := range modulesProvider.GetArguments(class) { field.Args[name] = argument } } if replicationEnabled(class) { field.Args["consistencyLevel"] = consistencyLevelArgument(class) } if schema.MultiTenancyEnabled(class) { field.Args["tenant"] = tenantArgument() } return field } func resolveGeoCoordinates(p graphql.ResolveParams) (interface{}, error) { field := p.Source.(map[string]interface{})[p.Info.FieldName] if field == nil { return nil, nil } geo, ok := field.(*models.GeoCoordinates) if !ok { return nil, fmt.Errorf("expected a *models.GeoCoordinates, but got: %T", field) } return map[string]interface{}{ "latitude": geo.Latitude, "longitude": geo.Longitude, }, nil } func resolvePhoneNumber(p graphql.ResolveParams) (interface{}, error) { field := p.Source.(map[string]interface{})[p.Info.FieldName] if field == nil { return nil, nil } phone, ok := field.(*models.PhoneNumber) if !ok { return nil, fmt.Errorf("expected a *models.PhoneNumber, but got: %T", field) } return map[string]interface{}{ "input": phone.Input, "internationalFormatted": phone.InternationalFormatted, "nationalFormatted": phone.NationalFormatted, "national": phone.National, "valid": phone.Valid, "countryCode": phone.CountryCode, "defaultCountry": phone.DefaultCountry, }, nil } func whereArgument(className string) *graphql.ArgumentConfig { return &graphql.ArgumentConfig{ Description: descriptions.GetWhere, Type: graphql.NewInputObject( graphql.InputObjectConfig{ Name: fmt.Sprintf("GetObjects%sWhereInpObj", className), Fields: common_filters.BuildNew(fmt.Sprintf("GetObjects%s", className)), Description: descriptions.GetWhereInpObj, }, ), } } type resolver struct { modulesProvider ModulesProvider } func newResolver(modulesProvider ModulesProvider) *resolver { return &resolver{modulesProvider} } func (r *resolver) makeResolveGetClass(className string) graphql.FieldResolveFn { return func(p graphql.ResolveParams) (interface{}, error) { result, err := r.resolveGet(p, className) if err != nil { return result, enterrors.NewErrGraphQLUser(err, "Get", className) } return result, nil } } func (r *resolver) resolveGet(p graphql.ResolveParams, className string) (interface{}, error) { source, ok := p.Source.(map[string]interface{}) if !ok { return nil, fmt.Errorf("expected graphql root to be a map, but was %T", p.Source) } resolver, ok := source["Resolver"].(Resolver) if !ok { return nil, fmt.Errorf("expected source map to have a usable Resolver, but got %#v", source["Resolver"]) } pagination, err := filters.ExtractPaginationFromArgs(p.Args) if err != nil { return nil, err } cursor, err := filters.ExtractCursorFromArgs(p.Args) if err != nil { return nil, err } // There can only be exactly one ast.Field; it is the class name. if len(p.Info.FieldASTs) != 1 { panic("Only one Field expected here") } selectionsOfClass := p.Info.FieldASTs[0].SelectionSet properties, addlProps, err := extractProperties( className, selectionsOfClass, p.Info.Fragments, r.modulesProvider) if err != nil { return nil, err } var sort []filters.Sort if sortArg, ok := p.Args["sort"]; ok { sort = filters.ExtractSortFromArgs(sortArg.([]interface{})) } filters, err := common_filters.ExtractFilters(p.Args, p.Info.FieldName) if err != nil { return nil, fmt.Errorf("could not extract filters: %s", err) } var nearVectorParams *searchparams.NearVector if nearVector, ok := p.Args["nearVector"]; ok { p, err := common_filters.ExtractNearVector(nearVector.(map[string]interface{})) if err != nil { return nil, fmt.Errorf("failed to extract nearVector params: %s", err) } nearVectorParams = &p } var nearObjectParams *searchparams.NearObject if nearObject, ok := p.Args["nearObject"]; ok { p, err := common_filters.ExtractNearObject(nearObject.(map[string]interface{})) if err != nil { return nil, fmt.Errorf("failed to extract nearObject params: %s", err) } nearObjectParams = &p } var moduleParams map[string]interface{} if r.modulesProvider != nil { extractedParams := r.modulesProvider.ExtractSearchParams(p.Args, className) if len(extractedParams) > 0 { moduleParams = extractedParams } } // extracts bm25 (sparseSearch) from the query var keywordRankingParams *searchparams.KeywordRanking if bm25, ok := p.Args["bm25"]; ok { if len(sort) > 0 { return nil, fmt.Errorf("bm25 search is not compatible with sort") } p := common_filters.ExtractBM25(bm25.(map[string]interface{}), addlProps.ExplainScore) keywordRankingParams = &p } // Extract hybrid search params from the processed query // Everything hybrid can go in another namespace AFTER modulesprovider is // refactored var hybridParams *searchparams.HybridSearch if hybrid, ok := p.Args["hybrid"]; ok { if len(sort) > 0 { return nil, fmt.Errorf("hybrid search is not compatible with sort") } p, err := common_filters.ExtractHybridSearch(hybrid.(map[string]interface{}), addlProps.ExplainScore) if err != nil { return nil, fmt.Errorf("failed to extract hybrid params: %w", err) } hybridParams = p } var replProps *additional.ReplicationProperties if cl, ok := p.Args["consistencyLevel"]; ok { replProps = &additional.ReplicationProperties{ ConsistencyLevel: cl.(string), } } group := extractGroup(p.Args) var groupByParams *searchparams.GroupBy if groupBy, ok := p.Args["groupBy"]; ok { p := common_filters.ExtractGroupBy(groupBy.(map[string]interface{})) groupByParams = &p } var tenant string if tk, ok := p.Args["tenant"]; ok { tenant = tk.(string) } params := dto.GetParams{ Filters: filters, ClassName: className, Pagination: pagination, Cursor: cursor, Properties: properties, Sort: sort, NearVector: nearVectorParams, NearObject: nearObjectParams, Group: group, ModuleParams: moduleParams, AdditionalProperties: addlProps, KeywordRanking: keywordRankingParams, HybridSearch: hybridParams, ReplicationProperties: replProps, GroupBy: groupByParams, Tenant: tenant, } // need to perform vector search by distance // under certain conditions setLimitBasedOnVectorSearchParams(¶ms) return func() (interface{}, error) { result, err := resolver.GetClass(p.Context, principalFromContext(p.Context), params) if err != nil { return result, enterrors.NewErrGraphQLUser(err, "Get", params.ClassName) } return result, nil }, nil } // the limit needs to be set according to the vector search parameters. // for example, if a certainty is provided by any of the near* options, // and no limit was provided, weaviate will want to execute a vector // search by distance. it knows to do this by watching for a limit // flag, specifically filters.LimitFlagSearchByDistance func setLimitBasedOnVectorSearchParams(params *dto.GetParams) { setLimit := func(params *dto.GetParams) { if params.Pagination == nil { // limit was omitted entirely, implicitly // indicating to do unlimited search params.Pagination = &filters.Pagination{ Limit: filters.LimitFlagSearchByDist, } } else if params.Pagination.Limit < 0 { // a negative limit was set, explicitly // indicating to do unlimited search params.Pagination.Limit = filters.LimitFlagSearchByDist } } if params.NearVector != nil && (params.NearVector.Certainty != 0 || params.NearVector.WithDistance) { setLimit(params) return } if params.NearObject != nil && (params.NearObject.Certainty != 0 || params.NearObject.WithDistance) { setLimit(params) return } for _, param := range params.ModuleParams { nearParam, ok := param.(modulecapabilities.NearParam) if ok && nearParam.SimilarityMetricProvided() { setLimit(params) return } } } func extractGroup(args map[string]interface{}) *dto.GroupParams { group, ok := args["group"] if !ok { return nil } asMap := group.(map[string]interface{}) // guaranteed by graphql strategy := asMap["type"].(string) force := asMap["force"].(float64) return &dto.GroupParams{ Strategy: strategy, Force: float32(force), } } func principalFromContext(ctx context.Context) *models.Principal { principal := ctx.Value("principal") if principal == nil { return nil } return principal.(*models.Principal) } func isPrimitive(selectionSet *ast.SelectionSet) bool { if selectionSet == nil { return true } // if there is a selection set it could either be a cross-ref or a map-type // field like GeoCoordinates or PhoneNumber for _, subSelection := range selectionSet.Selections { if subsectionField, ok := subSelection.(*ast.Field); ok { if fieldNameIsOfObjectButNonReferenceType(subsectionField.Name.Value) { return true } } } // must be a ref field return false } type additionalCheck struct { modulesProvider ModulesProvider } func (ac *additionalCheck) isAdditional(parentName, name string) bool { if parentName == "_additional" { if name == "classification" || name == "certainty" || name == "distance" || name == "id" || name == "vector" || name == "creationTimeUnix" || name == "lastUpdateTimeUnix" || name == "score" || name == "explainScore" || name == "isConsistent" || name == "group" { return true } if ac.isModuleAdditional(name) { return true } } return false } func (ac *additionalCheck) isModuleAdditional(name string) bool { if ac.modulesProvider != nil { if len(ac.modulesProvider.GraphQLAdditionalFieldNames()) > 0 { for _, moduleAdditionalProperty := range ac.modulesProvider.GraphQLAdditionalFieldNames() { if name == moduleAdditionalProperty { return true } } } } return false } func fieldNameIsOfObjectButNonReferenceType(field string) bool { switch field { case "latitude", "longitude": // must be a geo prop return true case "input", "internationalFormatted", "nationalFormatted", "national", "valid", "countryCode", "defaultCountry": // must be a phone number return true default: return false } } func extractProperties(className string, selections *ast.SelectionSet, fragments map[string]ast.Definition, modulesProvider ModulesProvider, ) ([]search.SelectProperty, additional.Properties, error) { var properties []search.SelectProperty var additionalProps additional.Properties additionalCheck := &additionalCheck{modulesProvider} for _, selection := range selections.Selections { field := selection.(*ast.Field) name := field.Name.Value property := search.SelectProperty{Name: name} property.IsPrimitive = isPrimitive(field.SelectionSet) if !property.IsPrimitive { // We can interpret this property in different ways for _, subSelection := range field.SelectionSet.Selections { switch s := subSelection.(type) { case *ast.Field: // Is it a field with the name __typename? if s.Name.Value == "__typename" { property.IncludeTypeName = true continue } else if additionalCheck.isAdditional(name, s.Name.Value) { additionalProperty := s.Name.Value if additionalProperty == "classification" { additionalProps.Classification = true continue } if additionalProperty == "certainty" { additionalProps.Certainty = true continue } if additionalProperty == "distance" { additionalProps.Distance = true continue } if additionalProperty == "id" { additionalProps.ID = true continue } if additionalProperty == "vector" { additionalProps.Vector = true continue } if additionalProperty == "creationTimeUnix" { additionalProps.CreationTimeUnix = true continue } if additionalProperty == "score" { additionalProps.Score = true continue } if additionalProperty == "explainScore" { additionalProps.ExplainScore = true continue } if additionalProperty == "lastUpdateTimeUnix" { additionalProps.LastUpdateTimeUnix = true continue } if additionalProperty == "isConsistent" { additionalProps.IsConsistent = true continue } if additionalProperty == "group" { additionalProps.Group = true additionalGroupHitProperties, err := extractGroupHitProperties(className, additionalProps, subSelection, fragments, modulesProvider) if err != nil { return nil, additionalProps, err } properties = append(properties, additionalGroupHitProperties...) continue } if modulesProvider != nil { if additionalCheck.isModuleAdditional(additionalProperty) { additionalProps.ModuleParams = getModuleParams(additionalProps.ModuleParams) additionalProps.ModuleParams[additionalProperty] = modulesProvider.ExtractAdditionalField(className, additionalProperty, s.Arguments) continue } } } else { // It's an object / object array property continue } case *ast.FragmentSpread: ref, err := extractFragmentSpread(className, s, fragments, modulesProvider) if err != nil { return nil, additionalProps, err } property.Refs = append(property.Refs, ref) case *ast.InlineFragment: ref, err := extractInlineFragment(className, s, fragments, modulesProvider) if err != nil { return nil, additionalProps, err } property.Refs = append(property.Refs, ref) default: return nil, additionalProps, fmt.Errorf("unrecoginzed type in subs-selection: %T", subSelection) } } } if name == "_additional" { continue } properties = append(properties, property) } return properties, additionalProps, nil } func extractGroupHitProperties( className string, additionalProps additional.Properties, subSelection ast.Selection, fragments map[string]ast.Definition, modulesProvider ModulesProvider, ) ([]search.SelectProperty, error) { additionalGroupProperties := []search.SelectProperty{} if subSelection != nil { if selectionSet := subSelection.GetSelectionSet(); selectionSet != nil { for _, groupSubSelection := range selectionSet.Selections { if groupSubSelection != nil { if groupSubSelectionField, ok := groupSubSelection.(*ast.Field); ok { if groupSubSelectionField.Name.Value == "hits" && groupSubSelectionField.SelectionSet != nil { for _, groupHitsSubSelection := range groupSubSelectionField.SelectionSet.Selections { if hf, ok := groupHitsSubSelection.(*ast.Field); ok { if hf.SelectionSet != nil { for _, ss := range hf.SelectionSet.Selections { if inlineFrag, ok := ss.(*ast.InlineFragment); ok { ref, err := extractInlineFragment(className, inlineFrag, fragments, modulesProvider) if err != nil { return nil, err } additionalGroupHitProp := search.SelectProperty{Name: fmt.Sprintf("_additional:group:hits:%v", hf.Name.Value)} additionalGroupHitProp.Refs = append(additionalGroupHitProp.Refs, ref) additionalGroupProperties = append(additionalGroupProperties, additionalGroupHitProp) } } } } } } } } } } } return additionalGroupProperties, nil } func getModuleParams(moduleParams map[string]interface{}) map[string]interface{} { if moduleParams == nil { return map[string]interface{}{} } return moduleParams } func extractInlineFragment(class string, fragment *ast.InlineFragment, fragments map[string]ast.Definition, modulesProvider ModulesProvider, ) (search.SelectClass, error) { var className schema.ClassName var err error var result search.SelectClass if strings.Contains(fragment.TypeCondition.Name.Value, "__") { // is a helper type for a network ref // don't validate anything as of now className = schema.ClassName(fragment.TypeCondition.Name.Value) } else { className, err = schema.ValidateClassName(fragment.TypeCondition.Name.Value) if err != nil { return result, fmt.Errorf("the inline fragment type name '%s' is not a valid class name", fragment.TypeCondition.Name.Value) } } if className == "Beacon" { return result, fmt.Errorf("retrieving cross-refs by beacon is not supported yet - coming soon!") } subProperties, additionalProperties, err := extractProperties(class, fragment.SelectionSet, fragments, modulesProvider) if err != nil { return result, err } result.ClassName = string(className) result.RefProperties = subProperties result.AdditionalProperties = additionalProperties return result, nil } func extractFragmentSpread(class string, spread *ast.FragmentSpread, fragments map[string]ast.Definition, modulesProvider ModulesProvider, ) (search.SelectClass, error) { var result search.SelectClass name := spread.Name.Value def, ok := fragments[name] if !ok { return result, fmt.Errorf("spread fragment '%s' refers to unknown fragment", name) } className, err := hackyWorkaroundToExtractClassName(def, name) if err != nil { return result, err } subProperties, additionalProperties, err := extractProperties(class, def.GetSelectionSet(), fragments, modulesProvider) if err != nil { return result, err } result.ClassName = string(className) result.RefProperties = subProperties result.AdditionalProperties = additionalProperties return result, nil } // It seems there's no proper way to extract this info unfortunately: // https://github.com/tailor-inc/graphql/issues/455 func hackyWorkaroundToExtractClassName(def ast.Definition, name string) (string, error) { loc := def.GetLoc() raw := loc.Source.Body[loc.Start:loc.End] r := regexp.MustCompile(fmt.Sprintf(`fragment\s*%s\s*on\s*(\w*)\s*{`, name)) matches := r.FindSubmatch(raw) if len(matches) < 2 { return "", fmt.Errorf("could not extract a className from fragment") } return string(matches[1]), nil }