Spaces:
Sleeping
Sleeping
| // _ _ | |
| // __ _____ __ ___ ___ __ _| |_ ___ | |
| // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
| // \ V V / __/ (_| |\ V /| | (_| | || __/ | |
| // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
| // | |
| // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
| // | |
| // CONTACT: [email protected] | |
| // | |
| package main | |
| import ( | |
| "encoding/binary" | |
| "encoding/json" | |
| "fmt" | |
| "io" | |
| "math" | |
| "net/http" | |
| "os" | |
| "sync" | |
| "github.com/go-openapi/strfmt" | |
| "github.com/google/uuid" | |
| "github.com/pkg/errors" | |
| "github.com/weaviate/weaviate/entities/models" | |
| ) | |
| const ( | |
| class = "Benchmark" | |
| nrSearchResults = 79 | |
| ) | |
| func createSchemaSIFTRequest(url string) *http.Request { | |
| classObj := &models.Class{ | |
| Class: class, | |
| Description: "Dummy class for benchmarking purposes", | |
| Properties: []*models.Property{ | |
| { | |
| DataType: []string{"int"}, | |
| Description: "The value of the counter in the dataset", | |
| Name: "counter", | |
| }, | |
| }, | |
| VectorIndexConfig: map[string]interface{}{ // values are from benchmark script | |
| "distance": "l2-squared", | |
| "ef": -1, | |
| "efConstruction": 64, | |
| "maxConnections": 64, | |
| "vectorCacheMaxObjects": 1000000000, | |
| }, | |
| Vectorizer: "none", | |
| } | |
| request := createRequest(url+"schema", "POST", classObj) | |
| return request | |
| } | |
| func float32FromBytes(bytes []byte) float32 { | |
| bits := binary.LittleEndian.Uint32(bytes) | |
| float := math.Float32frombits(bits) | |
| return float | |
| } | |
| func int32FromBytes(bytes []byte) int { | |
| return int(binary.LittleEndian.Uint32(bytes)) | |
| } | |
| func readSiftFloat(file string, maxObjects int) []*models.Object { | |
| var objects []*models.Object | |
| f, err := os.Open("sift/" + file) | |
| if err != nil { | |
| panic(errors.Wrap(err, "Could not open SIFT file")) | |
| } | |
| defer f.Close() | |
| fi, err := f.Stat() | |
| if err != nil { | |
| panic(errors.Wrap(err, "Could not get SIFT file properties")) | |
| } | |
| fileSize := fi.Size() | |
| if fileSize < 1000000 { | |
| panic("The file is only " + fmt.Sprint(fileSize) + " bytes long. Did you forgot to install git lfs?") | |
| } | |
| // The sift data is a binary file containing floating point vectors | |
| // For each entry, the first 4 bytes is the length of the vector (in number of floats, not in bytes) | |
| // which is followed by the vector data with vector length * 4 bytes. | |
| // |-length-vec1 (4bytes)-|-Vec1-data-(4*length-vector-1 bytes)-|-length-vec2 (4bytes)-|-Vec2-data-(4*length-vector-2 bytes)-| | |
| // The vector length needs to be converted from bytes to int | |
| // The vector data needs to be converted from bytes to float | |
| // Note that the vector entries are of type float but are integer numbers eg 2.0 | |
| bytesPerF := 4 | |
| vectorLengthFloat := 128 | |
| vectorBytes := make([]byte, bytesPerF+vectorLengthFloat*bytesPerF) | |
| for i := 0; i >= 0; i++ { | |
| _, err = f.Read(vectorBytes) | |
| if err == io.EOF { | |
| break | |
| } else if err != nil { | |
| panic(err) | |
| } | |
| if int32FromBytes(vectorBytes[0:bytesPerF]) != vectorLengthFloat { | |
| panic("Each vector must have 128 entries.") | |
| } | |
| var vectorFloat []float32 | |
| for j := 0; j < vectorLengthFloat; j++ { | |
| start := (j + 1) * bytesPerF // first bytesPerF are length of vector | |
| vectorFloat = append(vectorFloat, float32FromBytes(vectorBytes[start:start+bytesPerF])) | |
| } | |
| ObjectUuid := uuid.New() | |
| object := &models.Object{ | |
| Class: class, | |
| ID: strfmt.UUID(ObjectUuid.String()), | |
| Vector: models.C11yVector(vectorFloat), | |
| Properties: map[string]interface{}{ | |
| "counter": i, | |
| }, | |
| } | |
| objects = append(objects, object) | |
| if i >= maxObjects { | |
| break | |
| } | |
| } | |
| if len(objects) < maxObjects { | |
| panic("Could not load all elements.") | |
| } | |
| return objects | |
| } | |
| func benchmarkSift(c *http.Client, url string, maxObjects, numBatches int) (map[string]int64, error) { | |
| clearExistingObjects(c, url) | |
| objects := readSiftFloat("sift_base.fvecs", maxObjects) | |
| queries := readSiftFloat("sift_query.fvecs", maxObjects/100) | |
| requestSchema := createSchemaSIFTRequest(url) | |
| passedTime := make(map[string]int64) | |
| // Add schema | |
| responseSchemaCode, _, timeSchema, err := performRequest(c, requestSchema) | |
| passedTime["AddSchema"] = timeSchema | |
| if err != nil { | |
| return nil, errors.Wrap(err, "Could not add schema, error: ") | |
| } else if responseSchemaCode != 200 { | |
| return nil, errors.Errorf("Could not add schma, http error code: %v", responseSchemaCode) | |
| } | |
| // Batch-add | |
| passedTime["BatchAdd"] = 0 | |
| wg := sync.WaitGroup{} | |
| batchSize := len(objects) / numBatches | |
| errorChan := make(chan error, numBatches) | |
| timeChan := make(chan int64, numBatches) | |
| for i := 0; i < numBatches; i++ { | |
| wg.Add(1) | |
| go func(batchId int, errChan chan<- error) { | |
| batchObjects := objects[batchId*batchSize : (batchId+1)*batchSize] | |
| requestAdd := createRequest(url+"batch/objects", "POST", batch{batchObjects}) | |
| responseAddCode, _, timeBatchAdd, err := performRequest(c, requestAdd) | |
| timeChan <- timeBatchAdd | |
| if err != nil { | |
| errChan <- errors.Wrap(err, "Could not add batch, error: ") | |
| } else if responseAddCode != 200 { | |
| errChan <- errors.Errorf("Could not add batch, http error code: %v", responseAddCode) | |
| } | |
| wg.Done() | |
| }(i, errorChan) | |
| } | |
| wg.Wait() | |
| close(errorChan) | |
| close(timeChan) | |
| for err := range errorChan { | |
| return nil, err | |
| } | |
| for timing := range timeChan { | |
| passedTime["BatchAdd"] += timing | |
| } | |
| // Read entries | |
| nrSearchResultsUse := nrSearchResults | |
| if maxObjects < nrSearchResultsUse { | |
| nrSearchResultsUse = maxObjects | |
| } | |
| requestRead := createRequest(url+"objects?limit="+fmt.Sprint(nrSearchResultsUse)+"&class="+class, "GET", nil) | |
| responseReadCode, body, timeGetObjects, err := performRequest(c, requestRead) | |
| passedTime["GetObjects"] = timeGetObjects | |
| if err != nil { | |
| return nil, errors.Wrap(err, "Could not read objects") | |
| } else if responseReadCode != 200 { | |
| return nil, errors.New("Could not read objects, http error code: " + fmt.Sprint(responseReadCode)) | |
| } | |
| var result map[string]interface{} | |
| if err := json.Unmarshal(body, &result); err != nil { | |
| return nil, errors.Wrap(err, "Could not unmarshal read response") | |
| } | |
| if int(result["totalResults"].(float64)) != nrSearchResultsUse { | |
| errString := "Found " + fmt.Sprint(int(result["totalResults"].(float64))) + | |
| " results. Expected " + fmt.Sprint(nrSearchResultsUse) + "." | |
| return nil, errors.New(errString) | |
| } | |
| // Use sample queries | |
| for _, query := range queries { | |
| queryString := "{Get{" + class + "(nearVector: {vector:" + fmt.Sprint(query.Vector) + " }){counter}}}" | |
| requestQuery := createRequest(url+"graphql", "POST", models.GraphQLQuery{ | |
| Query: queryString, | |
| }) | |
| responseQueryCode, body, timeQuery, err := performRequest(c, requestQuery) | |
| passedTime["Query"] += timeQuery | |
| if err != nil { | |
| return nil, errors.Wrap(err, "Could not query objects") | |
| } else if responseQueryCode != 200 { | |
| return nil, errors.Errorf("Could not query objects, http error code: %v", responseQueryCode) | |
| } | |
| var result map[string]interface{} | |
| if err := json.Unmarshal(body, &result); err != nil { | |
| return nil, errors.Wrap(err, "Could not unmarshal query response") | |
| } | |
| if result["data"] == nil || result["errors"] != nil { | |
| return nil, errors.New("GraphQL Error") | |
| } | |
| } | |
| // Delete class (with schema and all entries) to clear all entries so next round can start fresh | |
| requestDelete := createRequest(url+"schema/"+class, "DELETE", nil) | |
| responseDeleteCode, _, timeDelete, err := performRequest(c, requestDelete) | |
| passedTime["Delete"] += timeDelete | |
| if err != nil { | |
| return nil, errors.Wrap(err, "Could not delete class") | |
| } else if responseDeleteCode != 200 { | |
| return nil, errors.Errorf("Could not delete class, http error code: %v", responseDeleteCode) | |
| } | |
| return passedTime, nil | |
| } | |