KevinStephenson
Adding in weaviate code
b110593
raw
history blame
3.63 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package vectorizer
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/weaviate/weaviate/entities/moduletools"
)
func Test_classSettings_getPassageModel(t *testing.T) {
tests := []struct {
name string
cfg moduletools.ClassConfig
wantPassageModel string
wantQueryModel string
wantWaitForModel bool
wantUseGPU bool
wantUseCache bool
wantEndpointURL string
wantError error
}{
{
name: "CShorten/CORD-19-Title-Abstracts",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"model": "CShorten/CORD-19-Title-Abstracts",
"options": map[string]interface{}{
"waitForModel": true,
"useGPU": false,
"useCache": false,
},
},
},
wantPassageModel: "CShorten/CORD-19-Title-Abstracts",
wantQueryModel: "CShorten/CORD-19-Title-Abstracts",
wantWaitForModel: true,
wantUseGPU: false,
wantUseCache: false,
},
{
name: "sentence-transformers/all-MiniLM-L6-v2",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"model": "sentence-transformers/all-MiniLM-L6-v2",
},
},
wantPassageModel: "sentence-transformers/all-MiniLM-L6-v2",
wantQueryModel: "sentence-transformers/all-MiniLM-L6-v2",
wantWaitForModel: false,
wantUseGPU: false,
wantUseCache: true,
},
{
name: "DPR models",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"passageModel": "sentence-transformers/facebook-dpr-ctx_encoder-single-nq-base",
"queryModel": "sentence-transformers/facebook-dpr-question_encoder-single-nq-base",
},
},
wantPassageModel: "sentence-transformers/facebook-dpr-ctx_encoder-single-nq-base",
wantQueryModel: "sentence-transformers/facebook-dpr-question_encoder-single-nq-base",
wantWaitForModel: false,
wantUseGPU: false,
wantUseCache: true,
},
{
name: "Hugging Face Inference API - endpointURL",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"endpointURL": "http://endpoint.cloud",
},
},
wantPassageModel: "",
wantQueryModel: "",
wantWaitForModel: false,
wantUseGPU: false,
wantUseCache: true,
wantEndpointURL: "http://endpoint.cloud",
},
{
name: "Hugging Face Inference API - endpointUrl",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"endpointUrl": "http://endpoint.cloud",
},
},
wantPassageModel: "",
wantQueryModel: "",
wantWaitForModel: false,
wantUseGPU: false,
wantUseCache: true,
wantEndpointURL: "http://endpoint.cloud",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ic := NewClassSettings(tt.cfg)
assert.Equal(t, tt.wantPassageModel, ic.getPassageModel())
assert.Equal(t, tt.wantQueryModel, ic.getQueryModel())
assert.Equal(t, tt.wantWaitForModel, ic.OptionWaitForModel())
assert.Equal(t, tt.wantUseGPU, ic.OptionUseGPU())
assert.Equal(t, tt.wantUseCache, ic.OptionUseCache())
assert.Equal(t, tt.wantEndpointURL, ic.EndpointURL())
assert.Equal(t, tt.wantError, ic.validateClassSettings())
})
}
}