KevinStephenson
Adding in weaviate code
b110593
raw
history blame
10.3 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package config
import (
"encoding/json"
"fmt"
"strings"
"github.com/pkg/errors"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/moduletools"
)
const (
serviceProperty = "service"
regionProperty = "region"
modelProperty = "model"
endpointProperty = "endpoint"
targetModelProperty = "targetModel"
targetVariantProperty = "targetVariant"
maxTokenCountProperty = "maxTokenCount"
maxTokensToSampleProperty = "maxTokensToSample"
stopSequencesProperty = "stopSequences"
temperatureProperty = "temperature"
topPProperty = "topP"
topKProperty = "topK"
)
var (
DefaultTitanMaxTokens = 8192
DefaultTitanStopSequences = []string{}
DefaultTitanTemperature = 0.0
DefaultTitanTopP = 1.0
DefaultService = "bedrock"
)
var (
DefaultAnthropicMaxTokensToSample = 300
DefaultAnthropicStopSequences = []string{"\\n\\nHuman:"}
DefaultAnthropicTemperature = 1.0
DefaultAnthropicTopK = 250
DefaultAnthropicTopP = 0.999
)
var DefaultAI21MaxTokens = 300
var (
DefaultCohereMaxTokens = 100
DefaultCohereTemperature = 0.8
DefaultAI21Temperature = 0.7
DefaultCohereTopP = 1.0
)
var availableAWSServices = []string{
DefaultService,
}
var availableBedrockModels = []string{
"cohere.command-text-v14",
"cohere.command-light-text-v14",
}
type classSettings struct {
cfg moduletools.ClassConfig
}
func NewClassSettings(cfg moduletools.ClassConfig) *classSettings {
return &classSettings{cfg: cfg}
}
func (ic *classSettings) Validate(class *models.Class) error {
if ic.cfg == nil {
// we would receive a nil-config on cross-class requests, such as Explore{}
return errors.New("empty config")
}
var errorMessages []string
service := ic.Service()
if service == "" || !ic.validatAvailableAWSSetting(service, availableAWSServices) {
errorMessages = append(errorMessages, fmt.Sprintf("wrong %s, available services are: %v", serviceProperty, availableAWSServices))
}
region := ic.Region()
if region == "" {
errorMessages = append(errorMessages, fmt.Sprintf("%s cannot be empty", regionProperty))
}
if isBedrock(service) {
model := ic.Model()
if model == "" && !ic.validateAWSSetting(model, availableBedrockModels) {
errorMessages = append(errorMessages, fmt.Sprintf("wrong %s: %s, available model names are: %v", modelProperty, model, availableBedrockModels))
}
maxTokenCount := ic.MaxTokenCount()
if *maxTokenCount < 1 || *maxTokenCount > 8192 {
errorMessages = append(errorMessages, fmt.Sprintf("%s has to be an integer value between 1 and 8096", maxTokenCountProperty))
}
temperature := ic.Temperature()
if *temperature < 0 || *temperature > 1 {
errorMessages = append(errorMessages, fmt.Sprintf("%s has to be float value between 0 and 1", temperatureProperty))
}
topP := ic.TopP()
if topP != nil && (*topP < 0 || *topP > 1) {
errorMessages = append(errorMessages, fmt.Sprintf("%s has to be an integer value between 0 and 1", topPProperty))
}
endpoint := ic.Endpoint()
if endpoint != "" {
errorMessages = append(errorMessages, fmt.Sprintf("wrong configuration: %s, not applicable to %s", endpoint, service))
}
}
if isSagemaker(service) {
endpoint := ic.Endpoint()
if endpoint == "" {
errorMessages = append(errorMessages, fmt.Sprintf("%s cannot be empty", endpointProperty))
}
model := ic.Model()
if model != "" {
errorMessages = append(errorMessages, fmt.Sprintf("wrong configuration: %s, not applicable to %s. did you mean %s", modelProperty, service, targetModelProperty))
}
}
if len(errorMessages) > 0 {
return fmt.Errorf("%s", strings.Join(errorMessages, ", "))
}
return nil
}
func (ic *classSettings) validatAvailableAWSSetting(value string, availableValues []string) bool {
for i := range availableValues {
if value == availableValues[i] {
return true
}
}
return false
}
func (ic *classSettings) validateAWSSetting(value string, availableValues []string) bool {
for i := range availableValues {
if value == availableValues[i] {
return true
}
}
return false
}
func (ic *classSettings) getStringProperty(name, defaultValue string) string {
if ic.cfg == nil {
// we would receive a nil-config on cross-class requests, such as Explore{}
return defaultValue
}
value, ok := ic.cfg.ClassByModuleName("generative-aws")[name]
if ok {
asString, ok := value.(string)
if ok {
return asString
}
}
return defaultValue
}
func (ic *classSettings) getFloatProperty(name string, defaultValue *float64) *float64 {
if ic.cfg == nil {
// we would receive a nil-config on cross-class requests, such as Explore{}
return defaultValue
}
val, ok := ic.cfg.ClassByModuleName("generative-aws")[name]
if ok {
asFloat, ok := val.(float64)
if ok {
return &asFloat
}
asNumber, ok := val.(json.Number)
if ok {
asFloat, _ := asNumber.Float64()
return &asFloat
}
asInt, ok := val.(int)
if ok {
asFloat := float64(asInt)
return &asFloat
}
}
return defaultValue
}
func (ic *classSettings) getIntProperty(name string, defaultValue *int) *int {
if ic.cfg == nil {
// we would receive a nil-config on cross-class requests, such as Explore{}
return defaultValue
}
val, ok := ic.cfg.ClassByModuleName("generative-cohere")[name]
if ok {
asInt, ok := val.(int)
if ok {
return &asInt
}
asFloat, ok := val.(float64)
if ok {
asInt := int(asFloat)
return &asInt
}
asNumber, ok := val.(json.Number)
if ok {
asFloat, _ := asNumber.Float64()
asInt := int(asFloat)
return &asInt
}
wrongVal := -1
return &wrongVal
}
if defaultValue != nil {
return defaultValue
}
return nil
}
func (ic *classSettings) getListOfStringsProperty(name string, defaultValue []string) *[]string {
if ic.cfg == nil {
// we would receive a nil-config on cross-class requests, such as Explore{}
return &defaultValue
}
model, ok := ic.cfg.ClassByModuleName("generative-aws")[name]
if ok {
asStringList, ok := model.([]string)
if ok {
return &asStringList
}
var empty []string
return &empty
}
return &defaultValue
}
// AWS params
func (ic *classSettings) Service() string {
return ic.getStringProperty(serviceProperty, DefaultService)
}
func (ic *classSettings) Region() string {
return ic.getStringProperty(regionProperty, "")
}
func (ic *classSettings) Model() string {
return ic.getStringProperty(modelProperty, "")
}
func (ic *classSettings) MaxTokenCount() *int {
if isBedrock(ic.Service()) {
if isAmazonModel(ic.Model()) {
return ic.getIntProperty(maxTokenCountProperty, &DefaultTitanMaxTokens)
}
if isAnthropicModel(ic.Model()) {
return ic.getIntProperty(maxTokensToSampleProperty, &DefaultAnthropicMaxTokensToSample)
}
if isAI21Model(ic.Model()) {
return ic.getIntProperty(maxTokenCountProperty, &DefaultAI21MaxTokens)
}
if isCohereModel(ic.Model()) {
return ic.getIntProperty(maxTokenCountProperty, &DefaultCohereMaxTokens)
}
}
return ic.getIntProperty(maxTokenCountProperty, nil)
}
func (ic *classSettings) StopSequences() []string {
if isBedrock(ic.Service()) {
if isAmazonModel(ic.Model()) {
return *ic.getListOfStringsProperty(stopSequencesProperty, DefaultTitanStopSequences)
}
if isAnthropicModel(ic.Model()) {
return *ic.getListOfStringsProperty(stopSequencesProperty, DefaultAnthropicStopSequences)
}
}
return *ic.getListOfStringsProperty(stopSequencesProperty, nil)
}
func (ic *classSettings) Temperature() *float64 {
if isBedrock(ic.Service()) {
if isAmazonModel(ic.Model()) {
return ic.getFloatProperty(temperatureProperty, &DefaultTitanTemperature)
}
if isAnthropicModel(ic.Model()) {
return ic.getFloatProperty(temperatureProperty, &DefaultAnthropicTemperature)
}
if isCohereModel(ic.Model()) {
return ic.getFloatProperty(temperatureProperty, &DefaultCohereTemperature)
}
if isAI21Model(ic.Model()) {
return ic.getFloatProperty(temperatureProperty, &DefaultAI21Temperature)
}
}
return ic.getFloatProperty(temperatureProperty, nil)
}
func (ic *classSettings) TopP() *float64 {
if isBedrock(ic.Service()) {
if isAmazonModel(ic.Model()) {
return ic.getFloatProperty(topPProperty, &DefaultTitanTopP)
}
if isAnthropicModel(ic.Model()) {
return ic.getFloatProperty(topPProperty, &DefaultAnthropicTopP)
}
if isCohereModel(ic.Model()) {
return ic.getFloatProperty(topPProperty, &DefaultCohereTopP)
}
}
return ic.getFloatProperty(topPProperty, nil)
}
func (ic *classSettings) TopK() *int {
if isBedrock(ic.Service()) {
if isAnthropicModel(ic.Model()) {
return ic.getIntProperty(topKProperty, &DefaultAnthropicTopK)
}
}
return ic.getIntProperty(topKProperty, nil)
}
func (ic *classSettings) Endpoint() string {
return ic.getStringProperty(endpointProperty, "")
}
func (ic *classSettings) TargetModel() string {
return ic.getStringProperty(targetModelProperty, "")
}
func (ic *classSettings) TargetVariant() string {
return ic.getStringProperty(targetVariantProperty, "")
}
func isSagemaker(service string) bool {
return service == "sagemaker"
}
func isBedrock(service string) bool {
return service == "bedrock"
}
func isAmazonModel(model string) bool {
return strings.HasPrefix(model, "amazon")
}
func isAI21Model(model string) bool {
return strings.HasPrefix(model, "ai21")
}
func isAnthropicModel(model string) bool {
return strings.HasPrefix(model, "anthropic")
}
func isCohereModel(model string) bool {
return strings.HasPrefix(model, "cohere")
}