Spaces:
Sleeping
Sleeping
| // _ _ | |
| // __ _____ __ ___ ___ __ _| |_ ___ | |
| // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
| // \ V V / __/ (_| |\ V /| | (_| | || __/ | |
| // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
| // | |
| // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
| // | |
| // CONTACT: [email protected] | |
| // | |
| package clients | |
| import ( | |
| "bytes" | |
| "context" | |
| "encoding/json" | |
| "fmt" | |
| "io" | |
| "net/http" | |
| "regexp" | |
| "strings" | |
| "time" | |
| "github.com/weaviate/weaviate/usecases/modulecomponents" | |
| "github.com/pkg/errors" | |
| "github.com/sirupsen/logrus" | |
| "github.com/weaviate/weaviate/entities/moduletools" | |
| "github.com/weaviate/weaviate/modules/generative-anyscale/config" | |
| generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models" | |
| ) | |
| var compile, _ = regexp.Compile(`{([\w\s]*?)}`) | |
| type anyscale struct { | |
| apiKey string | |
| httpClient *http.Client | |
| logger logrus.FieldLogger | |
| } | |
| func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *anyscale { | |
| return &anyscale{ | |
| apiKey: apiKey, | |
| httpClient: &http.Client{ | |
| Timeout: timeout, | |
| }, | |
| logger: logger, | |
| } | |
| } | |
| func (v *anyscale) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { | |
| forPrompt, err := v.generateForPrompt(textProperties, prompt) | |
| if err != nil { | |
| return nil, err | |
| } | |
| return v.Generate(ctx, cfg, forPrompt) | |
| } | |
| func (v *anyscale) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { | |
| forTask, err := v.generatePromptForTask(textProperties, task) | |
| if err != nil { | |
| return nil, err | |
| } | |
| return v.Generate(ctx, cfg, forTask) | |
| } | |
| func (v *anyscale) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) { | |
| settings := config.NewClassSettings(cfg) | |
| anyscaleUrl := v.getAnyscaleUrl(ctx, settings.BaseURL()) | |
| anyscalePrompt := []map[string]string{ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": prompt}, | |
| } | |
| input := generateInput{ | |
| Messages: anyscalePrompt, | |
| Model: settings.Model(), | |
| Temperature: settings.Temperature(), | |
| } | |
| body, err := json.Marshal(input) | |
| if err != nil { | |
| return nil, errors.Wrap(err, "marshal body") | |
| } | |
| req, err := http.NewRequestWithContext(ctx, "POST", anyscaleUrl, | |
| bytes.NewReader(body)) | |
| if err != nil { | |
| return nil, errors.Wrap(err, "create POST request") | |
| } | |
| apiKey, err := v.getApiKey(ctx) | |
| if err != nil { | |
| return nil, errors.Wrapf(err, "Anyscale (OpenAI) API Key") | |
| } | |
| req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey)) | |
| req.Header.Add("Content-Type", "application/json") | |
| res, err := v.httpClient.Do(req) | |
| if err != nil { | |
| return nil, errors.Wrap(err, "send POST request") | |
| } | |
| defer res.Body.Close() | |
| bodyBytes, err := io.ReadAll(res.Body) | |
| if err != nil { | |
| return nil, errors.Wrap(err, "read response body") | |
| } | |
| var resBody generateResponse | |
| if err := json.Unmarshal(bodyBytes, &resBody); err != nil { | |
| return nil, errors.Wrap(err, "unmarshal response body") | |
| } | |
| if res.StatusCode != 200 || resBody.Error != nil { | |
| if resBody.Error != nil { | |
| return nil, errors.Errorf("connection to Anyscale API failed with status: %d error: %v", res.StatusCode, resBody.Error.Message) | |
| } | |
| return nil, errors.Errorf("connection to Anyscale API failed with status: %d", res.StatusCode) | |
| } | |
| textResponse := resBody.Choices[0].Message.Content | |
| return &generativemodels.GenerateResponse{ | |
| Result: &textResponse, | |
| }, nil | |
| } | |
| func (v *anyscale) getAnyscaleUrl(ctx context.Context, baseURL string) string { | |
| passedBaseURL := baseURL | |
| if headerBaseURL := v.getValueFromContext(ctx, "X-Anyscale-Baseurl"); headerBaseURL != "" { | |
| passedBaseURL = headerBaseURL | |
| } | |
| return fmt.Sprintf("%s/v1/chat/completions", passedBaseURL) | |
| } | |
| func (v *anyscale) generatePromptForTask(textProperties []map[string]string, task string) (string, error) { | |
| marshal, err := json.Marshal(textProperties) | |
| if err != nil { | |
| return "", err | |
| } | |
| return fmt.Sprintf(`'%v: | |
| %v`, task, string(marshal)), nil | |
| } | |
| func (v *anyscale) generateForPrompt(textProperties map[string]string, prompt string) (string, error) { | |
| all := compile.FindAll([]byte(prompt), -1) | |
| for _, match := range all { | |
| originalProperty := string(match) | |
| replacedProperty := compile.FindStringSubmatch(originalProperty)[1] | |
| replacedProperty = strings.TrimSpace(replacedProperty) | |
| value := textProperties[replacedProperty] | |
| if value == "" { | |
| return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty) | |
| } | |
| prompt = strings.ReplaceAll(prompt, originalProperty, value) | |
| } | |
| return prompt, nil | |
| } | |
| func (v *anyscale) getValueFromContext(ctx context.Context, key string) string { | |
| if value := ctx.Value(key); value != nil { | |
| if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 { | |
| return keyHeader[0] | |
| } | |
| } | |
| // try getting header from GRPC if not successful | |
| if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 { | |
| return apiKey[0] | |
| } | |
| return "" | |
| } | |
| func (v *anyscale) getApiKey(ctx context.Context) (string, error) { | |
| // note Anyscale uses the OpenAI API Key in it's requests. | |
| if apiKey := v.getValueFromContext(ctx, "X-Anyscale-Api-Key"); apiKey != "" { | |
| return apiKey, nil | |
| } | |
| if v.apiKey != "" { | |
| return v.apiKey, nil | |
| } | |
| return "", errors.New("no api key found " + | |
| "neither in request header: X-Anyscale-Api-Key " + | |
| "nor in environment variable under ANYSCALE_APIKEY") | |
| } | |
| type generateInput struct { | |
| Model string `json:"model"` | |
| Messages []map[string]string `json:"messages"` | |
| Temperature int `json:"temperature"` | |
| } | |
| type Message struct { | |
| Role string `json:"role"` | |
| Content string `json:"content"` | |
| } | |
| type Choice struct { | |
| Message Message `json:"message"` | |
| Index int `json:"index"` | |
| FinishReason string `json:"finish_reason"` | |
| } | |
| // The entire response for an error ends up looking different, may want to add omitempty everywhere. | |
| type generateResponse struct { | |
| ID string `json:"id"` | |
| Object string `json:"object"` | |
| Created int64 `json:"created"` | |
| Model string `json:"model"` | |
| Choices []Choice `json:"choices"` | |
| Usage map[string]int `json:"usage"` | |
| Error *anyscaleApiError `json:"error,omitempty"` | |
| } | |
| type anyscaleApiError struct { | |
| Message string `json:"message"` | |
| } | |