KevinStephenson
Adding in weaviate code
b110593
raw
history blame
9.94 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package hybrid
import (
"context"
"fmt"
"github.com/weaviate/weaviate/adapters/handlers/graphql/local/common_filters"
"github.com/weaviate/weaviate/entities/autocut"
"github.com/sirupsen/logrus"
"github.com/weaviate/weaviate/entities/additional"
"github.com/weaviate/weaviate/entities/search"
"github.com/weaviate/weaviate/entities/searchparams"
"github.com/weaviate/weaviate/entities/storobj"
)
const DefaultLimit = 100
type Params struct {
*searchparams.HybridSearch
Keyword *searchparams.KeywordRanking
Class string
Autocut int
}
// Result facilitates the pairing of a search result with its internal doc id.
//
// This type is key in generalising hybrid search across different use cases.
// Some use cases require a full search result (Get{} queries) and others need
// only a doc id (Aggregate{}) which the search.Result type does not contain.
type Result struct {
DocID uint64
*search.Result
}
type Results []*Result
func (res Results) SearchResults() []search.Result {
out := make([]search.Result, len(res))
for i, r := range res {
out[i] = *r.Result
}
return out
}
// sparseSearchFunc is the signature of a closure which performs sparse search.
// Any package which wishes use hybrid search must provide this. The weights are
// used in calculating the final scores of the result set.
type sparseSearchFunc func() (results []*storobj.Object, weights []float32, err error)
// denseSearchFunc is the signature of a closure which performs dense search.
// A search vector argument is required to pass along to the vector index.
// Any package which wishes use hybrid search must provide this The weights are
// used in calculating the final scores of the result set.
type denseSearchFunc func(searchVector []float32) (results []*storobj.Object, weights []float32, err error)
// postProcFunc takes the results of the hybrid search and applies some transformation.
// This is optionally provided, and allows the caller to somehow change the nature of
// the result set. For example, Get{} queries sometimes require resolving references,
// which is implemented by doing the reference resolution within a postProcFunc closure.
type postProcFunc func(hybridResults Results) (postProcResults []search.Result, err error)
type modulesProvider interface {
VectorFromInput(ctx context.Context,
className string, input string) ([]float32, error)
}
// Search executes sparse and dense searches and combines the result sets using Reciprocal Rank Fusion
func Search(ctx context.Context, params *Params, logger logrus.FieldLogger, sparseSearch sparseSearchFunc, denseSearch denseSearchFunc, postProc postProcFunc, modules modulesProvider) (Results, error) {
var (
found [][]*Result
weights []float64
names []string
)
if params.Query != "" {
alpha := params.Alpha
if alpha < 1 {
res, err := processSparseSearch(sparseSearch())
if err != nil {
return nil, err
}
found = append(found, res)
weights = append(weights, 1-alpha)
names = append(names, "keyword")
}
if alpha > 0 {
res, err := processDenseSearch(ctx, denseSearch, params, modules)
if err != nil {
return nil, err
}
found = append(found, res)
weights = append(weights, alpha)
names = append(names, "vector")
}
} else {
ss := params.SubSearches
// To catch error if ss is empty
_, err := decideSearchVector(ctx, params, modules)
if err != nil {
return nil, err
}
for _, subsearch := range ss.([]searchparams.WeightedSearchResult) {
res, name, weight, err := handleSubSearch(ctx, &subsearch, denseSearch, sparseSearch, params, modules)
if err != nil {
return nil, err
}
if res == nil {
continue
}
found = append(found, res)
weights = append(weights, weight)
names = append(names, name)
}
}
if len(weights) != len(found) {
return nil, fmt.Errorf("length of weights and results do not match for hybrid search %v vs. %v", len(weights), len(found))
}
var fused []*Result
if params.FusionAlgorithm == common_filters.HybridRankedFusion {
fused = FusionRanked(weights, found, names)
} else if params.FusionAlgorithm == common_filters.HybridRelativeScoreFusion {
fused = FusionRelativeScore(weights, found, names)
} else {
return nil, fmt.Errorf("unknown ranking algorithm %v for hybrid search", params.FusionAlgorithm)
}
if postProc != nil {
sr, err := postProc(fused)
if err != nil {
return nil, fmt.Errorf("hybrid search post-processing: %w", err)
}
fused = fused[:len(sr)]
for i := range fused {
fused[i].Result = &(sr[i])
}
}
if params.Autocut > 0 {
scores := make([]float32, len(fused))
for i := range fused {
scores[i] = fused[i].Score
}
cutOff := autocut.Autocut(scores, params.Autocut)
fused = fused[:cutOff]
}
return fused, nil
}
func processSparseSearch(results []*storobj.Object, weights []float32, err error) ([]*Result, error) {
if err != nil {
return nil, fmt.Errorf("sparse search: %w", err)
}
out := make([]*Result, len(results))
for i, obj := range results {
sr := obj.SearchResultWithDist(additional.Properties{}, weights[i])
sr.SecondarySortValue = sr.Score
sr.ExplainScore = "(bm25)" + sr.ExplainScore
out[i] = &Result{obj.DocID(), &sr}
}
return out, nil
}
func processDenseSearch(ctx context.Context, denseSearch denseSearchFunc, params *Params, modules modulesProvider) ([]*Result, error) {
vector, err := decideSearchVector(ctx, params, modules)
if err != nil {
return nil, err
}
res, dists, err := denseSearch(vector)
if err != nil {
return nil, fmt.Errorf("dense search: %w", err)
}
out := make([]*Result, len(res))
for i, obj := range res {
sr := obj.SearchResultWithDist(additional.Properties{}, dists[i])
sr.SecondarySortValue = 1 - sr.Dist
sr.ExplainScore = fmt.Sprintf(
"(vector) %v %v ", truncateVectorString(10, vector),
res[i].ExplainScore())
out[i] = &Result{obj.DocID(), &sr}
}
return out, nil
}
func handleSubSearch(ctx context.Context, subsearch *searchparams.WeightedSearchResult, denseSearch denseSearchFunc, sparseSearch sparseSearchFunc, params *Params, modules modulesProvider) ([]*Result, string, float64, error) {
switch subsearch.Type {
case "bm25":
fallthrough
case "sparseSearch":
return sparseSubSearch(subsearch, params, sparseSearch)
case "nearText":
return nearTextSubSearch(ctx, subsearch, denseSearch, params, modules)
case "nearVector":
return nearVectorSubSearch(subsearch, denseSearch)
default:
return nil, "unknown", 0, fmt.Errorf("unknown hybrid search type %q", subsearch.Type)
}
}
func sparseSubSearch(subsearch *searchparams.WeightedSearchResult, params *Params, sparseSearch sparseSearchFunc) ([]*Result, string, float64, error) {
sp := subsearch.SearchParams.(searchparams.KeywordRanking)
params.Keyword = &sp
res, dists, err := sparseSearch()
if err != nil {
return nil, "", 0, fmt.Errorf("sparse subsearch: %w", err)
}
out := make([]*Result, len(res))
for i, obj := range res {
sr := obj.SearchResultWithDist(additional.Properties{}, dists[i])
out[i] = &Result{obj.DocID(), &sr}
}
return out, "bm25f", subsearch.Weight, nil
}
func nearTextSubSearch(ctx context.Context, subsearch *searchparams.WeightedSearchResult, denseSearch denseSearchFunc, params *Params, modules modulesProvider) ([]*Result, string, float64, error) {
sp := subsearch.SearchParams.(searchparams.NearTextParams)
if modules == nil {
return nil, "", 0, nil
}
vector, err := vectorFromModuleInput(ctx, params.Class, sp.Values[0], modules)
if err != nil {
return nil, "", 0, err
}
res, dists, err := denseSearch(vector)
if err != nil {
return nil, "", 0, err
}
out := make([]*Result, len(res))
for i, obj := range res {
sr := obj.SearchResultWithDist(additional.Properties{}, dists[i])
out[i] = &Result{obj.DocID(), &sr}
}
return out, "vector,nearText", subsearch.Weight, nil
}
func nearVectorSubSearch(subsearch *searchparams.WeightedSearchResult, denseSearch denseSearchFunc) ([]*Result, string, float64, error) {
sp := subsearch.SearchParams.(searchparams.NearVector)
res, dists, err := denseSearch(sp.Vector)
if err != nil {
return nil, "", 0, err
}
out := make([]*Result, len(res))
for i, obj := range res {
sr := obj.SearchResultWithDist(additional.Properties{}, dists[i])
out[i] = &Result{obj.DocID(), &sr}
}
return out, "vector,nearVector", subsearch.Weight, nil
}
func decideSearchVector(ctx context.Context, params *Params, modules modulesProvider) ([]float32, error) {
var (
vector []float32
err error
)
if params.Vector != nil && len(params.Vector) != 0 {
vector = params.Vector
} else {
if modules != nil {
vector, err = vectorFromModuleInput(ctx, params.Class, params.Query, modules)
if err != nil {
return nil, err
}
}
}
return vector, nil
}
func vectorFromModuleInput(ctx context.Context, class, input string, modules modulesProvider) ([]float32, error) {
vector, err := modules.VectorFromInput(ctx, class, input)
if err != nil {
return nil, fmt.Errorf("get vector input from modules provider: %w", err)
}
return vector, nil
}
func truncateVectorString(maxLength int, vector []float32) string {
if len(vector) <= maxLength {
return fmt.Sprintf("%v", vector)
}
return fmt.Sprintf("%v...", vector[:maxLength])
}