SemanticSearchPOC / adapters /repos /db /batch_reference_integration_test.go
KevinStephenson
Adding in weaviate code
b110593
raw
history blame
13.1 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
//go:build integrationTest
package db
import (
"context"
"fmt"
"testing"
"github.com/go-openapi/strfmt"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/weaviate/weaviate/entities/additional"
"github.com/weaviate/weaviate/entities/dto"
"github.com/weaviate/weaviate/entities/filters"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/schema"
"github.com/weaviate/weaviate/entities/schema/crossref"
enthnsw "github.com/weaviate/weaviate/entities/vectorindex/hnsw"
"github.com/weaviate/weaviate/usecases/objects"
)
func Test_AddingReferencesInBatches(t *testing.T) {
dirName := t.TempDir()
logger := logrus.New()
schemaGetter := &fakeSchemaGetter{
schema: schema.Schema{Objects: &models.Schema{Classes: nil}},
shardState: singleShardState(),
}
repo, err := New(logger, Config{
MemtablesFlushIdleAfter: 60,
RootPath: dirName,
QueryMaximumResults: 10000,
MaxImportGoroutinesFactor: 1,
}, &fakeRemoteClient{}, &fakeNodeResolver{}, &fakeRemoteNodeClient{}, &fakeReplicationClient{}, nil)
require.Nil(t, err)
repo.SetSchemaGetter(schemaGetter)
require.Nil(t, repo.WaitForStartup(testCtx()))
defer repo.Shutdown(context.Background())
migrator := NewMigrator(repo, logger)
s := schema.Schema{
Objects: &models.Schema{
Classes: []*models.Class{
{
VectorIndexConfig: enthnsw.NewDefaultUserConfig(),
InvertedIndexConfig: invertedConfig(),
Class: "AddingBatchReferencesTestTarget",
Properties: []*models.Property{
{
Name: "name",
DataType: schema.DataTypeText.PropString(),
Tokenization: models.PropertyTokenizationWhitespace,
},
},
},
{
VectorIndexConfig: enthnsw.NewDefaultUserConfig(),
InvertedIndexConfig: invertedConfig(),
Class: "AddingBatchReferencesTestSource",
Properties: []*models.Property{
{
Name: "name",
DataType: schema.DataTypeText.PropString(),
Tokenization: models.PropertyTokenizationWhitespace,
},
{
Name: "toTarget",
DataType: []string{"AddingBatchReferencesTestTarget"},
},
},
},
},
},
}
t.Run("add required classes", func(t *testing.T) {
for _, class := range s.Objects.Classes {
t.Run(fmt.Sprintf("add %s", class.Class), func(t *testing.T) {
err := migrator.AddClass(context.Background(), class, schemaGetter.shardState)
require.Nil(t, err)
})
}
})
schemaGetter.schema = s
target1 := strfmt.UUID("7b395e5c-cf4d-4297-b8cc-1d849a057de3")
target2 := strfmt.UUID("8f9f54f3-a7db-415e-881a-0e6fb79a7ec7")
target3 := strfmt.UUID("046251cf-cb02-4102-b854-c7c4691cf16f")
target4 := strfmt.UUID("bc7d8875-3a24-4137-8203-e152096dea4f")
sourceID := strfmt.UUID("a3c98a66-be4a-4eaf-8cf3-04648a11d0f7")
t.Run("add objects", func(t *testing.T) {
err := repo.PutObject(context.Background(), &models.Object{
ID: sourceID,
Class: "AddingBatchReferencesTestSource",
Properties: map[string]interface{}{
"name": "source item",
},
}, []float32{0.5}, nil)
require.Nil(t, err)
targets := []strfmt.UUID{target1, target2, target3, target4}
for i, target := range targets {
err = repo.PutObject(context.Background(), &models.Object{
ID: target,
Class: "AddingBatchReferencesTestTarget",
Properties: map[string]interface{}{
"name": fmt.Sprintf("target item %d", i),
},
}, []float32{0.7}, nil)
require.Nil(t, err)
}
})
t.Run("verify ref count through filters", func(t *testing.T) {
t.Run("count==0 should return the source", func(t *testing.T) {
filter := buildFilter("toTarget", 0, eq, schema.DataTypeInt)
res, err := repo.Search(context.Background(), dto.GetParams{
Filters: filter,
ClassName: "AddingBatchReferencesTestSource",
Pagination: &filters.Pagination{
Limit: 10,
},
})
require.Nil(t, err)
require.Len(t, res, 1)
assert.Equal(t, res[0].ID, sourceID)
})
t.Run("count>0 should not return anything", func(t *testing.T) {
filter := buildFilter("toTarget", 0, gt, schema.DataTypeInt)
res, err := repo.Search(context.Background(), dto.GetParams{
Filters: filter,
ClassName: "AddingBatchReferencesTestSource",
Pagination: &filters.Pagination{
Limit: 10,
},
})
require.Nil(t, err)
require.Len(t, res, 0)
})
})
t.Run("add reference between them - first batch", func(t *testing.T) {
source, err := crossref.ParseSource(fmt.Sprintf(
"weaviate://localhost/AddingBatchReferencesTestSource/%s/toTarget",
sourceID))
require.Nil(t, err)
targets := []strfmt.UUID{target1, target2}
refs := make(objects.BatchReferences, len(targets))
for i, target := range targets {
to, err := crossref.Parse(fmt.Sprintf("weaviate://localhost/%s",
target))
require.Nil(t, err)
refs[i] = objects.BatchReference{
Err: nil,
From: source,
To: to,
OriginalIndex: i,
}
}
_, err = repo.AddBatchReferences(context.Background(), refs, nil)
assert.Nil(t, err)
})
t.Run("verify ref count through filters", func(t *testing.T) {
// so far we have imported two refs (!)
t.Run("count==2 should return the source", func(t *testing.T) {
filter := buildFilter("toTarget", 2, eq, schema.DataTypeInt)
res, err := repo.Search(context.Background(), dto.GetParams{
Filters: filter,
ClassName: "AddingBatchReferencesTestSource",
Pagination: &filters.Pagination{
Limit: 10,
},
})
require.Nil(t, err)
require.Len(t, res, 1)
assert.Equal(t, res[0].ID, sourceID)
})
t.Run("count==0 should not return anything", func(t *testing.T) {
filter := buildFilter("toTarget", 0, eq, schema.DataTypeInt)
res, err := repo.Search(context.Background(), dto.GetParams{
Filters: filter,
ClassName: "AddingBatchReferencesTestSource",
Pagination: &filters.Pagination{
Limit: 10,
},
})
require.Nil(t, err)
require.Len(t, res, 0)
})
})
t.Run("add reference between them - second batch including errors", func(t *testing.T) {
source, err := crossref.ParseSource(fmt.Sprintf(
"weaviate://localhost/AddingBatchReferencesTestSource/%s/toTarget",
sourceID))
require.Nil(t, err)
sourceNonExistingClass, err := crossref.ParseSource(fmt.Sprintf(
"weaviate://localhost/NonExistingClass/%s/toTarget",
sourceID))
require.Nil(t, err)
sourceNonExistingProp, err := crossref.ParseSource(fmt.Sprintf(
"weaviate://localhost/AddingBatchReferencesTestSource/%s/nonExistingProp",
sourceID))
require.Nil(t, err)
targets := []strfmt.UUID{target3, target4}
refs := make(objects.BatchReferences, 3*len(targets))
for i, target := range targets {
to, err := crossref.Parse(fmt.Sprintf("weaviate://localhost/%s", target))
require.Nil(t, err)
refs[3*i] = objects.BatchReference{
Err: nil,
From: source,
To: to,
OriginalIndex: 3 * i,
}
refs[3*i+1] = objects.BatchReference{
Err: nil,
From: sourceNonExistingClass,
To: to,
OriginalIndex: 3*i + 1,
}
refs[3*i+2] = objects.BatchReference{
Err: nil,
From: sourceNonExistingProp,
To: to,
OriginalIndex: 3*i + 2,
}
}
batchRefs, err := repo.AddBatchReferences(context.Background(), refs, nil)
assert.Nil(t, err)
require.Len(t, batchRefs, 6)
assert.Nil(t, batchRefs[0].Err)
assert.Nil(t, batchRefs[3].Err)
assert.Contains(t, batchRefs[1].Err.Error(), "NonExistingClass")
assert.Contains(t, batchRefs[4].Err.Error(), "NonExistingClass")
assert.Contains(t, batchRefs[2].Err.Error(), "nonExistingProp")
assert.Contains(t, batchRefs[5].Err.Error(), "nonExistingProp")
})
t.Run("check all references are now present", func(t *testing.T) {
source, err := repo.ObjectByID(context.Background(), sourceID, nil, additional.Properties{}, "")
require.Nil(t, err)
refs := source.Object().Properties.(map[string]interface{})["toTarget"]
refsSlice, ok := refs.(models.MultipleRef)
require.True(t, ok, fmt.Sprintf("toTarget must be models.MultipleRef, but got %#v", refs))
foundBeacons := []string{}
for _, ref := range refsSlice {
foundBeacons = append(foundBeacons, ref.Beacon.String())
}
expectedBeacons := []string{
fmt.Sprintf("weaviate://localhost/%s", target1),
fmt.Sprintf("weaviate://localhost/%s", target2),
fmt.Sprintf("weaviate://localhost/%s", target3),
fmt.Sprintf("weaviate://localhost/%s", target4),
}
assert.ElementsMatch(t, foundBeacons, expectedBeacons)
})
t.Run("verify ref count through filters", func(t *testing.T) {
// so far we have imported two refs (!)
t.Run("count==4 should return the source", func(t *testing.T) {
filter := buildFilter("toTarget", 4, eq, schema.DataTypeInt)
res, err := repo.Search(context.Background(), dto.GetParams{
Filters: filter,
ClassName: "AddingBatchReferencesTestSource",
Pagination: &filters.Pagination{
Limit: 10,
},
})
require.Nil(t, err)
require.Len(t, res, 1)
assert.Equal(t, res[0].ID, sourceID)
})
t.Run("count==0 should not return anything", func(t *testing.T) {
filter := buildFilter("toTarget", 0, eq, schema.DataTypeInt)
res, err := repo.Search(context.Background(), dto.GetParams{
Filters: filter,
ClassName: "AddingBatchReferencesTestSource",
Pagination: &filters.Pagination{
Limit: 10,
},
})
require.Nil(t, err)
require.Len(t, res, 0)
})
t.Run("count==2 should not return anything", func(t *testing.T) {
filter := buildFilter("toTarget", 2, eq, schema.DataTypeInt)
res, err := repo.Search(context.Background(), dto.GetParams{
Filters: filter,
ClassName: "AddingBatchReferencesTestSource",
Pagination: &filters.Pagination{
Limit: 10,
},
})
require.Nil(t, err)
require.Len(t, res, 0)
})
})
t.Run("verify search by cross-ref", func(t *testing.T) {
filter := &filters.LocalFilter{
Root: &filters.Clause{
Operator: eq,
On: &filters.Path{
Class: schema.ClassName("AddingBatchReferencesTestSource"),
Property: schema.PropertyName("toTarget"),
Child: &filters.Path{
Class: schema.ClassName("AddingBatchReferencesTestTarget"),
Property: schema.PropertyName("name"),
},
},
Value: &filters.Value{
Value: "item",
Type: schema.DataTypeText,
},
},
}
res, err := repo.Search(context.Background(), dto.GetParams{
Filters: filter,
ClassName: "AddingBatchReferencesTestSource",
Pagination: &filters.Pagination{
Limit: 10,
},
})
require.Nil(t, err)
require.Len(t, res, 1)
assert.Equal(t, res[0].ID, sourceID)
})
t.Run("verify objects are still searchable through the vector index",
func(t *testing.T) {
// prior to making the inverted index and its docIDs immutable, a ref
// update would not change the doc ID, therefore the batch reference
// never had to interact with the vector index. Now that they're
// immutable, the updated doc ID needs to be "re-inserted" even if the
// vector is still the same
// UPDATE gh-1334: Since batch refs are now a special case where we
// tolerate a re-use of the doc id, the above assumption is no longer
// correct. However, this test still adds value, since we were now able
// to remove the additional storage updates. By still including this
// test we verify that such an update is indeed no longer necessary
res, err := repo.VectorSearch(context.Background(), dto.GetParams{
ClassName: "AddingBatchReferencesTestSource",
SearchVector: []float32{0.49},
Pagination: &filters.Pagination{
Limit: 1,
},
})
require.Nil(t, err)
require.Len(t, res, 1)
assert.Equal(t, sourceID, res[0].ID)
})
t.Run("remove source and target classes", func(t *testing.T) {
err := repo.DeleteIndex("AddingBatchReferencesTestSource")
assert.Nil(t, err)
err = repo.DeleteIndex("AddingBatchReferencesTestTarget")
assert.Nil(t, err)
t.Run("verify classes do not exist", func(t *testing.T) {
assert.False(t, repo.IndexExists("AddingBatchReferencesTestSource"))
assert.False(t, repo.IndexExists("AddingBatchReferencesTestTarget"))
})
})
}