Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package config | |
import ( | |
"encoding/json" | |
"fmt" | |
"strings" | |
"github.com/pkg/errors" | |
"github.com/weaviate/weaviate/entities/models" | |
"github.com/weaviate/weaviate/entities/moduletools" | |
) | |
const ( | |
serviceProperty = "service" | |
regionProperty = "region" | |
modelProperty = "model" | |
endpointProperty = "endpoint" | |
targetModelProperty = "targetModel" | |
targetVariantProperty = "targetVariant" | |
maxTokenCountProperty = "maxTokenCount" | |
maxTokensToSampleProperty = "maxTokensToSample" | |
stopSequencesProperty = "stopSequences" | |
temperatureProperty = "temperature" | |
topPProperty = "topP" | |
topKProperty = "topK" | |
) | |
var ( | |
DefaultTitanMaxTokens = 8192 | |
DefaultTitanStopSequences = []string{} | |
DefaultTitanTemperature = 0.0 | |
DefaultTitanTopP = 1.0 | |
DefaultService = "bedrock" | |
) | |
var ( | |
DefaultAnthropicMaxTokensToSample = 300 | |
DefaultAnthropicStopSequences = []string{"\\n\\nHuman:"} | |
DefaultAnthropicTemperature = 1.0 | |
DefaultAnthropicTopK = 250 | |
DefaultAnthropicTopP = 0.999 | |
) | |
var DefaultAI21MaxTokens = 300 | |
var ( | |
DefaultCohereMaxTokens = 100 | |
DefaultCohereTemperature = 0.8 | |
DefaultAI21Temperature = 0.7 | |
DefaultCohereTopP = 1.0 | |
) | |
var availableAWSServices = []string{ | |
DefaultService, | |
} | |
var availableBedrockModels = []string{ | |
"cohere.command-text-v14", | |
"cohere.command-light-text-v14", | |
} | |
type classSettings struct { | |
cfg moduletools.ClassConfig | |
} | |
func NewClassSettings(cfg moduletools.ClassConfig) *classSettings { | |
return &classSettings{cfg: cfg} | |
} | |
func (ic *classSettings) Validate(class *models.Class) error { | |
if ic.cfg == nil { | |
// we would receive a nil-config on cross-class requests, such as Explore{} | |
return errors.New("empty config") | |
} | |
var errorMessages []string | |
service := ic.Service() | |
if service == "" || !ic.validatAvailableAWSSetting(service, availableAWSServices) { | |
errorMessages = append(errorMessages, fmt.Sprintf("wrong %s, available services are: %v", serviceProperty, availableAWSServices)) | |
} | |
region := ic.Region() | |
if region == "" { | |
errorMessages = append(errorMessages, fmt.Sprintf("%s cannot be empty", regionProperty)) | |
} | |
if isBedrock(service) { | |
model := ic.Model() | |
if model == "" && !ic.validateAWSSetting(model, availableBedrockModels) { | |
errorMessages = append(errorMessages, fmt.Sprintf("wrong %s: %s, available model names are: %v", modelProperty, model, availableBedrockModels)) | |
} | |
maxTokenCount := ic.MaxTokenCount() | |
if *maxTokenCount < 1 || *maxTokenCount > 8192 { | |
errorMessages = append(errorMessages, fmt.Sprintf("%s has to be an integer value between 1 and 8096", maxTokenCountProperty)) | |
} | |
temperature := ic.Temperature() | |
if *temperature < 0 || *temperature > 1 { | |
errorMessages = append(errorMessages, fmt.Sprintf("%s has to be float value between 0 and 1", temperatureProperty)) | |
} | |
topP := ic.TopP() | |
if topP != nil && (*topP < 0 || *topP > 1) { | |
errorMessages = append(errorMessages, fmt.Sprintf("%s has to be an integer value between 0 and 1", topPProperty)) | |
} | |
endpoint := ic.Endpoint() | |
if endpoint != "" { | |
errorMessages = append(errorMessages, fmt.Sprintf("wrong configuration: %s, not applicable to %s", endpoint, service)) | |
} | |
} | |
if isSagemaker(service) { | |
endpoint := ic.Endpoint() | |
if endpoint == "" { | |
errorMessages = append(errorMessages, fmt.Sprintf("%s cannot be empty", endpointProperty)) | |
} | |
model := ic.Model() | |
if model != "" { | |
errorMessages = append(errorMessages, fmt.Sprintf("wrong configuration: %s, not applicable to %s. did you mean %s", modelProperty, service, targetModelProperty)) | |
} | |
} | |
if len(errorMessages) > 0 { | |
return fmt.Errorf("%s", strings.Join(errorMessages, ", ")) | |
} | |
return nil | |
} | |
func (ic *classSettings) validatAvailableAWSSetting(value string, availableValues []string) bool { | |
for i := range availableValues { | |
if value == availableValues[i] { | |
return true | |
} | |
} | |
return false | |
} | |
func (ic *classSettings) validateAWSSetting(value string, availableValues []string) bool { | |
for i := range availableValues { | |
if value == availableValues[i] { | |
return true | |
} | |
} | |
return false | |
} | |
func (ic *classSettings) getStringProperty(name, defaultValue string) string { | |
if ic.cfg == nil { | |
// we would receive a nil-config on cross-class requests, such as Explore{} | |
return defaultValue | |
} | |
value, ok := ic.cfg.ClassByModuleName("generative-aws")[name] | |
if ok { | |
asString, ok := value.(string) | |
if ok { | |
return asString | |
} | |
} | |
return defaultValue | |
} | |
func (ic *classSettings) getFloatProperty(name string, defaultValue *float64) *float64 { | |
if ic.cfg == nil { | |
// we would receive a nil-config on cross-class requests, such as Explore{} | |
return defaultValue | |
} | |
val, ok := ic.cfg.ClassByModuleName("generative-aws")[name] | |
if ok { | |
asFloat, ok := val.(float64) | |
if ok { | |
return &asFloat | |
} | |
asNumber, ok := val.(json.Number) | |
if ok { | |
asFloat, _ := asNumber.Float64() | |
return &asFloat | |
} | |
asInt, ok := val.(int) | |
if ok { | |
asFloat := float64(asInt) | |
return &asFloat | |
} | |
} | |
return defaultValue | |
} | |
func (ic *classSettings) getIntProperty(name string, defaultValue *int) *int { | |
if ic.cfg == nil { | |
// we would receive a nil-config on cross-class requests, such as Explore{} | |
return defaultValue | |
} | |
val, ok := ic.cfg.ClassByModuleName("generative-cohere")[name] | |
if ok { | |
asInt, ok := val.(int) | |
if ok { | |
return &asInt | |
} | |
asFloat, ok := val.(float64) | |
if ok { | |
asInt := int(asFloat) | |
return &asInt | |
} | |
asNumber, ok := val.(json.Number) | |
if ok { | |
asFloat, _ := asNumber.Float64() | |
asInt := int(asFloat) | |
return &asInt | |
} | |
wrongVal := -1 | |
return &wrongVal | |
} | |
if defaultValue != nil { | |
return defaultValue | |
} | |
return nil | |
} | |
func (ic *classSettings) getListOfStringsProperty(name string, defaultValue []string) *[]string { | |
if ic.cfg == nil { | |
// we would receive a nil-config on cross-class requests, such as Explore{} | |
return &defaultValue | |
} | |
model, ok := ic.cfg.ClassByModuleName("generative-aws")[name] | |
if ok { | |
asStringList, ok := model.([]string) | |
if ok { | |
return &asStringList | |
} | |
var empty []string | |
return &empty | |
} | |
return &defaultValue | |
} | |
// AWS params | |
func (ic *classSettings) Service() string { | |
return ic.getStringProperty(serviceProperty, DefaultService) | |
} | |
func (ic *classSettings) Region() string { | |
return ic.getStringProperty(regionProperty, "") | |
} | |
func (ic *classSettings) Model() string { | |
return ic.getStringProperty(modelProperty, "") | |
} | |
func (ic *classSettings) MaxTokenCount() *int { | |
if isBedrock(ic.Service()) { | |
if isAmazonModel(ic.Model()) { | |
return ic.getIntProperty(maxTokenCountProperty, &DefaultTitanMaxTokens) | |
} | |
if isAnthropicModel(ic.Model()) { | |
return ic.getIntProperty(maxTokensToSampleProperty, &DefaultAnthropicMaxTokensToSample) | |
} | |
if isAI21Model(ic.Model()) { | |
return ic.getIntProperty(maxTokenCountProperty, &DefaultAI21MaxTokens) | |
} | |
if isCohereModel(ic.Model()) { | |
return ic.getIntProperty(maxTokenCountProperty, &DefaultCohereMaxTokens) | |
} | |
} | |
return ic.getIntProperty(maxTokenCountProperty, nil) | |
} | |
func (ic *classSettings) StopSequences() []string { | |
if isBedrock(ic.Service()) { | |
if isAmazonModel(ic.Model()) { | |
return *ic.getListOfStringsProperty(stopSequencesProperty, DefaultTitanStopSequences) | |
} | |
if isAnthropicModel(ic.Model()) { | |
return *ic.getListOfStringsProperty(stopSequencesProperty, DefaultAnthropicStopSequences) | |
} | |
} | |
return *ic.getListOfStringsProperty(stopSequencesProperty, nil) | |
} | |
func (ic *classSettings) Temperature() *float64 { | |
if isBedrock(ic.Service()) { | |
if isAmazonModel(ic.Model()) { | |
return ic.getFloatProperty(temperatureProperty, &DefaultTitanTemperature) | |
} | |
if isAnthropicModel(ic.Model()) { | |
return ic.getFloatProperty(temperatureProperty, &DefaultAnthropicTemperature) | |
} | |
if isCohereModel(ic.Model()) { | |
return ic.getFloatProperty(temperatureProperty, &DefaultCohereTemperature) | |
} | |
if isAI21Model(ic.Model()) { | |
return ic.getFloatProperty(temperatureProperty, &DefaultAI21Temperature) | |
} | |
} | |
return ic.getFloatProperty(temperatureProperty, nil) | |
} | |
func (ic *classSettings) TopP() *float64 { | |
if isBedrock(ic.Service()) { | |
if isAmazonModel(ic.Model()) { | |
return ic.getFloatProperty(topPProperty, &DefaultTitanTopP) | |
} | |
if isAnthropicModel(ic.Model()) { | |
return ic.getFloatProperty(topPProperty, &DefaultAnthropicTopP) | |
} | |
if isCohereModel(ic.Model()) { | |
return ic.getFloatProperty(topPProperty, &DefaultCohereTopP) | |
} | |
} | |
return ic.getFloatProperty(topPProperty, nil) | |
} | |
func (ic *classSettings) TopK() *int { | |
if isBedrock(ic.Service()) { | |
if isAnthropicModel(ic.Model()) { | |
return ic.getIntProperty(topKProperty, &DefaultAnthropicTopK) | |
} | |
} | |
return ic.getIntProperty(topKProperty, nil) | |
} | |
func (ic *classSettings) Endpoint() string { | |
return ic.getStringProperty(endpointProperty, "") | |
} | |
func (ic *classSettings) TargetModel() string { | |
return ic.getStringProperty(targetModelProperty, "") | |
} | |
func (ic *classSettings) TargetVariant() string { | |
return ic.getStringProperty(targetVariantProperty, "") | |
} | |
func isSagemaker(service string) bool { | |
return service == "sagemaker" | |
} | |
func isBedrock(service string) bool { | |
return service == "bedrock" | |
} | |
func isAmazonModel(model string) bool { | |
return strings.HasPrefix(model, "amazon") | |
} | |
func isAI21Model(model string) bool { | |
return strings.HasPrefix(model, "ai21") | |
} | |
func isAnthropicModel(model string) bool { | |
return strings.HasPrefix(model, "anthropic") | |
} | |
func isCohereModel(model string) bool { | |
return strings.HasPrefix(model, "cohere") | |
} | |