KevinStephenson
Adding in weaviate code
b110593
raw
history blame
2.31 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package clients
import (
"reflect"
"testing"
)
func Test_bertEmbeddingsDecoder_calculateVector(t *testing.T) {
tests := []struct {
name string
embeddings [][]float32
want []float32
wantErr bool
}{
{
name: "nil",
embeddings: nil,
wantErr: true,
},
{
name: "empty",
embeddings: [][]float32{},
wantErr: true,
},
{
name: "just one vector",
embeddings: [][]float32{{-0.17978577315807343}},
want: []float32{-0.17978577315807343},
},
{
name: "distilbert-base-uncased",
embeddings: [][]float32{
{-0.17978577315807343, -0.0678672045469284, 0.1706605851650238, -0.1639413982629776, -0.12804915010929108, 0.017568372189998627, 0.1610901951789856, 0.19909054040908813, -0.26103103160858154, -0.14505508542060852},
{-0.25516796112060547, -0.054695576429367065, 0.13527897000312805, -0.3919253945350647, 0.1900954395532608, 0.5994636416435242, 0.5798457264900208, 0.6522972583770752, -0.08617493510246277, -0.35053199529647827},
{0.930827260017395, 0.3315476179122925, -0.323006272315979, 0.18198077380657196, -0.3299236297607422, -0.5998684763908386, 0.3299814462661743, -0.6352149844169617, 0.5154204368591309, 0.11740084737539291},
},
want: []float32{0.1652911752462387, 0.06966160982847214, -0.005688905715942383, -0.12462866306304932, -0.08929244428873062, 0.005721171852201223, 0.35697245597839355, 0.07205760478973389, 0.05607149004936218, -0.1260620802640915},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
d := bertEmbeddingsDecoder{}
got, err := d.calculateVector(tt.embeddings)
if (err != nil) != tt.wantErr {
t.Errorf("bertEmbeddingsDecoder.calculateVector() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("bertEmbeddingsDecoder.calculateVector() = %v, want %v", got, tt.want)
}
})
}
}