SemanticSearchPOC / adapters /repos /db /classification_integration_test.go
KevinStephenson
Adding in weaviate code
b110593
raw
history blame
13.5 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
//go:build integrationTest
// +build integrationTest
package db
import (
"context"
"fmt"
"testing"
"github.com/go-openapi/strfmt"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/weaviate/weaviate/entities/filters"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/schema"
"github.com/weaviate/weaviate/entities/search"
enthnsw "github.com/weaviate/weaviate/entities/vectorindex/hnsw"
"github.com/weaviate/weaviate/usecases/classification"
)
func TestClassifications(t *testing.T) {
dirName := t.TempDir()
logger := logrus.New()
schemaGetter := &fakeSchemaGetter{
schema: schema.Schema{Objects: &models.Schema{Classes: nil}},
shardState: singleShardState(),
}
repo, err := New(logger, Config{
MemtablesFlushIdleAfter: 60,
RootPath: dirName,
QueryMaximumResults: 10000,
MaxImportGoroutinesFactor: 1,
}, &fakeRemoteClient{}, &fakeNodeResolver{}, &fakeRemoteNodeClient{}, &fakeReplicationClient{}, nil)
require.Nil(t, err)
repo.SetSchemaGetter(schemaGetter)
require.Nil(t, repo.WaitForStartup(testCtx()))
defer repo.Shutdown(context.Background())
migrator := NewMigrator(repo, logger)
t.Run("importing classification schema", func(t *testing.T) {
for _, class := range classificationTestSchema() {
err := migrator.AddClass(context.Background(), class, schemaGetter.shardState)
require.Nil(t, err)
}
})
// update schema getter so it's in sync with class
schemaGetter.schema = schema.Schema{Objects: &models.Schema{Classes: classificationTestSchema()}}
t.Run("importing categories", func(t *testing.T) {
for _, res := range classificationTestCategories() {
thing := res.Object()
err := repo.PutObject(context.Background(), thing, res.Vector, nil)
require.Nil(t, err)
}
})
t.Run("importing articles", func(t *testing.T) {
for _, res := range classificationTestArticles() {
thing := res.Object()
err := repo.PutObject(context.Background(), thing, res.Vector, nil)
require.Nil(t, err)
}
})
t.Run("finding all unclassified (no filters)", func(t *testing.T) {
res, err := repo.GetUnclassified(context.Background(),
"Article", []string{"exactCategory", "mainCategory"}, nil)
require.Nil(t, err)
require.Len(t, res, 6)
})
t.Run("finding all unclassified (with filters)", func(t *testing.T) {
filter := &filters.LocalFilter{
Root: &filters.Clause{
Operator: filters.OperatorEqual,
On: &filters.Path{
Property: "description",
},
Value: &filters.Value{
Value: "johnny",
Type: schema.DataTypeText,
},
},
}
res, err := repo.GetUnclassified(context.Background(),
"Article", []string{"exactCategory", "mainCategory"}, filter)
require.Nil(t, err)
require.Len(t, res, 1)
assert.Equal(t, strfmt.UUID("a2bbcbdc-76e1-477d-9e72-a6d2cfb50109"), res[0].ID)
})
t.Run("aggregating over item neighbors", func(t *testing.T) {
t.Run("close to politics (no filters)", func(t *testing.T) {
res, err := repo.AggregateNeighbors(context.Background(),
[]float32{0.7, 0.01, 0.01}, "Article",
[]string{"exactCategory", "mainCategory"}, 1, nil)
expectedRes := []classification.NeighborRef{
{
Beacon: strfmt.URI(fmt.Sprintf("weaviate://localhost/%s", idCategoryPolitics)),
Property: "exactCategory",
OverallCount: 1,
WinningCount: 1,
LosingCount: 0,
Distances: classification.NeighborRefDistances{
MeanWinningDistance: 0.00010201335,
ClosestWinningDistance: 0.00010201335,
ClosestOverallDistance: 0.00010201335,
},
},
{
Beacon: strfmt.URI(fmt.Sprintf("weaviate://localhost/%s", idMainCategoryPoliticsAndSociety)),
Property: "mainCategory",
OverallCount: 1,
WinningCount: 1,
LosingCount: 0,
Distances: classification.NeighborRefDistances{
MeanWinningDistance: 0.00010201335,
ClosestWinningDistance: 0.00010201335,
ClosestOverallDistance: 0.00010201335,
},
},
}
require.Nil(t, err)
assert.ElementsMatch(t, expectedRes, res)
})
t.Run("close to food and drink (no filters)", func(t *testing.T) {
res, err := repo.AggregateNeighbors(context.Background(),
[]float32{0.01, 0.01, 0.66}, "Article",
[]string{"exactCategory", "mainCategory"}, 1, nil)
expectedRes := []classification.NeighborRef{
{
Beacon: strfmt.URI(fmt.Sprintf("weaviate://localhost/%s", idCategoryFoodAndDrink)),
Property: "exactCategory",
OverallCount: 1,
WinningCount: 1,
LosingCount: 0,
Distances: classification.NeighborRefDistances{
MeanWinningDistance: 0.00011473894,
ClosestWinningDistance: 0.00011473894,
ClosestOverallDistance: 0.00011473894,
},
},
{
Beacon: strfmt.URI(fmt.Sprintf("weaviate://localhost/%s", idMainCategoryFoodAndDrink)),
Property: "mainCategory",
OverallCount: 1,
WinningCount: 1,
LosingCount: 0,
Distances: classification.NeighborRefDistances{
MeanWinningDistance: 0.00011473894,
ClosestWinningDistance: 0.00011473894,
ClosestOverallDistance: 0.00011473894,
},
},
}
require.Nil(t, err)
assert.ElementsMatch(t, expectedRes, res)
})
t.Run("close to food and drink (but limiting to politics through filter)", func(t *testing.T) {
filter := &filters.LocalFilter{
Root: &filters.Clause{
On: &filters.Path{
Property: "description",
},
Value: &filters.Value{
Value: "politics",
Type: schema.DataTypeText,
},
Operator: filters.OperatorEqual,
},
}
res, err := repo.AggregateNeighbors(context.Background(),
[]float32{0.01, 0.01, 0.66}, "Article",
[]string{"exactCategory", "mainCategory"}, 1, filter)
expectedRes := []classification.NeighborRef{
{
Beacon: strfmt.URI(fmt.Sprintf("weaviate://localhost/%s", idCategoryPolitics)),
Property: "exactCategory",
OverallCount: 1,
WinningCount: 1,
LosingCount: 0,
Distances: classification.NeighborRefDistances{
MeanWinningDistance: 0.49242598,
ClosestWinningDistance: 0.49242598,
ClosestOverallDistance: 0.49242598,
},
},
{
Beacon: strfmt.URI(fmt.Sprintf("weaviate://localhost/%s", idMainCategoryPoliticsAndSociety)),
Property: "mainCategory",
OverallCount: 1,
WinningCount: 1,
LosingCount: 0,
Distances: classification.NeighborRefDistances{
MeanWinningDistance: 0.49242598,
ClosestWinningDistance: 0.49242598,
ClosestOverallDistance: 0.49242598,
},
},
}
require.Nil(t, err)
assert.ElementsMatch(t, expectedRes, res)
})
})
}
// test fixtures
func classificationTestSchema() []*models.Class {
return []*models.Class{
{
Class: "ExactCategory",
VectorIndexConfig: enthnsw.NewDefaultUserConfig(),
InvertedIndexConfig: invertedConfig(),
Properties: []*models.Property{
{
Name: "name",
DataType: schema.DataTypeText.PropString(),
Tokenization: models.PropertyTokenizationWhitespace,
},
},
},
{
Class: "MainCategory",
VectorIndexConfig: enthnsw.NewDefaultUserConfig(),
InvertedIndexConfig: invertedConfig(),
Properties: []*models.Property{
{
Name: "name",
DataType: schema.DataTypeText.PropString(),
Tokenization: models.PropertyTokenizationWhitespace,
},
},
},
{
Class: "Article",
VectorIndexConfig: enthnsw.NewDefaultUserConfig(),
InvertedIndexConfig: invertedConfig(),
Properties: []*models.Property{
{
Name: "description",
DataType: []string{string(schema.DataTypeText)},
Tokenization: "word",
},
{
Name: "name",
DataType: schema.DataTypeText.PropString(),
Tokenization: models.PropertyTokenizationWhitespace,
},
{
Name: "exactCategory",
DataType: []string{"ExactCategory"},
},
{
Name: "mainCategory",
DataType: []string{"MainCategory"},
},
},
},
}
}
const (
idMainCategoryPoliticsAndSociety = "39c6abe3-4bbe-4c4e-9e60-ca5e99ec6b4e"
idMainCategoryFoodAndDrink = "5a3d909a-4f0d-4168-8f5c-cd3074d1e79a"
idCategoryPolitics = "1b204f16-7da6-44fd-bbd2-8cc4a7414bc3"
idCategorySociety = "ec500f39-1dc9-4580-9bd1-55a8ea8e37a2"
idCategoryFoodAndDrink = "027b708a-31ca-43ea-9001-88bec864c79c"
)
func beaconRef(target string) *models.SingleRef {
beacon := fmt.Sprintf("weaviate://localhost/%s", target)
return &models.SingleRef{Beacon: strfmt.URI(beacon)}
}
func classificationTestCategories() search.Results {
// using search.Results, because it's the perfect grouping of object and
// vector
return search.Results{
// exact categories
search.Result{
ID: idCategoryPolitics,
ClassName: "ExactCategory",
Vector: []float32{1, 0, 0},
Schema: map[string]interface{}{
"name": "Politics",
},
},
search.Result{
ID: idCategorySociety,
ClassName: "ExactCategory",
Vector: []float32{0, 1, 0},
Schema: map[string]interface{}{
"name": "Society",
},
},
search.Result{
ID: idCategoryFoodAndDrink,
ClassName: "ExactCategory",
Vector: []float32{0, 0, 1},
Schema: map[string]interface{}{
"name": "Food and Drink",
},
},
// main categories
search.Result{
ID: idMainCategoryPoliticsAndSociety,
ClassName: "MainCategory",
Vector: []float32{0, 1, 0},
Schema: map[string]interface{}{
"name": "Politics and Society",
},
},
search.Result{
ID: idMainCategoryFoodAndDrink,
ClassName: "MainCategory",
Vector: []float32{0, 0, 1},
Schema: map[string]interface{}{
"name": "Food and Drink",
},
},
}
}
func classificationTestArticles() search.Results {
// using search.Results, because it's the perfect grouping of object and
// vector
return search.Results{
// classified
search.Result{
ID: "8aeecd06-55a0-462c-9853-81b31a284d80",
ClassName: "Article",
Vector: []float32{1, 0, 0},
Schema: map[string]interface{}{
"description": "This article talks about politics",
"exactCategory": models.MultipleRef{beaconRef(idCategoryPolitics)},
"mainCategory": models.MultipleRef{beaconRef(idMainCategoryPoliticsAndSociety)},
},
},
search.Result{
ID: "9f4c1847-2567-4de7-8861-34cf47a071ae",
ClassName: "Article",
Vector: []float32{0, 1, 0},
Schema: map[string]interface{}{
"description": "This articles talks about society",
"exactCategory": models.MultipleRef{beaconRef(idCategorySociety)},
"mainCategory": models.MultipleRef{beaconRef(idMainCategoryPoliticsAndSociety)},
},
},
search.Result{
ID: "926416ec-8fb1-4e40-ab8c-37b226b3d68e",
ClassName: "Article",
Vector: []float32{0, 0, 1},
Schema: map[string]interface{}{
"description": "This article talks about food",
"exactCategory": models.MultipleRef{beaconRef(idCategoryFoodAndDrink)},
"mainCategory": models.MultipleRef{beaconRef(idMainCategoryFoodAndDrink)},
},
},
// unclassified
search.Result{
ID: "75ba35af-6a08-40ae-b442-3bec69b355f9",
ClassName: "Article",
Vector: []float32{0.78, 0, 0},
Schema: map[string]interface{}{
"description": "Barack Obama is a former US president",
},
},
search.Result{
ID: "f850439a-d3cd-4f17-8fbf-5a64405645cd",
ClassName: "Article",
Vector: []float32{0.90, 0, 0},
Schema: map[string]interface{}{
"description": "Michelle Obama is Barack Obamas wife",
},
},
search.Result{
ID: "a2bbcbdc-76e1-477d-9e72-a6d2cfb50109",
ClassName: "Article",
Vector: []float32{0, 0.78, 0},
Schema: map[string]interface{}{
"description": "Johnny Depp is an actor",
},
},
search.Result{
ID: "069410c3-4b9e-4f68-8034-32a066cb7997",
ClassName: "Article",
Vector: []float32{0, 0.90, 0},
Schema: map[string]interface{}{
"description": "Brad Pitt starred in a Quentin Tarantino movie",
},
},
search.Result{
ID: "06a1e824-889c-4649-97f9-1ed3fa401d8e",
ClassName: "Article",
Vector: []float32{0, 0, 0.78},
Schema: map[string]interface{}{
"description": "Ice Cream often contains a lot of sugar",
},
},
search.Result{
ID: "6402e649-b1e0-40ea-b192-a64eab0d5e56",
ClassName: "Article",
Vector: []float32{0, 0, 0.90},
Schema: map[string]interface{}{
"description": "French Fries are more common in Belgium and the US than in France",
},
},
}
}