KevinStephenson
Adding in weaviate code
b110593
raw
history blame
16.2 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
//go:build integrationTest
// +build integrationTest
package classification_integration_test
import (
"context"
"fmt"
"io"
"math/rand"
"sync"
"github.com/go-openapi/strfmt"
"github.com/google/uuid"
"github.com/weaviate/weaviate/entities/additional"
"github.com/weaviate/weaviate/entities/aggregation"
"github.com/weaviate/weaviate/entities/filters"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/schema"
"github.com/weaviate/weaviate/entities/search"
"github.com/weaviate/weaviate/entities/searchparams"
"github.com/weaviate/weaviate/entities/storobj"
enthnsw "github.com/weaviate/weaviate/entities/vectorindex/hnsw"
"github.com/weaviate/weaviate/usecases/objects"
"github.com/weaviate/weaviate/usecases/replica"
"github.com/weaviate/weaviate/usecases/sharding"
)
type fakeSchemaGetter struct {
schema schema.Schema
shardState *sharding.State
}
func (f *fakeSchemaGetter) GetSchemaSkipAuth() schema.Schema {
return f.schema
}
func (f *fakeSchemaGetter) CopyShardingState(class string) *sharding.State {
return f.shardState
}
func (f *fakeSchemaGetter) ShardOwner(class, shard string) (string, error) {
ss := f.shardState
x, ok := ss.Physical[shard]
if !ok {
return "", fmt.Errorf("shard not found")
}
if len(x.BelongsToNodes) < 1 || x.BelongsToNodes[0] == "" {
return "", fmt.Errorf("owner node not found")
}
return ss.Physical[shard].BelongsToNodes[0], nil
}
func (f *fakeSchemaGetter) ShardReplicas(class, shard string) ([]string, error) {
ss := f.shardState
x, ok := ss.Physical[shard]
if !ok {
return nil, fmt.Errorf("shard not found")
}
return x.BelongsToNodes, nil
}
func (f *fakeSchemaGetter) TenantShard(class, tenant string) (string, string) {
return tenant, models.TenantActivityStatusHOT
}
func (f *fakeSchemaGetter) ShardFromUUID(class string, uuid []byte) string {
ss := f.shardState
return ss.Shard("", string(uuid))
}
func (f *fakeSchemaGetter) Nodes() []string {
return []string{"node1"}
}
func (m *fakeSchemaGetter) NodeName() string {
return "node1"
}
func (m *fakeSchemaGetter) ClusterHealthScore() int {
return 0
}
func (m *fakeSchemaGetter) ResolveParentNodes(_ string, shard string,
) (map[string]string, error) {
return nil, nil
}
func singleShardState() *sharding.State {
config, err := sharding.ParseConfig(nil, 1)
if err != nil {
panic(err)
}
s, err := sharding.InitState("test-index", config,
fakeNodes{[]string{"node1"}}, 1, false)
if err != nil {
panic(err)
}
return s
}
type fakeClassificationRepo struct {
sync.Mutex
db map[strfmt.UUID]models.Classification
}
func newFakeClassificationRepo() *fakeClassificationRepo {
return &fakeClassificationRepo{
db: map[strfmt.UUID]models.Classification{},
}
}
func (f *fakeClassificationRepo) Put(ctx context.Context, class models.Classification) error {
f.Lock()
defer f.Unlock()
f.db[class.ID] = class
return nil
}
func (f *fakeClassificationRepo) Get(ctx context.Context, id strfmt.UUID) (*models.Classification, error) {
f.Lock()
defer f.Unlock()
class, ok := f.db[id]
if !ok {
return nil, nil
}
return &class, nil
}
func testSchema() schema.Schema {
return schema.Schema{
Objects: &models.Schema{
Classes: []*models.Class{
{
Class: "ExactCategory",
VectorIndexConfig: enthnsw.NewDefaultUserConfig(),
InvertedIndexConfig: invertedConfig(),
},
{
Class: "MainCategory",
VectorIndexConfig: enthnsw.NewDefaultUserConfig(),
InvertedIndexConfig: invertedConfig(),
},
{
Class: "Article",
VectorIndexConfig: enthnsw.NewDefaultUserConfig(),
InvertedIndexConfig: invertedConfig(),
Properties: []*models.Property{
{
Name: "description",
DataType: []string{string(schema.DataTypeText)},
},
{
Name: "name",
DataType: schema.DataTypeText.PropString(),
Tokenization: models.PropertyTokenizationWhitespace,
},
{
Name: "exactCategory",
DataType: []string{"ExactCategory"},
},
{
Name: "mainCategory",
DataType: []string{"MainCategory"},
},
{
Name: "categories",
DataType: []string{"ExactCategory"},
},
{
Name: "anyCategory",
DataType: []string{"MainCategory", "ExactCategory"},
},
},
},
},
},
}
}
// only used for knn-type
func testDataAlreadyClassified() search.Results {
return search.Results{
search.Result{
ID: "8aeecd06-55a0-462c-9853-81b31a284d80",
ClassName: "Article",
Vector: []float32{1, 0, 0},
Schema: map[string]interface{}{
"description": "This article talks about politics",
"exactCategory": models.MultipleRef{beaconRef(idCategoryPolitics)},
"mainCategory": models.MultipleRef{beaconRef(idMainCategoryPoliticsAndSociety)},
},
},
search.Result{
ID: "9f4c1847-2567-4de7-8861-34cf47a071ae",
ClassName: "Article",
Vector: []float32{0, 1, 0},
Schema: map[string]interface{}{
"description": "This articles talks about society",
"exactCategory": models.MultipleRef{beaconRef(idCategorySociety)},
"mainCategory": models.MultipleRef{beaconRef(idMainCategoryPoliticsAndSociety)},
},
},
search.Result{
ID: "926416ec-8fb1-4e40-ab8c-37b226b3d68e",
ClassName: "Article",
Vector: []float32{0, 0, 1},
Schema: map[string]interface{}{
"description": "This article talks about food",
"exactCategory": models.MultipleRef{beaconRef(idCategoryFoodAndDrink)},
"mainCategory": models.MultipleRef{beaconRef(idMainCategoryFoodAndDrink)},
},
},
}
}
// only used for zeroshot-type
func testDataZeroShotUnclassified() search.Results {
return search.Results{
search.Result{
ID: "8aeecd06-55a0-462c-9853-81b31a284d80",
ClassName: "FoodType",
Vector: []float32{1, 0, 0},
Schema: map[string]interface{}{
"text": "Ice cream",
},
},
search.Result{
ID: "9f4c1847-2567-4de7-8861-34cf47a071ae",
ClassName: "FoodType",
Vector: []float32{0, 1, 0},
Schema: map[string]interface{}{
"text": "Meat",
},
},
search.Result{
ID: "926416ec-8fb1-4e40-ab8c-37b226b3d68e",
ClassName: "Recipes",
Vector: []float32{0, 0, 1},
Schema: map[string]interface{}{
"text": "Cut the steak in half and put it into pan",
},
},
search.Result{
ID: "926416ec-8fb1-4e40-ab8c-37b226b3d688",
ClassName: "Recipes",
Vector: []float32{0, 1, 1},
Schema: map[string]interface{}{
"description": "There are flavors of vanilla, chocolate and strawberry",
},
},
}
}
func mustUUID() strfmt.UUID {
id, err := uuid.NewRandom()
if err != nil {
panic(err)
}
return strfmt.UUID(id.String())
}
func largeTestDataSize(size int) search.Results {
out := make(search.Results, size)
for i := range out {
out[i] = search.Result{
ID: mustUUID(),
ClassName: "Article",
Vector: []float32{0.02, 0, rand.Float32()},
Schema: map[string]interface{}{
"description": "does not matter much",
},
}
}
return out
}
type fakeAuthorizer struct{}
func (f *fakeAuthorizer) Authorize(principal *models.Principal, verb, resource string) error {
return nil
}
func beaconRef(target string) *models.SingleRef {
beacon := fmt.Sprintf("weaviate://localhost/%s", target)
return &models.SingleRef{Beacon: strfmt.URI(beacon)}
}
const (
idMainCategoryPoliticsAndSociety = "39c6abe3-4bbe-4c4e-9e60-ca5e99ec6b4e"
idMainCategoryFoodAndDrink = "5a3d909a-4f0d-4168-8f5c-cd3074d1e79a"
idCategoryPolitics = "1b204f16-7da6-44fd-bbd2-8cc4a7414bc3"
idCategorySociety = "ec500f39-1dc9-4580-9bd1-55a8ea8e37a2"
idCategoryFoodAndDrink = "027b708a-31ca-43ea-9001-88bec864c79c"
)
func invertedConfig() *models.InvertedIndexConfig {
return &models.InvertedIndexConfig{
CleanupIntervalSeconds: 60,
}
}
func testSchemaForZeroShot() schema.Schema {
return schema.Schema{
Objects: &models.Schema{
Classes: []*models.Class{
{
Class: "FoodType",
VectorIndexConfig: enthnsw.NewDefaultUserConfig(),
InvertedIndexConfig: invertedConfig(),
Properties: []*models.Property{
{
Name: "text",
DataType: []string{string(schema.DataTypeText)},
},
},
},
{
Class: "Recipes",
VectorIndexConfig: enthnsw.NewDefaultUserConfig(),
InvertedIndexConfig: invertedConfig(),
Properties: []*models.Property{
{
Name: "text",
DataType: []string{string(schema.DataTypeText)},
},
{
Name: "ofFoodType",
DataType: []string{"FoodType"},
},
},
},
},
},
}
}
type fakeNodes struct {
nodes []string
}
func (f fakeNodes) Candidates() []string {
return f.nodes
}
func (f fakeNodes) LocalName() string {
return f.nodes[0]
}
type fakeRemoteClient struct{}
func (f *fakeRemoteClient) PutObject(ctx context.Context, hostName, indexName,
shardName string, obj *storobj.Object,
) error {
return nil
}
func (f *fakeRemoteClient) PutFile(ctx context.Context, hostName, indexName,
shardName, fileName string, payload io.ReadSeekCloser,
) error {
return nil
}
func (f *fakeRemoteClient) GetObject(ctx context.Context, hostName, indexName,
shardName string, id strfmt.UUID, props search.SelectProperties,
additional additional.Properties,
) (*storobj.Object, error) {
return nil, nil
}
func (f *fakeRemoteClient) FindObject(ctx context.Context, hostName, indexName,
shardName string, id strfmt.UUID, props search.SelectProperties,
additional additional.Properties,
) (*storobj.Object, error) {
return nil, nil
}
func (f *fakeRemoteClient) OverwriteObjects(ctx context.Context,
host, index, shard string, objects []*objects.VObject,
) ([]replica.RepairResponse, error) {
return nil, nil
}
func (f *fakeRemoteClient) Exists(ctx context.Context, hostName, indexName,
shardName string, id strfmt.UUID,
) (bool, error) {
return false, nil
}
func (f *fakeRemoteClient) DeleteObject(ctx context.Context, hostName, indexName,
shardName string, id strfmt.UUID,
) error {
return nil
}
func (f *fakeRemoteClient) MergeObject(ctx context.Context, hostName, indexName,
shardName string, mergeDoc objects.MergeDocument,
) error {
return nil
}
func (f *fakeRemoteClient) SearchShard(ctx context.Context, hostName, indexName,
shardName string, vector []float32, limit int, filters *filters.LocalFilter,
keywordRanking *searchparams.KeywordRanking, sort []filters.Sort,
cursor *filters.Cursor, groupBy *searchparams.GroupBy, additional additional.Properties,
) ([]*storobj.Object, []float32, error) {
return nil, nil, nil
}
func (f *fakeRemoteClient) BatchPutObjects(ctx context.Context, hostName, indexName, shardName string, objs []*storobj.Object, repl *additional.ReplicationProperties) []error {
return nil
}
func (f *fakeRemoteClient) MultiGetObjects(ctx context.Context, hostName, indexName,
shardName string, ids []strfmt.UUID,
) ([]*storobj.Object, error) {
return nil, nil
}
func (f *fakeRemoteClient) BatchAddReferences(ctx context.Context, hostName,
indexName, shardName string, refs objects.BatchReferences,
) []error {
return nil
}
func (f *fakeRemoteClient) Aggregate(ctx context.Context, hostName, indexName,
shardName string, params aggregation.Params,
) (*aggregation.Result, error) {
return nil, nil
}
func (f *fakeRemoteClient) FindUUIDs(ctx context.Context, hostName, indexName, shardName string,
filters *filters.LocalFilter,
) ([]strfmt.UUID, error) {
return nil, nil
}
func (f *fakeRemoteClient) DeleteObjectBatch(ctx context.Context, hostName, indexName, shardName string,
uuids []strfmt.UUID, dryRun bool,
) objects.BatchSimpleObjects {
return nil
}
func (f *fakeRemoteClient) GetShardQueueSize(ctx context.Context,
hostName, indexName, shardName string,
) (int64, error) {
return 0, nil
}
func (f *fakeRemoteClient) GetShardStatus(ctx context.Context,
hostName, indexName, shardName string,
) (string, error) {
return "", nil
}
func (f *fakeRemoteClient) UpdateShardStatus(ctx context.Context, hostName, indexName, shardName,
targetStatus string,
) error {
return nil
}
func (f *fakeRemoteClient) DigestObjects(ctx context.Context,
hostName, indexName, shardName string, ids []strfmt.UUID,
) (result []replica.RepairResponse, err error) {
return nil, nil
}
type fakeNodeResolver struct{}
func (f *fakeNodeResolver) NodeHostname(string) (string, bool) {
return "", false
}
type fakeRemoteNodeClient struct{}
func (f *fakeRemoteNodeClient) GetNodeStatus(ctx context.Context, hostName, className, output string) (*models.NodeStatus, error) {
return &models.NodeStatus{}, nil
}
type fakeReplicationClient struct{}
func (f *fakeReplicationClient) PutObject(ctx context.Context, host, index, shard, requestID string,
obj *storobj.Object,
) (replica.SimpleResponse, error) {
return replica.SimpleResponse{}, nil
}
func (f *fakeReplicationClient) DeleteObject(ctx context.Context, host, index, shard, requestID string,
id strfmt.UUID,
) (replica.SimpleResponse, error) {
return replica.SimpleResponse{}, nil
}
func (f *fakeReplicationClient) PutObjects(ctx context.Context, host, index, shard, requestID string,
objs []*storobj.Object,
) (replica.SimpleResponse, error) {
return replica.SimpleResponse{}, nil
}
func (f *fakeReplicationClient) MergeObject(ctx context.Context, host, index, shard, requestID string,
mergeDoc *objects.MergeDocument,
) (replica.SimpleResponse, error) {
return replica.SimpleResponse{}, nil
}
func (f *fakeReplicationClient) DeleteObjects(ctx context.Context, host, index, shard, requestID string,
uuids []strfmt.UUID, dryRun bool,
) (replica.SimpleResponse, error) {
return replica.SimpleResponse{}, nil
}
func (f *fakeReplicationClient) AddReferences(ctx context.Context, host, index, shard, requestID string,
refs []objects.BatchReference,
) (replica.SimpleResponse, error) {
return replica.SimpleResponse{}, nil
}
func (f *fakeReplicationClient) Commit(ctx context.Context, host, index, shard, requestID string, resp interface{}) error {
return nil
}
func (f *fakeReplicationClient) Abort(ctx context.Context, host, index, shard, requestID string) (replica.SimpleResponse, error) {
return replica.SimpleResponse{}, nil
}
func (c *fakeReplicationClient) Exists(ctx context.Context, host, index,
shard string, id strfmt.UUID,
) (bool, error) {
return false, nil
}
func (f *fakeReplicationClient) FetchObject(_ context.Context, host, index,
shard string, id strfmt.UUID, props search.SelectProperties,
additional additional.Properties,
) (objects.Replica, error) {
return objects.Replica{}, nil
}
func (c *fakeReplicationClient) FetchObjects(ctx context.Context, host,
index, shard string, ids []strfmt.UUID,
) ([]objects.Replica, error) {
return nil, nil
}
func (c *fakeReplicationClient) DigestObjects(ctx context.Context,
host, index, shard string, ids []strfmt.UUID,
) (result []replica.RepairResponse, err error) {
return nil, nil
}
func (c *fakeReplicationClient) OverwriteObjects(ctx context.Context,
host, index, shard string, vobjects []*objects.VObject,
) ([]replica.RepairResponse, error) {
return nil, nil
}