KevinStephenson
Adding in weaviate code
b110593
raw
history blame
5.74 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package config
import (
"fmt"
"testing"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/weaviate/weaviate/entities/moduletools"
)
func Test_classSettings_Validate(t *testing.T) {
tests := []struct {
name string
cfg moduletools.ClassConfig
wantApiEndpoint string
wantProjectID string
wantModelID string
wantTemperature float64
wantTokenLimit int
wantTopK int
wantTopP float64
wantErr error
}{
{
name: "happy flow",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"projectId": "projectId",
},
},
wantApiEndpoint: "us-central1-aiplatform.googleapis.com",
wantProjectID: "projectId",
wantModelID: "chat-bison",
wantTemperature: 0.2,
wantTokenLimit: 256,
wantTopK: 40,
wantTopP: 0.95,
wantErr: nil,
},
{
name: "custom values",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"apiEndpoint": "google.com",
"projectId": "cloud-project",
"modelId": "model-id",
"temperature": 0.25,
"tokenLimit": 254,
"topK": 30,
"topP": 0.97,
},
},
wantApiEndpoint: "google.com",
wantProjectID: "cloud-project",
wantModelID: "model-id",
wantTemperature: 0.25,
wantTokenLimit: 254,
wantTopK: 30,
wantTopP: 0.97,
wantErr: nil,
},
{
name: "wrong temperature",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"projectId": "cloud-project",
"temperature": 2,
},
},
wantErr: errors.Errorf("temperature has to be float value between 0 and 1"),
},
{
name: "wrong tokenLimit",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"projectId": "cloud-project",
"tokenLimit": 2000,
},
},
wantErr: errors.Errorf("tokenLimit has to be an integer value between 1 and 1024"),
},
{
name: "wrong topK",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"projectId": "cloud-project",
"topK": 2000,
},
},
wantErr: errors.Errorf("topK has to be an integer value between 1 and 40"),
},
{
name: "wrong topP",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"projectId": "cloud-project",
"topP": 3,
},
},
wantErr: errors.Errorf("topP has to be float value between 0 and 1"),
},
{
name: "wrong all",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"projectId": "",
"temperature": 2,
"tokenLimit": 2000,
"topK": 2000,
"topP": 3,
},
},
wantErr: errors.Errorf("projectId cannot be empty, " +
"temperature has to be float value between 0 and 1, " +
"tokenLimit has to be an integer value between 1 and 1024, " +
"topK has to be an integer value between 1 and 40, " +
"topP has to be float value between 0 and 1"),
},
{
name: "Generative AI",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"apiEndpoint": "generativelanguage.googleapis.com",
},
},
wantApiEndpoint: "generativelanguage.googleapis.com",
wantProjectID: "",
wantModelID: "chat-bison-001",
wantTemperature: 0.2,
wantTokenLimit: 256,
wantTopK: 40,
wantTopP: 0.95,
wantErr: nil,
},
{
name: "Generative AI with model",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"apiEndpoint": "generativelanguage.googleapis.com",
"modelId": "chat-bison-001",
},
},
wantApiEndpoint: "generativelanguage.googleapis.com",
wantProjectID: "",
wantModelID: "chat-bison-001",
wantTemperature: 0.2,
wantTokenLimit: 256,
wantTopK: 40,
wantTopP: 0.95,
wantErr: nil,
},
{
name: "Generative AI with not supported model",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"apiEndpoint": "generativelanguage.googleapis.com",
"modelId": "unsupported-model",
},
},
wantErr: fmt.Errorf("unsupported-model is not supported available models are: [chat-bison-001 gemini-pro]"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ic := NewClassSettings(tt.cfg)
if tt.wantErr != nil {
assert.EqualError(t, ic.Validate(nil), tt.wantErr.Error())
} else {
assert.Equal(t, tt.wantApiEndpoint, ic.ApiEndpoint())
assert.Equal(t, tt.wantProjectID, ic.ProjectID())
assert.Equal(t, tt.wantModelID, ic.ModelID())
assert.Equal(t, tt.wantTemperature, ic.Temperature())
assert.Equal(t, tt.wantTokenLimit, ic.TokenLimit())
assert.Equal(t, tt.wantTopK, ic.TopK())
assert.Equal(t, tt.wantTopP, ic.TopP())
}
})
}
}
type fakeClassConfig struct {
classConfig map[string]interface{}
}
func (f fakeClassConfig) Class() map[string]interface{} {
return f.classConfig
}
func (f fakeClassConfig) Tenant() string {
return ""
}
func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} {
return f.classConfig
}
func (f fakeClassConfig) Property(propName string) map[string]interface{} {
return nil
}