KevinStephenson
Adding in weaviate code
b110593
raw
history blame
5.74 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package config
import (
"testing"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/weaviate/weaviate/entities/moduletools"
)
func Test_classSettings_Validate(t *testing.T) {
t.Skip("Skipping this test for now")
tests := []struct {
name string
cfg moduletools.ClassConfig
wantService string
wantRegion string
wantModel string
wantEndpoint string
wantTargetModel string
wantTargetVariant string
wantMaxTokenCount int
wantStopSequences []string
wantTemperature float64
wantTopP int
wantErr error
}{
{
name: "happy flow - Bedrock",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"service": "bedrock",
"region": "us-east-1",
"model": "amazon.titan-tg1-large",
},
},
wantService: "bedrock",
wantRegion: "us-east-1",
wantModel: "amazon.titan-tg1-large",
wantMaxTokenCount: 8192,
wantStopSequences: []string{},
wantTemperature: 0,
wantTopP: 1,
},
{
name: "happy flow - Sagemaker",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"service": "sagemaker",
"region": "us-east-1",
"endpoint": "my-endpoint-deployment",
"targetModel": "model",
"targetVariant": "variant-1",
},
},
wantService: "sagemaker",
wantRegion: "us-east-1",
wantEndpoint: "my-endpoint-deployment",
wantTargetModel: "model",
wantTargetVariant: "variant-1",
},
{
name: "custom values - Bedrock",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"service": "bedrock",
"region": "us-east-1",
"model": "amazon.titan-tg1-large",
"maxTokenCount": 1,
"stopSequences": []string{"test", "test2"},
"temperature": 0.2,
"topP": 0,
},
},
wantService: "bedrock",
wantRegion: "us-east-1",
wantModel: "amazon.titan-tg1-large",
wantMaxTokenCount: 1,
wantStopSequences: []string{"test", "test2"},
wantTemperature: 0.2,
wantTopP: 0,
},
{
name: "custom values - Sagemaker",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"service": "sagemaker",
"region": "us-east-1",
"endpoint": "this-is-my-endpoint",
"targetModel": "my-target-model",
"targetVariant": "my-target¬variant",
},
},
wantService: "sagemaker",
wantRegion: "us-east-1",
wantEndpoint: "this-is-my-endpoint",
wantTargetModel: "my-target-model",
wantTargetVariant: "my-target¬variant",
},
{
name: "wrong temperature",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"service": "bedrock",
"region": "us-east-1",
"model": "amazon.titan-tg1-large",
"temperature": 2,
},
},
wantErr: errors.Errorf("temperature has to be float value between 0 and 1"),
},
{
name: "wrong maxTokenCount",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"service": "bedrock",
"region": "us-east-1",
"model": "amazon.titan-tg1-large",
"maxTokenCount": 9000,
},
},
wantErr: errors.Errorf("maxTokenCount has to be an integer value between 1 and 8096"),
},
{
name: "wrong topP",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"service": "bedrock",
"region": "us-east-1",
"model": "amazon.titan-tg1-large",
"topP": 2000,
},
},
wantErr: errors.Errorf("topP has to be an integer value between 0 and 1"),
},
{
name: "wrong all",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"maxTokenCount": 9000,
"temperature": 2,
"topP": 3,
},
},
wantErr: errors.Errorf("wrong service, " +
"available services are: [bedrock sagemaker], " +
"region cannot be empty",
),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ic := NewClassSettings(tt.cfg)
if tt.wantErr != nil {
assert.EqualError(t, ic.Validate(nil), tt.wantErr.Error())
} else {
assert.Equal(t, tt.wantService, ic.Service())
assert.Equal(t, tt.wantRegion, ic.Region())
assert.Equal(t, tt.wantModel, ic.Model())
assert.Equal(t, tt.wantEndpoint, ic.Endpoint())
assert.Equal(t, tt.wantTargetModel, ic.TargetModel())
assert.Equal(t, tt.wantTargetVariant, ic.TargetVariant())
if ic.Temperature() != nil {
assert.Equal(t, tt.wantTemperature, *ic.Temperature())
}
assert.Equal(t, tt.wantStopSequences, ic.StopSequences())
if ic.TopP() != nil {
assert.Equal(t, tt.wantTopP, *ic.TopP())
}
}
})
}
}
type fakeClassConfig struct {
classConfig map[string]interface{}
}
func (f fakeClassConfig) Class() map[string]interface{} {
return f.classConfig
}
func (f fakeClassConfig) Tenant() string {
return ""
}
func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} {
return f.classConfig
}
func (f fakeClassConfig) Property(propName string) map[string]interface{} {
return nil
}