Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package config | |
import ( | |
"fmt" | |
"testing" | |
"github.com/pkg/errors" | |
"github.com/stretchr/testify/assert" | |
"github.com/weaviate/weaviate/entities/moduletools" | |
) | |
func Test_classSettings_Validate(t *testing.T) { | |
tests := []struct { | |
name string | |
cfg moduletools.ClassConfig | |
wantApiEndpoint string | |
wantProjectID string | |
wantModelID string | |
wantTemperature float64 | |
wantTokenLimit int | |
wantTopK int | |
wantTopP float64 | |
wantErr error | |
}{ | |
{ | |
name: "happy flow", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"projectId": "projectId", | |
}, | |
}, | |
wantApiEndpoint: "us-central1-aiplatform.googleapis.com", | |
wantProjectID: "projectId", | |
wantModelID: "chat-bison", | |
wantTemperature: 0.2, | |
wantTokenLimit: 256, | |
wantTopK: 40, | |
wantTopP: 0.95, | |
wantErr: nil, | |
}, | |
{ | |
name: "custom values", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"apiEndpoint": "google.com", | |
"projectId": "cloud-project", | |
"modelId": "model-id", | |
"temperature": 0.25, | |
"tokenLimit": 254, | |
"topK": 30, | |
"topP": 0.97, | |
}, | |
}, | |
wantApiEndpoint: "google.com", | |
wantProjectID: "cloud-project", | |
wantModelID: "model-id", | |
wantTemperature: 0.25, | |
wantTokenLimit: 254, | |
wantTopK: 30, | |
wantTopP: 0.97, | |
wantErr: nil, | |
}, | |
{ | |
name: "wrong temperature", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"projectId": "cloud-project", | |
"temperature": 2, | |
}, | |
}, | |
wantErr: errors.Errorf("temperature has to be float value between 0 and 1"), | |
}, | |
{ | |
name: "wrong tokenLimit", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"projectId": "cloud-project", | |
"tokenLimit": 2000, | |
}, | |
}, | |
wantErr: errors.Errorf("tokenLimit has to be an integer value between 1 and 1024"), | |
}, | |
{ | |
name: "wrong topK", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"projectId": "cloud-project", | |
"topK": 2000, | |
}, | |
}, | |
wantErr: errors.Errorf("topK has to be an integer value between 1 and 40"), | |
}, | |
{ | |
name: "wrong topP", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"projectId": "cloud-project", | |
"topP": 3, | |
}, | |
}, | |
wantErr: errors.Errorf("topP has to be float value between 0 and 1"), | |
}, | |
{ | |
name: "wrong all", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"projectId": "", | |
"temperature": 2, | |
"tokenLimit": 2000, | |
"topK": 2000, | |
"topP": 3, | |
}, | |
}, | |
wantErr: errors.Errorf("projectId cannot be empty, " + | |
"temperature has to be float value between 0 and 1, " + | |
"tokenLimit has to be an integer value between 1 and 1024, " + | |
"topK has to be an integer value between 1 and 40, " + | |
"topP has to be float value between 0 and 1"), | |
}, | |
{ | |
name: "Generative AI", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"apiEndpoint": "generativelanguage.googleapis.com", | |
}, | |
}, | |
wantApiEndpoint: "generativelanguage.googleapis.com", | |
wantProjectID: "", | |
wantModelID: "chat-bison-001", | |
wantTemperature: 0.2, | |
wantTokenLimit: 256, | |
wantTopK: 40, | |
wantTopP: 0.95, | |
wantErr: nil, | |
}, | |
{ | |
name: "Generative AI with model", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"apiEndpoint": "generativelanguage.googleapis.com", | |
"modelId": "chat-bison-001", | |
}, | |
}, | |
wantApiEndpoint: "generativelanguage.googleapis.com", | |
wantProjectID: "", | |
wantModelID: "chat-bison-001", | |
wantTemperature: 0.2, | |
wantTokenLimit: 256, | |
wantTopK: 40, | |
wantTopP: 0.95, | |
wantErr: nil, | |
}, | |
{ | |
name: "Generative AI with not supported model", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"apiEndpoint": "generativelanguage.googleapis.com", | |
"modelId": "unsupported-model", | |
}, | |
}, | |
wantErr: fmt.Errorf("unsupported-model is not supported available models are: [chat-bison-001 gemini-pro]"), | |
}, | |
} | |
for _, tt := range tests { | |
t.Run(tt.name, func(t *testing.T) { | |
ic := NewClassSettings(tt.cfg) | |
if tt.wantErr != nil { | |
assert.EqualError(t, ic.Validate(nil), tt.wantErr.Error()) | |
} else { | |
assert.Equal(t, tt.wantApiEndpoint, ic.ApiEndpoint()) | |
assert.Equal(t, tt.wantProjectID, ic.ProjectID()) | |
assert.Equal(t, tt.wantModelID, ic.ModelID()) | |
assert.Equal(t, tt.wantTemperature, ic.Temperature()) | |
assert.Equal(t, tt.wantTokenLimit, ic.TokenLimit()) | |
assert.Equal(t, tt.wantTopK, ic.TopK()) | |
assert.Equal(t, tt.wantTopP, ic.TopP()) | |
} | |
}) | |
} | |
} | |
type fakeClassConfig struct { | |
classConfig map[string]interface{} | |
} | |
func (f fakeClassConfig) Class() map[string]interface{} { | |
return f.classConfig | |
} | |
func (f fakeClassConfig) Tenant() string { | |
return "" | |
} | |
func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} { | |
return f.classConfig | |
} | |
func (f fakeClassConfig) Property(propName string) map[string]interface{} { | |
return nil | |
} | |