Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package config | |
import ( | |
"encoding/json" | |
"fmt" | |
"strings" | |
"github.com/pkg/errors" | |
"github.com/weaviate/weaviate/entities/models" | |
"github.com/weaviate/weaviate/entities/moduletools" | |
) | |
const ( | |
apiEndpointProperty = "apiEndpoint" | |
projectIDProperty = "projectId" | |
endpointIDProperty = "endpointId" | |
modelIDProperty = "modelId" | |
temperatureProperty = "temperature" | |
tokenLimitProperty = "tokenLimit" | |
topPProperty = "topP" | |
topKProperty = "topK" | |
) | |
var ( | |
DefaultPaLMApiEndpoint = "us-central1-aiplatform.googleapis.com" | |
DefaultPaLMModel = "chat-bison" | |
DefaultPaLMTemperature = 0.2 | |
DefaultTokenLimit = 256 | |
DefaultPaLMTopP = 0.95 | |
DefaultPaLMTopK = 40 | |
DefaulGenerativeAIApiEndpoint = "generativelanguage.googleapis.com" | |
DefaulGenerativeAIModelID = "chat-bison-001" | |
) | |
var supportedGenerativeAIModels = []string{ | |
DefaulGenerativeAIModelID, | |
"gemini-pro", | |
} | |
type ClassSettings interface { | |
Validate(class *models.Class) error | |
// Module settings | |
ApiEndpoint() string | |
ProjectID() string | |
EndpointID() string | |
ModelID() string | |
// parameters | |
// 0.0 - 1.0 | |
Temperature() float64 | |
// 1 - 1024 | |
TokenLimit() int | |
// 1 - 40 | |
TopK() int | |
// 0.0 - 1.0 | |
TopP() float64 | |
} | |
type classSettings struct { | |
cfg moduletools.ClassConfig | |
} | |
func NewClassSettings(cfg moduletools.ClassConfig) ClassSettings { | |
return &classSettings{cfg: cfg} | |
} | |
func (ic *classSettings) Validate(class *models.Class) error { | |
if ic.cfg == nil { | |
// we would receive a nil-config on cross-class requests, such as Explore{} | |
return errors.New("empty config") | |
} | |
var errorMessages []string | |
apiEndpoint := ic.ApiEndpoint() | |
projectID := ic.ProjectID() | |
if apiEndpoint != DefaulGenerativeAIApiEndpoint && projectID == "" { | |
errorMessages = append(errorMessages, fmt.Sprintf("%s cannot be empty", projectIDProperty)) | |
} | |
temperature := ic.Temperature() | |
if temperature < 0 || temperature > 1 { | |
errorMessages = append(errorMessages, fmt.Sprintf("%s has to be float value between 0 and 1", temperatureProperty)) | |
} | |
tokenLimit := ic.TokenLimit() | |
if tokenLimit < 1 || tokenLimit > 1024 { | |
errorMessages = append(errorMessages, fmt.Sprintf("%s has to be an integer value between 1 and 1024", tokenLimitProperty)) | |
} | |
topK := ic.TopK() | |
if topK < 1 || topK > 40 { | |
errorMessages = append(errorMessages, fmt.Sprintf("%s has to be an integer value between 1 and 40", topKProperty)) | |
} | |
topP := ic.TopP() | |
if topP < 0 || topP > 1 { | |
errorMessages = append(errorMessages, fmt.Sprintf("%s has to be float value between 0 and 1", topPProperty)) | |
} | |
// Google MakerSuite | |
model := ic.ModelID() | |
if apiEndpoint == DefaulGenerativeAIApiEndpoint && !contains[string](supportedGenerativeAIModels, model) { | |
errorMessages = append(errorMessages, fmt.Sprintf("%s is not supported available models are: %+v", model, supportedGenerativeAIModels)) | |
} | |
if len(errorMessages) > 0 { | |
return fmt.Errorf("%s", strings.Join(errorMessages, ", ")) | |
} | |
return nil | |
} | |
func (ic *classSettings) getStringProperty(name, defaultValue string) string { | |
if ic.cfg == nil { | |
// we would receive a nil-config on cross-class requests, such as Explore{} | |
return defaultValue | |
} | |
value, ok := ic.cfg.ClassByModuleName("generative-palm")[name] | |
if ok { | |
asString, ok := value.(string) | |
if ok { | |
return asString | |
} | |
} | |
return defaultValue | |
} | |
func (ic *classSettings) getFloatProperty(name string, defaultValue float64) float64 { | |
if ic.cfg == nil { | |
// we would receive a nil-config on cross-class requests, such as Explore{} | |
return defaultValue | |
} | |
val, ok := ic.cfg.ClassByModuleName("generative-palm")[name] | |
if ok { | |
asFloat, ok := val.(float64) | |
if ok { | |
return asFloat | |
} | |
asNumber, ok := val.(json.Number) | |
if ok { | |
asFloat, _ := asNumber.Float64() | |
return asFloat | |
} | |
asInt, ok := val.(int) | |
if ok { | |
asFloat := float64(asInt) | |
return asFloat | |
} | |
} | |
return defaultValue | |
} | |
func (ic *classSettings) getIntProperty(name string, defaultValue int) int { | |
if ic.cfg == nil { | |
// we would receive a nil-config on cross-class requests, such as Explore{} | |
return defaultValue | |
} | |
val, ok := ic.cfg.ClassByModuleName("generative-palm")[name] | |
if ok { | |
asFloat, ok := val.(float64) | |
if ok { | |
return int(asFloat) | |
} | |
asNumber, ok := val.(json.Number) | |
if ok { | |
asInt64, _ := asNumber.Int64() | |
return int(asInt64) | |
} | |
asInt, ok := val.(int) | |
if ok { | |
return asInt | |
} | |
} | |
return defaultValue | |
} | |
func (ic *classSettings) getDefaultModel(apiEndpoint string) string { | |
if apiEndpoint == DefaulGenerativeAIApiEndpoint { | |
return DefaulGenerativeAIModelID | |
} | |
return DefaultPaLMModel | |
} | |
// PaLM params | |
func (ic *classSettings) ApiEndpoint() string { | |
return ic.getStringProperty(apiEndpointProperty, DefaultPaLMApiEndpoint) | |
} | |
func (ic *classSettings) ProjectID() string { | |
return ic.getStringProperty(projectIDProperty, "") | |
} | |
func (ic *classSettings) EndpointID() string { | |
return ic.getStringProperty(endpointIDProperty, "") | |
} | |
func (ic *classSettings) ModelID() string { | |
return ic.getStringProperty(modelIDProperty, ic.getDefaultModel(ic.ApiEndpoint())) | |
} | |
// parameters | |
// 0.0 - 1.0 | |
func (ic *classSettings) Temperature() float64 { | |
return ic.getFloatProperty(temperatureProperty, DefaultPaLMTemperature) | |
} | |
// 1 - 1024 | |
func (ic *classSettings) TokenLimit() int { | |
return ic.getIntProperty(tokenLimitProperty, DefaultTokenLimit) | |
} | |
// 1 - 40 | |
func (ic *classSettings) TopK() int { | |
return ic.getIntProperty(topKProperty, DefaultPaLMTopK) | |
} | |
// 0.0 - 1.0 | |
func (ic *classSettings) TopP() float64 { | |
return ic.getFloatProperty(topPProperty, DefaultPaLMTopP) | |
} | |
func contains[T comparable](s []T, e T) bool { | |
for _, v := range s { | |
if v == e { | |
return true | |
} | |
} | |
return false | |
} | |