Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package vectorizer | |
import ( | |
"testing" | |
"github.com/stretchr/testify/assert" | |
"github.com/weaviate/weaviate/entities/moduletools" | |
) | |
func Test_classSettings_getPassageModel(t *testing.T) { | |
tests := []struct { | |
name string | |
cfg moduletools.ClassConfig | |
wantPassageModel string | |
wantQueryModel string | |
wantWaitForModel bool | |
wantUseGPU bool | |
wantUseCache bool | |
wantEndpointURL string | |
wantError error | |
}{ | |
{ | |
name: "CShorten/CORD-19-Title-Abstracts", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"model": "CShorten/CORD-19-Title-Abstracts", | |
"options": map[string]interface{}{ | |
"waitForModel": true, | |
"useGPU": false, | |
"useCache": false, | |
}, | |
}, | |
}, | |
wantPassageModel: "CShorten/CORD-19-Title-Abstracts", | |
wantQueryModel: "CShorten/CORD-19-Title-Abstracts", | |
wantWaitForModel: true, | |
wantUseGPU: false, | |
wantUseCache: false, | |
}, | |
{ | |
name: "sentence-transformers/all-MiniLM-L6-v2", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"model": "sentence-transformers/all-MiniLM-L6-v2", | |
}, | |
}, | |
wantPassageModel: "sentence-transformers/all-MiniLM-L6-v2", | |
wantQueryModel: "sentence-transformers/all-MiniLM-L6-v2", | |
wantWaitForModel: false, | |
wantUseGPU: false, | |
wantUseCache: true, | |
}, | |
{ | |
name: "DPR models", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"passageModel": "sentence-transformers/facebook-dpr-ctx_encoder-single-nq-base", | |
"queryModel": "sentence-transformers/facebook-dpr-question_encoder-single-nq-base", | |
}, | |
}, | |
wantPassageModel: "sentence-transformers/facebook-dpr-ctx_encoder-single-nq-base", | |
wantQueryModel: "sentence-transformers/facebook-dpr-question_encoder-single-nq-base", | |
wantWaitForModel: false, | |
wantUseGPU: false, | |
wantUseCache: true, | |
}, | |
{ | |
name: "Hugging Face Inference API - endpointURL", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"endpointURL": "http://endpoint.cloud", | |
}, | |
}, | |
wantPassageModel: "", | |
wantQueryModel: "", | |
wantWaitForModel: false, | |
wantUseGPU: false, | |
wantUseCache: true, | |
wantEndpointURL: "http://endpoint.cloud", | |
}, | |
{ | |
name: "Hugging Face Inference API - endpointUrl", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"endpointUrl": "http://endpoint.cloud", | |
}, | |
}, | |
wantPassageModel: "", | |
wantQueryModel: "", | |
wantWaitForModel: false, | |
wantUseGPU: false, | |
wantUseCache: true, | |
wantEndpointURL: "http://endpoint.cloud", | |
}, | |
} | |
for _, tt := range tests { | |
t.Run(tt.name, func(t *testing.T) { | |
ic := NewClassSettings(tt.cfg) | |
assert.Equal(t, tt.wantPassageModel, ic.getPassageModel()) | |
assert.Equal(t, tt.wantQueryModel, ic.getQueryModel()) | |
assert.Equal(t, tt.wantWaitForModel, ic.OptionWaitForModel()) | |
assert.Equal(t, tt.wantUseGPU, ic.OptionUseGPU()) | |
assert.Equal(t, tt.wantUseCache, ic.OptionUseCache()) | |
assert.Equal(t, tt.wantEndpointURL, ic.EndpointURL()) | |
assert.Equal(t, tt.wantError, ic.validateClassSettings()) | |
}) | |
} | |
} | |