Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package moduleshelper | |
import ( | |
"context" | |
"encoding/json" | |
"fmt" | |
"net/http" | |
"os" | |
"path/filepath" | |
"testing" | |
"cloud.google.com/go/storage" | |
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" | |
"github.com/stretchr/testify/require" | |
"github.com/weaviate/weaviate/test/helper" | |
graphqlhelper "github.com/weaviate/weaviate/test/helper/graphql" | |
"google.golang.org/api/googleapi" | |
"google.golang.org/api/option" | |
) | |
func EnsureClassExists(t *testing.T, className string, tenant string) { | |
query := fmt.Sprintf("{Aggregate{%s", className) | |
if tenant != "" { | |
query += fmt.Sprintf("(tenant:%q)", tenant) | |
} | |
query += " { meta { count}}}}" | |
resp := graphqlhelper.AssertGraphQL(t, helper.RootAuth, query) | |
class := resp.Get("Aggregate", className).Result.([]interface{}) | |
require.Len(t, class, 1) | |
} | |
func EnsureCompressedVectorsRestored(t *testing.T, className string) { | |
query := fmt.Sprintf("{Get{%s(limit:1){_additional{vector}}}}", className) | |
resp := graphqlhelper.AssertGraphQL(t, helper.RootAuth, query) | |
class := resp.Get("Get", className).Result.([]interface{}) | |
require.Len(t, class, 1) | |
vecResp := class[0].(map[string]interface{})["_additional"].(map[string]interface{})["vector"].([]interface{}) | |
searchVec := graphqlhelper.Vec2String(graphqlhelper.ParseVec(t, vecResp)) | |
limit := 10 | |
query = fmt.Sprintf( | |
"{Get{%s(nearVector:{vector:%s} limit:%d){_additional{vector}}}}", | |
className, searchVec, limit) | |
resp = graphqlhelper.AssertGraphQL(t, helper.RootAuth, query) | |
class = resp.Get("Get", className).Result.([]interface{}) | |
require.Len(t, class, limit) | |
} | |
func GetClassCount(t *testing.T, className string, tenant string) int64 { | |
query := fmt.Sprintf("{Aggregate{%s", className) | |
if tenant != "" { | |
query += fmt.Sprintf("(tenant:%q)", tenant) | |
} | |
query += " { meta { count}}}}" | |
resp := graphqlhelper.AssertGraphQL(t, helper.RootAuth, query) | |
class := resp.Get("Aggregate", className).Result.([]interface{}) | |
require.Len(t, class, 1) | |
meta := class[0].(map[string]interface{})["meta"].(map[string]interface{}) | |
countPayload := meta["count"].(json.Number) | |
count, err := countPayload.Int64() | |
require.Nil(t, err) | |
return count | |
} | |
func CreateTestFiles(t *testing.T, dirPath string) []string { | |
count := 5 | |
filePaths := make([]string, count) | |
var fileName string | |
for i := 0; i < count; i += 1 { | |
fileName = fmt.Sprintf("file_%d.db", i) | |
filePaths[i] = filepath.Join(dirPath, fileName) | |
file, err := os.Create(filePaths[i]) | |
if err != nil { | |
t.Fatalf("failed to create test file '%s': %s", fileName, err) | |
} | |
fmt.Fprintf(file, "This is content of db file named %s", fileName) | |
file.Close() | |
} | |
return filePaths | |
} | |
func CreateGCSBucket(ctx context.Context, t *testing.T, projectID, bucketName string) { | |
client, err := storage.NewClient(ctx, option.WithoutAuthentication()) | |
require.Nil(t, err) | |
err = client.Bucket(bucketName).Create(ctx, projectID, nil) | |
gcsErr, ok := err.(*googleapi.Error) | |
if ok { | |
// the bucket persists from the previous test. | |
// if the bucket already exists, we can proceed | |
if gcsErr.Code == http.StatusConflict { | |
return | |
} | |
} | |
require.Nil(t, err) | |
} | |
func CreateAzureContainer(ctx context.Context, t *testing.T, endpoint, containerName string) { | |
connectionString := "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://%s/devstoreaccount1;" | |
client, err := azblob.NewClientFromConnectionString(fmt.Sprintf(connectionString, endpoint), nil) | |
require.Nil(t, err) | |
_, err = client.CreateContainer(ctx, containerName, nil) | |
require.Nil(t, err) | |
} | |
func DeleteAzureContainer(ctx context.Context, t *testing.T, endpoint, containerName string) { | |
connectionString := "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://%s/devstoreaccount1;" | |
client, err := azblob.NewClientFromConnectionString(fmt.Sprintf(connectionString, endpoint), nil) | |
require.Nil(t, err) | |
_, err = client.DeleteContainer(ctx, containerName, nil) | |
require.Nil(t, err) | |
} | |