Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
//go:build integrationTest | |
// +build integrationTest | |
package classification_integration_test | |
import ( | |
"context" | |
"fmt" | |
"io" | |
"math/rand" | |
"sync" | |
"github.com/go-openapi/strfmt" | |
"github.com/google/uuid" | |
"github.com/weaviate/weaviate/entities/additional" | |
"github.com/weaviate/weaviate/entities/aggregation" | |
"github.com/weaviate/weaviate/entities/filters" | |
"github.com/weaviate/weaviate/entities/models" | |
"github.com/weaviate/weaviate/entities/schema" | |
"github.com/weaviate/weaviate/entities/search" | |
"github.com/weaviate/weaviate/entities/searchparams" | |
"github.com/weaviate/weaviate/entities/storobj" | |
enthnsw "github.com/weaviate/weaviate/entities/vectorindex/hnsw" | |
"github.com/weaviate/weaviate/usecases/objects" | |
"github.com/weaviate/weaviate/usecases/replica" | |
"github.com/weaviate/weaviate/usecases/sharding" | |
) | |
type fakeSchemaGetter struct { | |
schema schema.Schema | |
shardState *sharding.State | |
} | |
func (f *fakeSchemaGetter) GetSchemaSkipAuth() schema.Schema { | |
return f.schema | |
} | |
func (f *fakeSchemaGetter) CopyShardingState(class string) *sharding.State { | |
return f.shardState | |
} | |
func (f *fakeSchemaGetter) ShardOwner(class, shard string) (string, error) { | |
ss := f.shardState | |
x, ok := ss.Physical[shard] | |
if !ok { | |
return "", fmt.Errorf("shard not found") | |
} | |
if len(x.BelongsToNodes) < 1 || x.BelongsToNodes[0] == "" { | |
return "", fmt.Errorf("owner node not found") | |
} | |
return ss.Physical[shard].BelongsToNodes[0], nil | |
} | |
func (f *fakeSchemaGetter) ShardReplicas(class, shard string) ([]string, error) { | |
ss := f.shardState | |
x, ok := ss.Physical[shard] | |
if !ok { | |
return nil, fmt.Errorf("shard not found") | |
} | |
return x.BelongsToNodes, nil | |
} | |
func (f *fakeSchemaGetter) TenantShard(class, tenant string) (string, string) { | |
return tenant, models.TenantActivityStatusHOT | |
} | |
func (f *fakeSchemaGetter) ShardFromUUID(class string, uuid []byte) string { | |
ss := f.shardState | |
return ss.Shard("", string(uuid)) | |
} | |
func (f *fakeSchemaGetter) Nodes() []string { | |
return []string{"node1"} | |
} | |
func (m *fakeSchemaGetter) NodeName() string { | |
return "node1" | |
} | |
func (m *fakeSchemaGetter) ClusterHealthScore() int { | |
return 0 | |
} | |
func (m *fakeSchemaGetter) ResolveParentNodes(_ string, shard string, | |
) (map[string]string, error) { | |
return nil, nil | |
} | |
func singleShardState() *sharding.State { | |
config, err := sharding.ParseConfig(nil, 1) | |
if err != nil { | |
panic(err) | |
} | |
s, err := sharding.InitState("test-index", config, | |
fakeNodes{[]string{"node1"}}, 1, false) | |
if err != nil { | |
panic(err) | |
} | |
return s | |
} | |
type fakeClassificationRepo struct { | |
sync.Mutex | |
db map[strfmt.UUID]models.Classification | |
} | |
func newFakeClassificationRepo() *fakeClassificationRepo { | |
return &fakeClassificationRepo{ | |
db: map[strfmt.UUID]models.Classification{}, | |
} | |
} | |
func (f *fakeClassificationRepo) Put(ctx context.Context, class models.Classification) error { | |
f.Lock() | |
defer f.Unlock() | |
f.db[class.ID] = class | |
return nil | |
} | |
func (f *fakeClassificationRepo) Get(ctx context.Context, id strfmt.UUID) (*models.Classification, error) { | |
f.Lock() | |
defer f.Unlock() | |
class, ok := f.db[id] | |
if !ok { | |
return nil, nil | |
} | |
return &class, nil | |
} | |
func testSchema() schema.Schema { | |
return schema.Schema{ | |
Objects: &models.Schema{ | |
Classes: []*models.Class{ | |
{ | |
Class: "ExactCategory", | |
VectorIndexConfig: enthnsw.NewDefaultUserConfig(), | |
InvertedIndexConfig: invertedConfig(), | |
}, | |
{ | |
Class: "MainCategory", | |
VectorIndexConfig: enthnsw.NewDefaultUserConfig(), | |
InvertedIndexConfig: invertedConfig(), | |
}, | |
{ | |
Class: "Article", | |
VectorIndexConfig: enthnsw.NewDefaultUserConfig(), | |
InvertedIndexConfig: invertedConfig(), | |
Properties: []*models.Property{ | |
{ | |
Name: "description", | |
DataType: []string{string(schema.DataTypeText)}, | |
}, | |
{ | |
Name: "name", | |
DataType: schema.DataTypeText.PropString(), | |
Tokenization: models.PropertyTokenizationWhitespace, | |
}, | |
{ | |
Name: "exactCategory", | |
DataType: []string{"ExactCategory"}, | |
}, | |
{ | |
Name: "mainCategory", | |
DataType: []string{"MainCategory"}, | |
}, | |
{ | |
Name: "categories", | |
DataType: []string{"ExactCategory"}, | |
}, | |
{ | |
Name: "anyCategory", | |
DataType: []string{"MainCategory", "ExactCategory"}, | |
}, | |
}, | |
}, | |
}, | |
}, | |
} | |
} | |
// only used for knn-type | |
func testDataAlreadyClassified() search.Results { | |
return search.Results{ | |
search.Result{ | |
ID: "8aeecd06-55a0-462c-9853-81b31a284d80", | |
ClassName: "Article", | |
Vector: []float32{1, 0, 0}, | |
Schema: map[string]interface{}{ | |
"description": "This article talks about politics", | |
"exactCategory": models.MultipleRef{beaconRef(idCategoryPolitics)}, | |
"mainCategory": models.MultipleRef{beaconRef(idMainCategoryPoliticsAndSociety)}, | |
}, | |
}, | |
search.Result{ | |
ID: "9f4c1847-2567-4de7-8861-34cf47a071ae", | |
ClassName: "Article", | |
Vector: []float32{0, 1, 0}, | |
Schema: map[string]interface{}{ | |
"description": "This articles talks about society", | |
"exactCategory": models.MultipleRef{beaconRef(idCategorySociety)}, | |
"mainCategory": models.MultipleRef{beaconRef(idMainCategoryPoliticsAndSociety)}, | |
}, | |
}, | |
search.Result{ | |
ID: "926416ec-8fb1-4e40-ab8c-37b226b3d68e", | |
ClassName: "Article", | |
Vector: []float32{0, 0, 1}, | |
Schema: map[string]interface{}{ | |
"description": "This article talks about food", | |
"exactCategory": models.MultipleRef{beaconRef(idCategoryFoodAndDrink)}, | |
"mainCategory": models.MultipleRef{beaconRef(idMainCategoryFoodAndDrink)}, | |
}, | |
}, | |
} | |
} | |
// only used for zeroshot-type | |
func testDataZeroShotUnclassified() search.Results { | |
return search.Results{ | |
search.Result{ | |
ID: "8aeecd06-55a0-462c-9853-81b31a284d80", | |
ClassName: "FoodType", | |
Vector: []float32{1, 0, 0}, | |
Schema: map[string]interface{}{ | |
"text": "Ice cream", | |
}, | |
}, | |
search.Result{ | |
ID: "9f4c1847-2567-4de7-8861-34cf47a071ae", | |
ClassName: "FoodType", | |
Vector: []float32{0, 1, 0}, | |
Schema: map[string]interface{}{ | |
"text": "Meat", | |
}, | |
}, | |
search.Result{ | |
ID: "926416ec-8fb1-4e40-ab8c-37b226b3d68e", | |
ClassName: "Recipes", | |
Vector: []float32{0, 0, 1}, | |
Schema: map[string]interface{}{ | |
"text": "Cut the steak in half and put it into pan", | |
}, | |
}, | |
search.Result{ | |
ID: "926416ec-8fb1-4e40-ab8c-37b226b3d688", | |
ClassName: "Recipes", | |
Vector: []float32{0, 1, 1}, | |
Schema: map[string]interface{}{ | |
"description": "There are flavors of vanilla, chocolate and strawberry", | |
}, | |
}, | |
} | |
} | |
func mustUUID() strfmt.UUID { | |
id, err := uuid.NewRandom() | |
if err != nil { | |
panic(err) | |
} | |
return strfmt.UUID(id.String()) | |
} | |
func largeTestDataSize(size int) search.Results { | |
out := make(search.Results, size) | |
for i := range out { | |
out[i] = search.Result{ | |
ID: mustUUID(), | |
ClassName: "Article", | |
Vector: []float32{0.02, 0, rand.Float32()}, | |
Schema: map[string]interface{}{ | |
"description": "does not matter much", | |
}, | |
} | |
} | |
return out | |
} | |
type fakeAuthorizer struct{} | |
func (f *fakeAuthorizer) Authorize(principal *models.Principal, verb, resource string) error { | |
return nil | |
} | |
func beaconRef(target string) *models.SingleRef { | |
beacon := fmt.Sprintf("weaviate://localhost/%s", target) | |
return &models.SingleRef{Beacon: strfmt.URI(beacon)} | |
} | |
const ( | |
idMainCategoryPoliticsAndSociety = "39c6abe3-4bbe-4c4e-9e60-ca5e99ec6b4e" | |
idMainCategoryFoodAndDrink = "5a3d909a-4f0d-4168-8f5c-cd3074d1e79a" | |
idCategoryPolitics = "1b204f16-7da6-44fd-bbd2-8cc4a7414bc3" | |
idCategorySociety = "ec500f39-1dc9-4580-9bd1-55a8ea8e37a2" | |
idCategoryFoodAndDrink = "027b708a-31ca-43ea-9001-88bec864c79c" | |
) | |
func invertedConfig() *models.InvertedIndexConfig { | |
return &models.InvertedIndexConfig{ | |
CleanupIntervalSeconds: 60, | |
} | |
} | |
func testSchemaForZeroShot() schema.Schema { | |
return schema.Schema{ | |
Objects: &models.Schema{ | |
Classes: []*models.Class{ | |
{ | |
Class: "FoodType", | |
VectorIndexConfig: enthnsw.NewDefaultUserConfig(), | |
InvertedIndexConfig: invertedConfig(), | |
Properties: []*models.Property{ | |
{ | |
Name: "text", | |
DataType: []string{string(schema.DataTypeText)}, | |
}, | |
}, | |
}, | |
{ | |
Class: "Recipes", | |
VectorIndexConfig: enthnsw.NewDefaultUserConfig(), | |
InvertedIndexConfig: invertedConfig(), | |
Properties: []*models.Property{ | |
{ | |
Name: "text", | |
DataType: []string{string(schema.DataTypeText)}, | |
}, | |
{ | |
Name: "ofFoodType", | |
DataType: []string{"FoodType"}, | |
}, | |
}, | |
}, | |
}, | |
}, | |
} | |
} | |
type fakeNodes struct { | |
nodes []string | |
} | |
func (f fakeNodes) Candidates() []string { | |
return f.nodes | |
} | |
func (f fakeNodes) LocalName() string { | |
return f.nodes[0] | |
} | |
type fakeRemoteClient struct{} | |
func (f *fakeRemoteClient) PutObject(ctx context.Context, hostName, indexName, | |
shardName string, obj *storobj.Object, | |
) error { | |
return nil | |
} | |
func (f *fakeRemoteClient) PutFile(ctx context.Context, hostName, indexName, | |
shardName, fileName string, payload io.ReadSeekCloser, | |
) error { | |
return nil | |
} | |
func (f *fakeRemoteClient) GetObject(ctx context.Context, hostName, indexName, | |
shardName string, id strfmt.UUID, props search.SelectProperties, | |
additional additional.Properties, | |
) (*storobj.Object, error) { | |
return nil, nil | |
} | |
func (f *fakeRemoteClient) FindObject(ctx context.Context, hostName, indexName, | |
shardName string, id strfmt.UUID, props search.SelectProperties, | |
additional additional.Properties, | |
) (*storobj.Object, error) { | |
return nil, nil | |
} | |
func (f *fakeRemoteClient) OverwriteObjects(ctx context.Context, | |
host, index, shard string, objects []*objects.VObject, | |
) ([]replica.RepairResponse, error) { | |
return nil, nil | |
} | |
func (f *fakeRemoteClient) Exists(ctx context.Context, hostName, indexName, | |
shardName string, id strfmt.UUID, | |
) (bool, error) { | |
return false, nil | |
} | |
func (f *fakeRemoteClient) DeleteObject(ctx context.Context, hostName, indexName, | |
shardName string, id strfmt.UUID, | |
) error { | |
return nil | |
} | |
func (f *fakeRemoteClient) MergeObject(ctx context.Context, hostName, indexName, | |
shardName string, mergeDoc objects.MergeDocument, | |
) error { | |
return nil | |
} | |
func (f *fakeRemoteClient) SearchShard(ctx context.Context, hostName, indexName, | |
shardName string, vector []float32, limit int, filters *filters.LocalFilter, | |
keywordRanking *searchparams.KeywordRanking, sort []filters.Sort, | |
cursor *filters.Cursor, groupBy *searchparams.GroupBy, additional additional.Properties, | |
) ([]*storobj.Object, []float32, error) { | |
return nil, nil, nil | |
} | |
func (f *fakeRemoteClient) BatchPutObjects(ctx context.Context, hostName, indexName, shardName string, objs []*storobj.Object, repl *additional.ReplicationProperties) []error { | |
return nil | |
} | |
func (f *fakeRemoteClient) MultiGetObjects(ctx context.Context, hostName, indexName, | |
shardName string, ids []strfmt.UUID, | |
) ([]*storobj.Object, error) { | |
return nil, nil | |
} | |
func (f *fakeRemoteClient) BatchAddReferences(ctx context.Context, hostName, | |
indexName, shardName string, refs objects.BatchReferences, | |
) []error { | |
return nil | |
} | |
func (f *fakeRemoteClient) Aggregate(ctx context.Context, hostName, indexName, | |
shardName string, params aggregation.Params, | |
) (*aggregation.Result, error) { | |
return nil, nil | |
} | |
func (f *fakeRemoteClient) FindUUIDs(ctx context.Context, hostName, indexName, shardName string, | |
filters *filters.LocalFilter, | |
) ([]strfmt.UUID, error) { | |
return nil, nil | |
} | |
func (f *fakeRemoteClient) DeleteObjectBatch(ctx context.Context, hostName, indexName, shardName string, | |
uuids []strfmt.UUID, dryRun bool, | |
) objects.BatchSimpleObjects { | |
return nil | |
} | |
func (f *fakeRemoteClient) GetShardQueueSize(ctx context.Context, | |
hostName, indexName, shardName string, | |
) (int64, error) { | |
return 0, nil | |
} | |
func (f *fakeRemoteClient) GetShardStatus(ctx context.Context, | |
hostName, indexName, shardName string, | |
) (string, error) { | |
return "", nil | |
} | |
func (f *fakeRemoteClient) UpdateShardStatus(ctx context.Context, hostName, indexName, shardName, | |
targetStatus string, | |
) error { | |
return nil | |
} | |
func (f *fakeRemoteClient) DigestObjects(ctx context.Context, | |
hostName, indexName, shardName string, ids []strfmt.UUID, | |
) (result []replica.RepairResponse, err error) { | |
return nil, nil | |
} | |
type fakeNodeResolver struct{} | |
func (f *fakeNodeResolver) NodeHostname(string) (string, bool) { | |
return "", false | |
} | |
type fakeRemoteNodeClient struct{} | |
func (f *fakeRemoteNodeClient) GetNodeStatus(ctx context.Context, hostName, className, output string) (*models.NodeStatus, error) { | |
return &models.NodeStatus{}, nil | |
} | |
type fakeReplicationClient struct{} | |
func (f *fakeReplicationClient) PutObject(ctx context.Context, host, index, shard, requestID string, | |
obj *storobj.Object, | |
) (replica.SimpleResponse, error) { | |
return replica.SimpleResponse{}, nil | |
} | |
func (f *fakeReplicationClient) DeleteObject(ctx context.Context, host, index, shard, requestID string, | |
id strfmt.UUID, | |
) (replica.SimpleResponse, error) { | |
return replica.SimpleResponse{}, nil | |
} | |
func (f *fakeReplicationClient) PutObjects(ctx context.Context, host, index, shard, requestID string, | |
objs []*storobj.Object, | |
) (replica.SimpleResponse, error) { | |
return replica.SimpleResponse{}, nil | |
} | |
func (f *fakeReplicationClient) MergeObject(ctx context.Context, host, index, shard, requestID string, | |
mergeDoc *objects.MergeDocument, | |
) (replica.SimpleResponse, error) { | |
return replica.SimpleResponse{}, nil | |
} | |
func (f *fakeReplicationClient) DeleteObjects(ctx context.Context, host, index, shard, requestID string, | |
uuids []strfmt.UUID, dryRun bool, | |
) (replica.SimpleResponse, error) { | |
return replica.SimpleResponse{}, nil | |
} | |
func (f *fakeReplicationClient) AddReferences(ctx context.Context, host, index, shard, requestID string, | |
refs []objects.BatchReference, | |
) (replica.SimpleResponse, error) { | |
return replica.SimpleResponse{}, nil | |
} | |
func (f *fakeReplicationClient) Commit(ctx context.Context, host, index, shard, requestID string, resp interface{}) error { | |
return nil | |
} | |
func (f *fakeReplicationClient) Abort(ctx context.Context, host, index, shard, requestID string) (replica.SimpleResponse, error) { | |
return replica.SimpleResponse{}, nil | |
} | |
func (c *fakeReplicationClient) Exists(ctx context.Context, host, index, | |
shard string, id strfmt.UUID, | |
) (bool, error) { | |
return false, nil | |
} | |
func (f *fakeReplicationClient) FetchObject(_ context.Context, host, index, | |
shard string, id strfmt.UUID, props search.SelectProperties, | |
additional additional.Properties, | |
) (objects.Replica, error) { | |
return objects.Replica{}, nil | |
} | |
func (c *fakeReplicationClient) FetchObjects(ctx context.Context, host, | |
index, shard string, ids []strfmt.UUID, | |
) ([]objects.Replica, error) { | |
return nil, nil | |
} | |
func (c *fakeReplicationClient) DigestObjects(ctx context.Context, | |
host, index, shard string, ids []strfmt.UUID, | |
) (result []replica.RepairResponse, err error) { | |
return nil, nil | |
} | |
func (c *fakeReplicationClient) OverwriteObjects(ctx context.Context, | |
host, index, shard string, vobjects []*objects.VObject, | |
) ([]replica.RepairResponse, error) { | |
return nil, nil | |
} | |