Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package config | |
import ( | |
"testing" | |
"github.com/pkg/errors" | |
"github.com/stretchr/testify/assert" | |
"github.com/weaviate/weaviate/entities/moduletools" | |
) | |
func Test_classSettings_Validate(t *testing.T) { | |
tests := []struct { | |
name string | |
cfg moduletools.ClassConfig | |
wantModel string | |
wantMaxTokens int | |
wantTemperature int | |
wantK int | |
wantStopSequences []string | |
wantReturnLikelihoods string | |
wantBaseURL string | |
wantErr error | |
}{ | |
{ | |
name: "default settings", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{}, | |
}, | |
wantModel: "command-nightly", | |
wantMaxTokens: 2048, | |
wantTemperature: 0, | |
wantK: 0, | |
wantStopSequences: []string{}, | |
wantReturnLikelihoods: "NONE", | |
wantBaseURL: "https://api.cohere.ai", | |
wantErr: nil, | |
}, | |
{ | |
name: "everything non default configured", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"model": "command-xlarge", | |
"maxTokens": 2048, | |
"temperature": 1, | |
"k": 2, | |
"stopSequences": []string{"stop1", "stop2"}, | |
"returnLikelihoods": "NONE", | |
}, | |
}, | |
wantModel: "command-xlarge", | |
wantMaxTokens: 2048, | |
wantTemperature: 1, | |
wantK: 2, | |
wantStopSequences: []string{"stop1", "stop2"}, | |
wantReturnLikelihoods: "NONE", | |
wantBaseURL: "https://api.cohere.ai", | |
wantErr: nil, | |
}, | |
{ | |
name: "wrong model configured", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"model": "wrong-model", | |
}, | |
}, | |
wantErr: errors.Errorf("wrong Cohere model name, available model names are: " + | |
"[command-xlarge-beta command-xlarge command-medium command-xlarge-nightly " + | |
"command-medium-nightly xlarge medium command command-light command-nightly command-light-nightly base base-light]"), | |
}, | |
{ | |
name: "default settings with command-light-nightly", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"model": "command-light-nightly", | |
}, | |
}, | |
wantModel: "command-light-nightly", | |
wantMaxTokens: 2048, | |
wantTemperature: 0, | |
wantK: 0, | |
wantStopSequences: []string{}, | |
wantReturnLikelihoods: "NONE", | |
wantBaseURL: "https://api.cohere.ai", | |
wantErr: nil, | |
}, | |
{ | |
name: "default settings with command-light-nightly and baseURL", | |
cfg: fakeClassConfig{ | |
classConfig: map[string]interface{}{ | |
"model": "command-light-nightly", | |
"baseURL": "http://custom-url.com", | |
}, | |
}, | |
wantModel: "command-light-nightly", | |
wantMaxTokens: 2048, | |
wantTemperature: 0, | |
wantK: 0, | |
wantStopSequences: []string{}, | |
wantReturnLikelihoods: "NONE", | |
wantBaseURL: "http://custom-url.com", | |
wantErr: nil, | |
}, | |
} | |
for _, tt := range tests { | |
t.Run(tt.name, func(t *testing.T) { | |
ic := NewClassSettings(tt.cfg) | |
if tt.wantErr != nil { | |
assert.Equal(t, tt.wantErr.Error(), ic.Validate(nil).Error()) | |
} else { | |
assert.Equal(t, tt.wantModel, ic.Model()) | |
assert.Equal(t, tt.wantMaxTokens, ic.MaxTokens()) | |
assert.Equal(t, tt.wantTemperature, ic.Temperature()) | |
assert.Equal(t, tt.wantK, ic.K()) | |
assert.Equal(t, tt.wantStopSequences, ic.StopSequences()) | |
assert.Equal(t, tt.wantReturnLikelihoods, ic.ReturnLikelihoods()) | |
} | |
}) | |
} | |
} | |
type fakeClassConfig struct { | |
classConfig map[string]interface{} | |
} | |
func (f fakeClassConfig) Class() map[string]interface{} { | |
return f.classConfig | |
} | |
func (f fakeClassConfig) Tenant() string { | |
return "" | |
} | |
func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} { | |
return f.classConfig | |
} | |
func (f fakeClassConfig) Property(propName string) map[string]interface{} { | |
return nil | |
} | |