File size: 2,311 Bytes
b110593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
//                           _       _
// __      _____  __ ___   ___  __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
//  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
//  CONTACT: [email protected]
//

package clients

import (
	"reflect"
	"testing"
)

func Test_bertEmbeddingsDecoder_calculateVector(t *testing.T) {
	tests := []struct {
		name       string
		embeddings [][]float32
		want       []float32
		wantErr    bool
	}{
		{
			name:       "nil",
			embeddings: nil,
			wantErr:    true,
		},
		{
			name:       "empty",
			embeddings: [][]float32{},
			wantErr:    true,
		},
		{
			name:       "just one vector",
			embeddings: [][]float32{{-0.17978577315807343}},
			want:       []float32{-0.17978577315807343},
		},
		{
			name: "distilbert-base-uncased",
			embeddings: [][]float32{
				{-0.17978577315807343, -0.0678672045469284, 0.1706605851650238, -0.1639413982629776, -0.12804915010929108, 0.017568372189998627, 0.1610901951789856, 0.19909054040908813, -0.26103103160858154, -0.14505508542060852},
				{-0.25516796112060547, -0.054695576429367065, 0.13527897000312805, -0.3919253945350647, 0.1900954395532608, 0.5994636416435242, 0.5798457264900208, 0.6522972583770752, -0.08617493510246277, -0.35053199529647827},
				{0.930827260017395, 0.3315476179122925, -0.323006272315979, 0.18198077380657196, -0.3299236297607422, -0.5998684763908386, 0.3299814462661743, -0.6352149844169617, 0.5154204368591309, 0.11740084737539291},
			},
			want: []float32{0.1652911752462387, 0.06966160982847214, -0.005688905715942383, -0.12462866306304932, -0.08929244428873062, 0.005721171852201223, 0.35697245597839355, 0.07205760478973389, 0.05607149004936218, -0.1260620802640915},
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			d := bertEmbeddingsDecoder{}
			got, err := d.calculateVector(tt.embeddings)
			if (err != nil) != tt.wantErr {
				t.Errorf("bertEmbeddingsDecoder.calculateVector() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if !reflect.DeepEqual(got, tt.want) {
				t.Errorf("bertEmbeddingsDecoder.calculateVector() = %v, want %v", got, tt.want)
			}
		})
	}
}