SemanticSearchPOC / test /benchmark /benchmark_sift.go
KevinStephenson
Adding in weaviate code
b110593
raw
history blame
7.88 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package main
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"os"
"sync"
"github.com/go-openapi/strfmt"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/weaviate/weaviate/entities/models"
)
const (
class = "Benchmark"
nrSearchResults = 79
)
func createSchemaSIFTRequest(url string) *http.Request {
classObj := &models.Class{
Class: class,
Description: "Dummy class for benchmarking purposes",
Properties: []*models.Property{
{
DataType: []string{"int"},
Description: "The value of the counter in the dataset",
Name: "counter",
},
},
VectorIndexConfig: map[string]interface{}{ // values are from benchmark script
"distance": "l2-squared",
"ef": -1,
"efConstruction": 64,
"maxConnections": 64,
"vectorCacheMaxObjects": 1000000000,
},
Vectorizer: "none",
}
request := createRequest(url+"schema", "POST", classObj)
return request
}
func float32FromBytes(bytes []byte) float32 {
bits := binary.LittleEndian.Uint32(bytes)
float := math.Float32frombits(bits)
return float
}
func int32FromBytes(bytes []byte) int {
return int(binary.LittleEndian.Uint32(bytes))
}
func readSiftFloat(file string, maxObjects int) []*models.Object {
var objects []*models.Object
f, err := os.Open("sift/" + file)
if err != nil {
panic(errors.Wrap(err, "Could not open SIFT file"))
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
panic(errors.Wrap(err, "Could not get SIFT file properties"))
}
fileSize := fi.Size()
if fileSize < 1000000 {
panic("The file is only " + fmt.Sprint(fileSize) + " bytes long. Did you forgot to install git lfs?")
}
// The sift data is a binary file containing floating point vectors
// For each entry, the first 4 bytes is the length of the vector (in number of floats, not in bytes)
// which is followed by the vector data with vector length * 4 bytes.
// |-length-vec1 (4bytes)-|-Vec1-data-(4*length-vector-1 bytes)-|-length-vec2 (4bytes)-|-Vec2-data-(4*length-vector-2 bytes)-|
// The vector length needs to be converted from bytes to int
// The vector data needs to be converted from bytes to float
// Note that the vector entries are of type float but are integer numbers eg 2.0
bytesPerF := 4
vectorLengthFloat := 128
vectorBytes := make([]byte, bytesPerF+vectorLengthFloat*bytesPerF)
for i := 0; i >= 0; i++ {
_, err = f.Read(vectorBytes)
if err == io.EOF {
break
} else if err != nil {
panic(err)
}
if int32FromBytes(vectorBytes[0:bytesPerF]) != vectorLengthFloat {
panic("Each vector must have 128 entries.")
}
var vectorFloat []float32
for j := 0; j < vectorLengthFloat; j++ {
start := (j + 1) * bytesPerF // first bytesPerF are length of vector
vectorFloat = append(vectorFloat, float32FromBytes(vectorBytes[start:start+bytesPerF]))
}
ObjectUuid := uuid.New()
object := &models.Object{
Class: class,
ID: strfmt.UUID(ObjectUuid.String()),
Vector: models.C11yVector(vectorFloat),
Properties: map[string]interface{}{
"counter": i,
},
}
objects = append(objects, object)
if i >= maxObjects {
break
}
}
if len(objects) < maxObjects {
panic("Could not load all elements.")
}
return objects
}
func benchmarkSift(c *http.Client, url string, maxObjects, numBatches int) (map[string]int64, error) {
clearExistingObjects(c, url)
objects := readSiftFloat("sift_base.fvecs", maxObjects)
queries := readSiftFloat("sift_query.fvecs", maxObjects/100)
requestSchema := createSchemaSIFTRequest(url)
passedTime := make(map[string]int64)
// Add schema
responseSchemaCode, _, timeSchema, err := performRequest(c, requestSchema)
passedTime["AddSchema"] = timeSchema
if err != nil {
return nil, errors.Wrap(err, "Could not add schema, error: ")
} else if responseSchemaCode != 200 {
return nil, errors.Errorf("Could not add schma, http error code: %v", responseSchemaCode)
}
// Batch-add
passedTime["BatchAdd"] = 0
wg := sync.WaitGroup{}
batchSize := len(objects) / numBatches
errorChan := make(chan error, numBatches)
timeChan := make(chan int64, numBatches)
for i := 0; i < numBatches; i++ {
wg.Add(1)
go func(batchId int, errChan chan<- error) {
batchObjects := objects[batchId*batchSize : (batchId+1)*batchSize]
requestAdd := createRequest(url+"batch/objects", "POST", batch{batchObjects})
responseAddCode, _, timeBatchAdd, err := performRequest(c, requestAdd)
timeChan <- timeBatchAdd
if err != nil {
errChan <- errors.Wrap(err, "Could not add batch, error: ")
} else if responseAddCode != 200 {
errChan <- errors.Errorf("Could not add batch, http error code: %v", responseAddCode)
}
wg.Done()
}(i, errorChan)
}
wg.Wait()
close(errorChan)
close(timeChan)
for err := range errorChan {
return nil, err
}
for timing := range timeChan {
passedTime["BatchAdd"] += timing
}
// Read entries
nrSearchResultsUse := nrSearchResults
if maxObjects < nrSearchResultsUse {
nrSearchResultsUse = maxObjects
}
requestRead := createRequest(url+"objects?limit="+fmt.Sprint(nrSearchResultsUse)+"&class="+class, "GET", nil)
responseReadCode, body, timeGetObjects, err := performRequest(c, requestRead)
passedTime["GetObjects"] = timeGetObjects
if err != nil {
return nil, errors.Wrap(err, "Could not read objects")
} else if responseReadCode != 200 {
return nil, errors.New("Could not read objects, http error code: " + fmt.Sprint(responseReadCode))
}
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
return nil, errors.Wrap(err, "Could not unmarshal read response")
}
if int(result["totalResults"].(float64)) != nrSearchResultsUse {
errString := "Found " + fmt.Sprint(int(result["totalResults"].(float64))) +
" results. Expected " + fmt.Sprint(nrSearchResultsUse) + "."
return nil, errors.New(errString)
}
// Use sample queries
for _, query := range queries {
queryString := "{Get{" + class + "(nearVector: {vector:" + fmt.Sprint(query.Vector) + " }){counter}}}"
requestQuery := createRequest(url+"graphql", "POST", models.GraphQLQuery{
Query: queryString,
})
responseQueryCode, body, timeQuery, err := performRequest(c, requestQuery)
passedTime["Query"] += timeQuery
if err != nil {
return nil, errors.Wrap(err, "Could not query objects")
} else if responseQueryCode != 200 {
return nil, errors.Errorf("Could not query objects, http error code: %v", responseQueryCode)
}
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
return nil, errors.Wrap(err, "Could not unmarshal query response")
}
if result["data"] == nil || result["errors"] != nil {
return nil, errors.New("GraphQL Error")
}
}
// Delete class (with schema and all entries) to clear all entries so next round can start fresh
requestDelete := createRequest(url+"schema/"+class, "DELETE", nil)
responseDeleteCode, _, timeDelete, err := performRequest(c, requestDelete)
passedTime["Delete"] += timeDelete
if err != nil {
return nil, errors.Wrap(err, "Could not delete class")
} else if responseDeleteCode != 200 {
return nil, errors.Errorf("Could not delete class, http error code: %v", responseDeleteCode)
}
return passedTime, nil
}