Spaces:
Sleeping
Sleeping
| // _ _ | |
| // __ _____ __ ___ ___ __ _| |_ ___ | |
| // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
| // \ V V / __/ (_| |\ V /| | (_| | || __/ | |
| // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
| // | |
| // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
| // | |
| // CONTACT: [email protected] | |
| // | |
| package hybrid | |
| import ( | |
| "context" | |
| "fmt" | |
| "github.com/weaviate/weaviate/adapters/handlers/graphql/local/common_filters" | |
| "github.com/weaviate/weaviate/entities/autocut" | |
| "github.com/sirupsen/logrus" | |
| "github.com/weaviate/weaviate/entities/additional" | |
| "github.com/weaviate/weaviate/entities/search" | |
| "github.com/weaviate/weaviate/entities/searchparams" | |
| "github.com/weaviate/weaviate/entities/storobj" | |
| ) | |
| const DefaultLimit = 100 | |
| type Params struct { | |
| *searchparams.HybridSearch | |
| Keyword *searchparams.KeywordRanking | |
| Class string | |
| Autocut int | |
| } | |
| // Result facilitates the pairing of a search result with its internal doc id. | |
| // | |
| // This type is key in generalising hybrid search across different use cases. | |
| // Some use cases require a full search result (Get{} queries) and others need | |
| // only a doc id (Aggregate{}) which the search.Result type does not contain. | |
| type Result struct { | |
| DocID uint64 | |
| *search.Result | |
| } | |
| type Results []*Result | |
| func (res Results) SearchResults() []search.Result { | |
| out := make([]search.Result, len(res)) | |
| for i, r := range res { | |
| out[i] = *r.Result | |
| } | |
| return out | |
| } | |
| // sparseSearchFunc is the signature of a closure which performs sparse search. | |
| // Any package which wishes use hybrid search must provide this. The weights are | |
| // used in calculating the final scores of the result set. | |
| type sparseSearchFunc func() (results []*storobj.Object, weights []float32, err error) | |
| // denseSearchFunc is the signature of a closure which performs dense search. | |
| // A search vector argument is required to pass along to the vector index. | |
| // Any package which wishes use hybrid search must provide this The weights are | |
| // used in calculating the final scores of the result set. | |
| type denseSearchFunc func(searchVector []float32) (results []*storobj.Object, weights []float32, err error) | |
| // postProcFunc takes the results of the hybrid search and applies some transformation. | |
| // This is optionally provided, and allows the caller to somehow change the nature of | |
| // the result set. For example, Get{} queries sometimes require resolving references, | |
| // which is implemented by doing the reference resolution within a postProcFunc closure. | |
| type postProcFunc func(hybridResults Results) (postProcResults []search.Result, err error) | |
| type modulesProvider interface { | |
| VectorFromInput(ctx context.Context, | |
| className string, input string) ([]float32, error) | |
| } | |
| // Search executes sparse and dense searches and combines the result sets using Reciprocal Rank Fusion | |
| func Search(ctx context.Context, params *Params, logger logrus.FieldLogger, sparseSearch sparseSearchFunc, denseSearch denseSearchFunc, postProc postProcFunc, modules modulesProvider) (Results, error) { | |
| var ( | |
| found [][]*Result | |
| weights []float64 | |
| names []string | |
| ) | |
| if params.Query != "" { | |
| alpha := params.Alpha | |
| if alpha < 1 { | |
| res, err := processSparseSearch(sparseSearch()) | |
| if err != nil { | |
| return nil, err | |
| } | |
| found = append(found, res) | |
| weights = append(weights, 1-alpha) | |
| names = append(names, "keyword") | |
| } | |
| if alpha > 0 { | |
| res, err := processDenseSearch(ctx, denseSearch, params, modules) | |
| if err != nil { | |
| return nil, err | |
| } | |
| found = append(found, res) | |
| weights = append(weights, alpha) | |
| names = append(names, "vector") | |
| } | |
| } else { | |
| ss := params.SubSearches | |
| // To catch error if ss is empty | |
| _, err := decideSearchVector(ctx, params, modules) | |
| if err != nil { | |
| return nil, err | |
| } | |
| for _, subsearch := range ss.([]searchparams.WeightedSearchResult) { | |
| res, name, weight, err := handleSubSearch(ctx, &subsearch, denseSearch, sparseSearch, params, modules) | |
| if err != nil { | |
| return nil, err | |
| } | |
| if res == nil { | |
| continue | |
| } | |
| found = append(found, res) | |
| weights = append(weights, weight) | |
| names = append(names, name) | |
| } | |
| } | |
| if len(weights) != len(found) { | |
| return nil, fmt.Errorf("length of weights and results do not match for hybrid search %v vs. %v", len(weights), len(found)) | |
| } | |
| var fused []*Result | |
| if params.FusionAlgorithm == common_filters.HybridRankedFusion { | |
| fused = FusionRanked(weights, found, names) | |
| } else if params.FusionAlgorithm == common_filters.HybridRelativeScoreFusion { | |
| fused = FusionRelativeScore(weights, found, names) | |
| } else { | |
| return nil, fmt.Errorf("unknown ranking algorithm %v for hybrid search", params.FusionAlgorithm) | |
| } | |
| if postProc != nil { | |
| sr, err := postProc(fused) | |
| if err != nil { | |
| return nil, fmt.Errorf("hybrid search post-processing: %w", err) | |
| } | |
| fused = fused[:len(sr)] | |
| for i := range fused { | |
| fused[i].Result = &(sr[i]) | |
| } | |
| } | |
| if params.Autocut > 0 { | |
| scores := make([]float32, len(fused)) | |
| for i := range fused { | |
| scores[i] = fused[i].Score | |
| } | |
| cutOff := autocut.Autocut(scores, params.Autocut) | |
| fused = fused[:cutOff] | |
| } | |
| return fused, nil | |
| } | |
| func processSparseSearch(results []*storobj.Object, weights []float32, err error) ([]*Result, error) { | |
| if err != nil { | |
| return nil, fmt.Errorf("sparse search: %w", err) | |
| } | |
| out := make([]*Result, len(results)) | |
| for i, obj := range results { | |
| sr := obj.SearchResultWithDist(additional.Properties{}, weights[i]) | |
| sr.SecondarySortValue = sr.Score | |
| sr.ExplainScore = "(bm25)" + sr.ExplainScore | |
| out[i] = &Result{obj.DocID(), &sr} | |
| } | |
| return out, nil | |
| } | |
| func processDenseSearch(ctx context.Context, denseSearch denseSearchFunc, params *Params, modules modulesProvider) ([]*Result, error) { | |
| vector, err := decideSearchVector(ctx, params, modules) | |
| if err != nil { | |
| return nil, err | |
| } | |
| res, dists, err := denseSearch(vector) | |
| if err != nil { | |
| return nil, fmt.Errorf("dense search: %w", err) | |
| } | |
| out := make([]*Result, len(res)) | |
| for i, obj := range res { | |
| sr := obj.SearchResultWithDist(additional.Properties{}, dists[i]) | |
| sr.SecondarySortValue = 1 - sr.Dist | |
| sr.ExplainScore = fmt.Sprintf( | |
| "(vector) %v %v ", truncateVectorString(10, vector), | |
| res[i].ExplainScore()) | |
| out[i] = &Result{obj.DocID(), &sr} | |
| } | |
| return out, nil | |
| } | |
| func handleSubSearch(ctx context.Context, subsearch *searchparams.WeightedSearchResult, denseSearch denseSearchFunc, sparseSearch sparseSearchFunc, params *Params, modules modulesProvider) ([]*Result, string, float64, error) { | |
| switch subsearch.Type { | |
| case "bm25": | |
| fallthrough | |
| case "sparseSearch": | |
| return sparseSubSearch(subsearch, params, sparseSearch) | |
| case "nearText": | |
| return nearTextSubSearch(ctx, subsearch, denseSearch, params, modules) | |
| case "nearVector": | |
| return nearVectorSubSearch(subsearch, denseSearch) | |
| default: | |
| return nil, "unknown", 0, fmt.Errorf("unknown hybrid search type %q", subsearch.Type) | |
| } | |
| } | |
| func sparseSubSearch(subsearch *searchparams.WeightedSearchResult, params *Params, sparseSearch sparseSearchFunc) ([]*Result, string, float64, error) { | |
| sp := subsearch.SearchParams.(searchparams.KeywordRanking) | |
| params.Keyword = &sp | |
| res, dists, err := sparseSearch() | |
| if err != nil { | |
| return nil, "", 0, fmt.Errorf("sparse subsearch: %w", err) | |
| } | |
| out := make([]*Result, len(res)) | |
| for i, obj := range res { | |
| sr := obj.SearchResultWithDist(additional.Properties{}, dists[i]) | |
| out[i] = &Result{obj.DocID(), &sr} | |
| } | |
| return out, "bm25f", subsearch.Weight, nil | |
| } | |
| func nearTextSubSearch(ctx context.Context, subsearch *searchparams.WeightedSearchResult, denseSearch denseSearchFunc, params *Params, modules modulesProvider) ([]*Result, string, float64, error) { | |
| sp := subsearch.SearchParams.(searchparams.NearTextParams) | |
| if modules == nil { | |
| return nil, "", 0, nil | |
| } | |
| vector, err := vectorFromModuleInput(ctx, params.Class, sp.Values[0], modules) | |
| if err != nil { | |
| return nil, "", 0, err | |
| } | |
| res, dists, err := denseSearch(vector) | |
| if err != nil { | |
| return nil, "", 0, err | |
| } | |
| out := make([]*Result, len(res)) | |
| for i, obj := range res { | |
| sr := obj.SearchResultWithDist(additional.Properties{}, dists[i]) | |
| out[i] = &Result{obj.DocID(), &sr} | |
| } | |
| return out, "vector,nearText", subsearch.Weight, nil | |
| } | |
| func nearVectorSubSearch(subsearch *searchparams.WeightedSearchResult, denseSearch denseSearchFunc) ([]*Result, string, float64, error) { | |
| sp := subsearch.SearchParams.(searchparams.NearVector) | |
| res, dists, err := denseSearch(sp.Vector) | |
| if err != nil { | |
| return nil, "", 0, err | |
| } | |
| out := make([]*Result, len(res)) | |
| for i, obj := range res { | |
| sr := obj.SearchResultWithDist(additional.Properties{}, dists[i]) | |
| out[i] = &Result{obj.DocID(), &sr} | |
| } | |
| return out, "vector,nearVector", subsearch.Weight, nil | |
| } | |
| func decideSearchVector(ctx context.Context, params *Params, modules modulesProvider) ([]float32, error) { | |
| var ( | |
| vector []float32 | |
| err error | |
| ) | |
| if params.Vector != nil && len(params.Vector) != 0 { | |
| vector = params.Vector | |
| } else { | |
| if modules != nil { | |
| vector, err = vectorFromModuleInput(ctx, params.Class, params.Query, modules) | |
| if err != nil { | |
| return nil, err | |
| } | |
| } | |
| } | |
| return vector, nil | |
| } | |
| func vectorFromModuleInput(ctx context.Context, class, input string, modules modulesProvider) ([]float32, error) { | |
| vector, err := modules.VectorFromInput(ctx, class, input) | |
| if err != nil { | |
| return nil, fmt.Errorf("get vector input from modules provider: %w", err) | |
| } | |
| return vector, nil | |
| } | |
| func truncateVectorString(maxLength int, vector []float32) string { | |
| if len(vector) <= maxLength { | |
| return fmt.Sprintf("%v", vector) | |
| } | |
| return fmt.Sprintf("%v...", vector[:maxLength]) | |
| } | |