KevinStephenson
Adding in weaviate code
b110593
raw
history blame
4.62 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package moduleshelper
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"path/filepath"
"testing"
"cloud.google.com/go/storage"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
"github.com/stretchr/testify/require"
"github.com/weaviate/weaviate/test/helper"
graphqlhelper "github.com/weaviate/weaviate/test/helper/graphql"
"google.golang.org/api/googleapi"
"google.golang.org/api/option"
)
func EnsureClassExists(t *testing.T, className string, tenant string) {
query := fmt.Sprintf("{Aggregate{%s", className)
if tenant != "" {
query += fmt.Sprintf("(tenant:%q)", tenant)
}
query += " { meta { count}}}}"
resp := graphqlhelper.AssertGraphQL(t, helper.RootAuth, query)
class := resp.Get("Aggregate", className).Result.([]interface{})
require.Len(t, class, 1)
}
func EnsureCompressedVectorsRestored(t *testing.T, className string) {
query := fmt.Sprintf("{Get{%s(limit:1){_additional{vector}}}}", className)
resp := graphqlhelper.AssertGraphQL(t, helper.RootAuth, query)
class := resp.Get("Get", className).Result.([]interface{})
require.Len(t, class, 1)
vecResp := class[0].(map[string]interface{})["_additional"].(map[string]interface{})["vector"].([]interface{})
searchVec := graphqlhelper.Vec2String(graphqlhelper.ParseVec(t, vecResp))
limit := 10
query = fmt.Sprintf(
"{Get{%s(nearVector:{vector:%s} limit:%d){_additional{vector}}}}",
className, searchVec, limit)
resp = graphqlhelper.AssertGraphQL(t, helper.RootAuth, query)
class = resp.Get("Get", className).Result.([]interface{})
require.Len(t, class, limit)
}
func GetClassCount(t *testing.T, className string, tenant string) int64 {
query := fmt.Sprintf("{Aggregate{%s", className)
if tenant != "" {
query += fmt.Sprintf("(tenant:%q)", tenant)
}
query += " { meta { count}}}}"
resp := graphqlhelper.AssertGraphQL(t, helper.RootAuth, query)
class := resp.Get("Aggregate", className).Result.([]interface{})
require.Len(t, class, 1)
meta := class[0].(map[string]interface{})["meta"].(map[string]interface{})
countPayload := meta["count"].(json.Number)
count, err := countPayload.Int64()
require.Nil(t, err)
return count
}
func CreateTestFiles(t *testing.T, dirPath string) []string {
count := 5
filePaths := make([]string, count)
var fileName string
for i := 0; i < count; i += 1 {
fileName = fmt.Sprintf("file_%d.db", i)
filePaths[i] = filepath.Join(dirPath, fileName)
file, err := os.Create(filePaths[i])
if err != nil {
t.Fatalf("failed to create test file '%s': %s", fileName, err)
}
fmt.Fprintf(file, "This is content of db file named %s", fileName)
file.Close()
}
return filePaths
}
func CreateGCSBucket(ctx context.Context, t *testing.T, projectID, bucketName string) {
client, err := storage.NewClient(ctx, option.WithoutAuthentication())
require.Nil(t, err)
err = client.Bucket(bucketName).Create(ctx, projectID, nil)
gcsErr, ok := err.(*googleapi.Error)
if ok {
// the bucket persists from the previous test.
// if the bucket already exists, we can proceed
if gcsErr.Code == http.StatusConflict {
return
}
}
require.Nil(t, err)
}
func CreateAzureContainer(ctx context.Context, t *testing.T, endpoint, containerName string) {
connectionString := "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://%s/devstoreaccount1;"
client, err := azblob.NewClientFromConnectionString(fmt.Sprintf(connectionString, endpoint), nil)
require.Nil(t, err)
_, err = client.CreateContainer(ctx, containerName, nil)
require.Nil(t, err)
}
func DeleteAzureContainer(ctx context.Context, t *testing.T, endpoint, containerName string) {
connectionString := "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://%s/devstoreaccount1;"
client, err := azblob.NewClientFromConnectionString(fmt.Sprintf(connectionString, endpoint), nil)
require.Nil(t, err)
_, err = client.DeleteContainer(ctx, containerName, nil)
require.Nil(t, err)
}