Spaces:
Sleeping
Sleeping
| // _ _ | |
| // __ _____ __ ___ ___ __ _| |_ ___ | |
| // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
| // \ V V / __/ (_| |\ V /| | (_| | || __/ | |
| // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
| // | |
| // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
| // | |
| // CONTACT: [email protected] | |
| // | |
| package moduleshelper | |
| import ( | |
| "context" | |
| "encoding/json" | |
| "fmt" | |
| "net/http" | |
| "os" | |
| "path/filepath" | |
| "testing" | |
| "cloud.google.com/go/storage" | |
| "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" | |
| "github.com/stretchr/testify/require" | |
| "github.com/weaviate/weaviate/test/helper" | |
| graphqlhelper "github.com/weaviate/weaviate/test/helper/graphql" | |
| "google.golang.org/api/googleapi" | |
| "google.golang.org/api/option" | |
| ) | |
| func EnsureClassExists(t *testing.T, className string, tenant string) { | |
| query := fmt.Sprintf("{Aggregate{%s", className) | |
| if tenant != "" { | |
| query += fmt.Sprintf("(tenant:%q)", tenant) | |
| } | |
| query += " { meta { count}}}}" | |
| resp := graphqlhelper.AssertGraphQL(t, helper.RootAuth, query) | |
| class := resp.Get("Aggregate", className).Result.([]interface{}) | |
| require.Len(t, class, 1) | |
| } | |
| func EnsureCompressedVectorsRestored(t *testing.T, className string) { | |
| query := fmt.Sprintf("{Get{%s(limit:1){_additional{vector}}}}", className) | |
| resp := graphqlhelper.AssertGraphQL(t, helper.RootAuth, query) | |
| class := resp.Get("Get", className).Result.([]interface{}) | |
| require.Len(t, class, 1) | |
| vecResp := class[0].(map[string]interface{})["_additional"].(map[string]interface{})["vector"].([]interface{}) | |
| searchVec := graphqlhelper.Vec2String(graphqlhelper.ParseVec(t, vecResp)) | |
| limit := 10 | |
| query = fmt.Sprintf( | |
| "{Get{%s(nearVector:{vector:%s} limit:%d){_additional{vector}}}}", | |
| className, searchVec, limit) | |
| resp = graphqlhelper.AssertGraphQL(t, helper.RootAuth, query) | |
| class = resp.Get("Get", className).Result.([]interface{}) | |
| require.Len(t, class, limit) | |
| } | |
| func GetClassCount(t *testing.T, className string, tenant string) int64 { | |
| query := fmt.Sprintf("{Aggregate{%s", className) | |
| if tenant != "" { | |
| query += fmt.Sprintf("(tenant:%q)", tenant) | |
| } | |
| query += " { meta { count}}}}" | |
| resp := graphqlhelper.AssertGraphQL(t, helper.RootAuth, query) | |
| class := resp.Get("Aggregate", className).Result.([]interface{}) | |
| require.Len(t, class, 1) | |
| meta := class[0].(map[string]interface{})["meta"].(map[string]interface{}) | |
| countPayload := meta["count"].(json.Number) | |
| count, err := countPayload.Int64() | |
| require.Nil(t, err) | |
| return count | |
| } | |
| func CreateTestFiles(t *testing.T, dirPath string) []string { | |
| count := 5 | |
| filePaths := make([]string, count) | |
| var fileName string | |
| for i := 0; i < count; i += 1 { | |
| fileName = fmt.Sprintf("file_%d.db", i) | |
| filePaths[i] = filepath.Join(dirPath, fileName) | |
| file, err := os.Create(filePaths[i]) | |
| if err != nil { | |
| t.Fatalf("failed to create test file '%s': %s", fileName, err) | |
| } | |
| fmt.Fprintf(file, "This is content of db file named %s", fileName) | |
| file.Close() | |
| } | |
| return filePaths | |
| } | |
| func CreateGCSBucket(ctx context.Context, t *testing.T, projectID, bucketName string) { | |
| client, err := storage.NewClient(ctx, option.WithoutAuthentication()) | |
| require.Nil(t, err) | |
| err = client.Bucket(bucketName).Create(ctx, projectID, nil) | |
| gcsErr, ok := err.(*googleapi.Error) | |
| if ok { | |
| // the bucket persists from the previous test. | |
| // if the bucket already exists, we can proceed | |
| if gcsErr.Code == http.StatusConflict { | |
| return | |
| } | |
| } | |
| require.Nil(t, err) | |
| } | |
| func CreateAzureContainer(ctx context.Context, t *testing.T, endpoint, containerName string) { | |
| connectionString := "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://%s/devstoreaccount1;" | |
| client, err := azblob.NewClientFromConnectionString(fmt.Sprintf(connectionString, endpoint), nil) | |
| require.Nil(t, err) | |
| _, err = client.CreateContainer(ctx, containerName, nil) | |
| require.Nil(t, err) | |
| } | |
| func DeleteAzureContainer(ctx context.Context, t *testing.T, endpoint, containerName string) { | |
| connectionString := "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://%s/devstoreaccount1;" | |
| client, err := azblob.NewClientFromConnectionString(fmt.Sprintf(connectionString, endpoint), nil) | |
| require.Nil(t, err) | |
| _, err = client.DeleteContainer(ctx, containerName, nil) | |
| require.Nil(t, err) | |
| } | |