KevinStephenson
Adding in weaviate code
b110593
raw
history blame
15.4 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package hybrid
import (
"context"
"testing"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/searchparams"
"github.com/weaviate/weaviate/entities/storobj"
)
func TestSearcher(t *testing.T) {
ctx := context.Background()
logger, _ := test.NewNullLogger()
class := "HybridClass"
tests := []struct {
name string
f func(t *testing.T)
}{
{
name: "with module provider",
f: func(t *testing.T) {
params := &Params{
HybridSearch: &searchparams.HybridSearch{
Type: "hybrid",
Alpha: 0.5,
Query: "some query",
},
Class: class,
}
sparse := func() ([]*storobj.Object, []float32, error) { return nil, nil, nil }
dense := func([]float32) ([]*storobj.Object, []float32, error) { return nil, nil, nil }
provider := &fakeModuleProvider{}
provider.On("VectorFromInput", ctx, class, params.Query).Return([]float32{1, 2, 3}, nil)
_, err := Search(ctx, params, logger, sparse, dense, nil, provider)
require.Nil(t, err)
},
},
{
name: "without module provider",
f: func(t *testing.T) {
params := &Params{
HybridSearch: &searchparams.HybridSearch{
Type: "hybrid",
Alpha: 0.5,
Query: "some query",
},
Class: class,
}
sparse := func() ([]*storobj.Object, []float32, error) { return nil, nil, nil }
dense := func([]float32) ([]*storobj.Object, []float32, error) { return nil, nil, nil }
_, err := Search(ctx, params, logger, sparse, dense, nil, nil)
require.Nil(t, err)
},
},
{
name: "with sparse search only",
f: func(t *testing.T) {
params := &Params{
HybridSearch: &searchparams.HybridSearch{
Type: "hybrid",
Alpha: 0,
Query: "some query",
},
Class: class,
}
sparse := func() ([]*storobj.Object, []float32, error) {
return []*storobj.Object{
{
Object: models.Object{
Class: class,
ID: "1889a225-3b28-477d-b8fc-5f6071bb4731",
Properties: map[string]any{"prop": "val"},
Vector: []float32{1, 2, 3},
},
Vector: []float32{1, 2, 3},
},
}, []float32{0.008}, nil
}
dense := func([]float32) ([]*storobj.Object, []float32, error) { return nil, nil, nil }
res, err := Search(ctx, params, logger, sparse, dense, nil, nil)
require.Nil(t, err)
assert.Len(t, res, 1)
assert.NotNil(t, res[0])
assert.Contains(t, res[0].Result.ExplainScore, "(bm25)")
assert.Contains(t, res[0].Result.ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
assert.Equal(t, res[0].Result.Vector, []float32{1, 2, 3})
assert.Equal(t, res[0].Result.Dist, float32(0.008))
},
},
{
name: "with dense search only",
f: func(t *testing.T) {
params := &Params{
HybridSearch: &searchparams.HybridSearch{
Type: "hybrid",
Alpha: 1,
Query: "some query",
Vector: []float32{1, 2, 3},
},
Class: class,
}
sparse := func() ([]*storobj.Object, []float32, error) { return nil, nil, nil }
dense := func([]float32) ([]*storobj.Object, []float32, error) {
return []*storobj.Object{
{
Object: models.Object{
Class: class,
ID: "1889a225-3b28-477d-b8fc-5f6071bb4731",
Properties: map[string]any{"prop": "val"},
Vector: []float32{1, 2, 3},
},
Vector: []float32{1, 2, 3},
},
}, []float32{0.008}, nil
}
res, err := Search(ctx, params, logger, sparse, dense, nil, nil)
require.Nil(t, err)
assert.Len(t, res, 1)
assert.NotNil(t, res[0])
assert.Contains(t, res[0].Result.ExplainScore, "(vector)")
assert.Contains(t, res[0].Result.ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
assert.Equal(t, res[0].Result.Vector, []float32{1, 2, 3})
assert.Equal(t, res[0].Result.Dist, float32(0.008))
},
},
{
name: "combined hybrid search",
f: func(t *testing.T) {
params := &Params{
HybridSearch: &searchparams.HybridSearch{
Type: "hybrid",
Alpha: 0.5,
Query: "some query",
Vector: []float32{1, 2, 3},
},
Class: class,
}
sparse := func() ([]*storobj.Object, []float32, error) {
return []*storobj.Object{
{
Object: models.Object{
Class: class,
ID: "1889a225-3b28-477d-b8fc-5f6071bb4731",
Properties: map[string]any{"prop": "val"},
Vector: []float32{1, 2, 3},
},
Vector: []float32{1, 2, 3},
},
}, []float32{0.008}, nil
}
dense := func([]float32) ([]*storobj.Object, []float32, error) {
return []*storobj.Object{
{
Object: models.Object{
Class: class,
ID: "79a636c2-3314-442e-a4d1-e94d7c0afc3a",
Properties: map[string]any{"prop": "val"},
Vector: []float32{4, 5, 6},
},
Vector: []float32{4, 5, 6},
},
}, []float32{0.008}, nil
}
res, err := Search(ctx, params, logger, sparse, dense, nil, nil)
require.Nil(t, err)
assert.Len(t, res, 2)
assert.NotNil(t, res[0])
assert.NotNil(t, res[1])
assert.Contains(t, res[0].Result.ExplainScore, "(vector)")
assert.Contains(t, res[0].Result.ExplainScore, "79a636c2-3314-442e-a4d1-e94d7c0afc3a")
assert.Equal(t, res[0].Result.Vector, []float32{4, 5, 6})
assert.Equal(t, res[0].Result.Dist, float32(0.008))
assert.Contains(t, res[1].Result.ExplainScore, "(bm25)")
assert.Contains(t, res[1].Result.ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
assert.Equal(t, res[1].Result.Vector, []float32{1, 2, 3})
assert.Equal(t, res[1].Result.Dist, float32(0.008))
},
},
{
name: "with sparse subsearch filter",
f: func(t *testing.T) {
params := &Params{
HybridSearch: &searchparams.HybridSearch{
Type: "hybrid",
SubSearches: []searchparams.WeightedSearchResult{
{
Type: "sparseSearch",
SearchParams: searchparams.KeywordRanking{
Type: "bm25",
Properties: []string{"propA", "propB"},
Query: "some query",
},
},
},
},
Class: class,
}
sparse := func() ([]*storobj.Object, []float32, error) {
return []*storobj.Object{
{
Object: models.Object{
Class: class,
ID: "1889a225-3b28-477d-b8fc-5f6071bb4731",
Properties: map[string]any{"prop": "val"},
Vector: []float32{1, 2, 3},
Additional: map[string]interface{}{"score": float32(0.008)},
},
Vector: []float32{1, 2, 3},
},
}, []float32{0.008}, nil
}
dense := func([]float32) ([]*storobj.Object, []float32, error) {
return nil, nil, nil
}
res, err := Search(ctx, params, logger, sparse, dense, nil, nil)
require.Nil(t, err)
assert.Len(t, res, 1)
assert.NotNil(t, res[0])
assert.Contains(t, res[0].Result.ExplainScore, "(Result Set bm25f) Document 1889a225-3b28-477d-b8fc-5f6071bb4731")
assert.Contains(t, res[0].Result.ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
assert.Equal(t, res[0].Result.Vector, []float32{1, 2, 3})
assert.Equal(t, res[0].Result.Dist, float32(0.008))
},
},
{
name: "with nearText subsearch filter",
f: func(t *testing.T) {
params := &Params{
HybridSearch: &searchparams.HybridSearch{
Type: "hybrid",
SubSearches: []searchparams.WeightedSearchResult{
{
Type: "nearText",
SearchParams: searchparams.NearTextParams{
Values: []string{"some query"},
Certainty: 0.8,
},
},
},
},
Class: class,
}
sparse := func() ([]*storobj.Object, []float32, error) {
return nil, nil, nil
}
dense := func([]float32) ([]*storobj.Object, []float32, error) {
return []*storobj.Object{
{
Object: models.Object{
Class: class,
ID: "1889a225-3b28-477d-b8fc-5f6071bb4731",
Properties: map[string]any{"prop": "val"},
Vector: []float32{1, 2, 3},
Additional: map[string]interface{}{"score": float32(0.008)},
},
Vector: []float32{1, 2, 3},
},
}, []float32{0.008}, nil
}
provider := &fakeModuleProvider{}
provider.On("VectorFromInput", ctx, class,
params.HybridSearch.SubSearches.([]searchparams.WeightedSearchResult)[0].
SearchParams.(searchparams.NearTextParams).Values[0]).Return([]float32{1, 2, 3}, nil)
provider.On("VectorFromInput", ctx, class, "").Return([]float32{1, 2, 3}, nil)
res, err := Search(ctx, params, logger, sparse, dense, nil, provider)
require.Nil(t, err)
assert.Len(t, res, 1)
assert.NotNil(t, res[0])
assert.Contains(t, res[0].Result.ExplainScore, "(Result Set vector,nearText) Document 1889a225-3b28-477d-b8fc-5f6071bb4731")
assert.Contains(t, res[0].Result.ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
assert.Equal(t, res[0].Result.Vector, []float32{1, 2, 3})
assert.Equal(t, res[0].Result.Dist, float32(0.008))
},
},
{
name: "with nearVector subsearch filter",
f: func(t *testing.T) {
params := &Params{
HybridSearch: &searchparams.HybridSearch{
Type: "hybrid",
SubSearches: []searchparams.WeightedSearchResult{
{
Type: "nearVector",
SearchParams: searchparams.NearVector{
Vector: []float32{1, 2, 3},
Certainty: 0.8,
},
},
},
},
Class: class,
}
sparse := func() ([]*storobj.Object, []float32, error) {
return nil, nil, nil
}
dense := func([]float32) ([]*storobj.Object, []float32, error) {
return []*storobj.Object{
{
Object: models.Object{
Class: class,
ID: "1889a225-3b28-477d-b8fc-5f6071bb4731",
Properties: map[string]any{"prop": "val"},
Vector: []float32{1, 2, 3},
Additional: map[string]interface{}{"score": float32(0.008)},
},
Vector: []float32{1, 2, 3},
},
}, []float32{0.008}, nil
}
provider := &fakeModuleProvider{}
provider.On("VectorFromInput", ctx, class,
params.HybridSearch.SubSearches.([]searchparams.WeightedSearchResult)[0].
SearchParams.(searchparams.NearVector).Vector).Return([]float32{1, 2, 3}, nil)
provider.On("VectorFromInput", ctx, class, "").Return([]float32{1, 2, 3}, nil)
res, err := Search(ctx, params, logger, sparse, dense, nil, provider)
require.Nil(t, err)
assert.Len(t, res, 1)
assert.NotNil(t, res[0])
assert.Contains(t, res[0].Result.ExplainScore, "(Result Set vector,nearVector) Document 1889a225-3b28-477d-b8fc-5f6071bb4731")
assert.Contains(t, res[0].Result.ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
assert.Equal(t, res[0].Result.Vector, []float32{1, 2, 3})
assert.Equal(t, res[0].Result.Dist, float32(0.008))
},
},
{
name: "with all subsearch filters",
f: func(t *testing.T) {
params := &Params{
HybridSearch: &searchparams.HybridSearch{
Type: "hybrid",
SubSearches: []searchparams.WeightedSearchResult{
{
Type: "nearVector",
SearchParams: searchparams.NearVector{
Vector: []float32{1, 2, 3},
Certainty: 0.8,
},
Weight: 100,
},
{
Type: "nearText",
SearchParams: searchparams.NearTextParams{
Values: []string{"some query"},
Certainty: 0.8,
},
Weight: 2,
},
{
Type: "sparseSearch",
SearchParams: searchparams.KeywordRanking{
Type: "bm25",
Properties: []string{"propA", "propB"},
Query: "some query",
},
Weight: 3,
},
},
},
Class: class,
}
sparse := func() ([]*storobj.Object, []float32, error) {
return []*storobj.Object{
{
Object: models.Object{
Class: class,
ID: "1889a225-3b28-477d-b8fc-5f6071bb4731",
Properties: map[string]any{"prop": "val"},
Vector: []float32{1, 2, 3},
Additional: map[string]interface{}{"score": float32(0.008)},
},
Vector: []float32{1, 2, 3},
},
}, []float32{0.008}, nil
}
dense := func([]float32) ([]*storobj.Object, []float32, error) {
return []*storobj.Object{
{
Object: models.Object{
Class: class,
ID: "79a636c2-3314-442e-a4d1-e94d7c0afc3a",
Properties: map[string]any{"prop": "val"},
Vector: []float32{4, 5, 6},
Additional: map[string]interface{}{"score": float32(0.8)},
},
Vector: []float32{4, 5, 6},
},
}, []float32{0.008}, nil
}
provider := &fakeModuleProvider{}
provider.On("VectorFromInput", ctx, class,
params.HybridSearch.SubSearches.([]searchparams.WeightedSearchResult)[0].
SearchParams.(searchparams.NearVector).Vector).Return([]float32{1, 2, 3}, nil)
provider.On("VectorFromInput", ctx, class,
params.HybridSearch.SubSearches.([]searchparams.WeightedSearchResult)[1].
SearchParams.(searchparams.NearTextParams).Values[0]).Return([]float32{1, 2, 3}, nil)
provider.On("VectorFromInput", ctx, class, "").Return([]float32{1, 2, 3}, nil)
res, err := Search(ctx, params, logger, sparse, dense, nil, provider)
require.Nil(t, err)
assert.Len(t, res, 2)
assert.NotNil(t, res[0])
assert.NotNil(t, res[1])
assert.Contains(t, res[0].Result.ExplainScore, "(Result Set vector,nearVector) Document 79a636c2-3314-442e-a4d1-e94d7c0afc3a")
assert.Contains(t, res[0].Result.ExplainScore, "79a636c2-3314-442e-a4d1-e94d7c0afc3a")
assert.Equal(t, res[0].Result.Vector, []float32{4, 5, 6})
assert.Equal(t, res[0].Result.Dist, float32(0.008))
assert.Contains(t, res[1].Result.ExplainScore, "(Result Set bm25f) Document 1889a225-3b28-477d-b8fc-5f6071bb4731")
assert.Contains(t, res[1].Result.ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
assert.Equal(t, res[1].Result.Vector, []float32{1, 2, 3})
assert.Equal(t, res[1].Result.Dist, float32(0.008))
},
},
}
for _, test := range tests {
t.Run(test.name, test.f)
}
}
type fakeModuleProvider struct {
mock.Mock
}
func (f *fakeModuleProvider) VectorFromInput(ctx context.Context,
className string, input string,
) ([]float32, error) {
args := f.Called(ctx, className, input)
return args.Get(0).([]float32), nil
}