Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package clients | |
import ( | |
"bytes" | |
"context" | |
"encoding/json" | |
"fmt" | |
"io" | |
"net/http" | |
"regexp" | |
"strings" | |
"time" | |
"github.com/weaviate/weaviate/usecases/modulecomponents" | |
"github.com/pkg/errors" | |
"github.com/sirupsen/logrus" | |
"github.com/weaviate/weaviate/entities/moduletools" | |
"github.com/weaviate/weaviate/modules/generative-palm/config" | |
generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models" | |
) | |
type harmCategory string | |
var ( | |
// Category is unspecified. | |
HarmCategoryUnspecified harmCategory = "HARM_CATEGORY_UNSPECIFIED" | |
// Negative or harmful comments targeting identity and/or protected attribute. | |
HarmCategoryDerogatory harmCategory = "HARM_CATEGORY_DEROGATORY" | |
// Content that is rude, disrepspectful, or profane. | |
HarmCategoryToxicity harmCategory = "HARM_CATEGORY_TOXICITY" | |
// Describes scenarios depictng violence against an individual or group, or general descriptions of gore. | |
HarmCategoryViolence harmCategory = "HARM_CATEGORY_VIOLENCE" | |
// Contains references to sexual acts or other lewd content. | |
HarmCategorySexual harmCategory = "HARM_CATEGORY_SEXUAL" | |
// Promotes unchecked medical advice. | |
HarmCategoryMedical harmCategory = "HARM_CATEGORY_MEDICAL" | |
// Dangerous content that promotes, facilitates, or encourages harmful acts. | |
HarmCategoryDangerous harmCategory = "HARM_CATEGORY_DANGEROUS" | |
// Harassment content. | |
HarmCategoryHarassment harmCategory = "HARM_CATEGORY_HARASSMENT" | |
// Hate speech and content. | |
HarmCategoryHate_speech harmCategory = "HARM_CATEGORY_HATE_SPEECH" | |
// Sexually explicit content. | |
HarmCategorySexually_explicit harmCategory = "HARM_CATEGORY_SEXUALLY_EXPLICIT" | |
// Dangerous content. | |
HarmCategoryDangerous_content harmCategory = "HARM_CATEGORY_DANGEROUS_CONTENT" | |
) | |
type harmBlockThreshold string | |
var ( | |
// Threshold is unspecified. | |
HarmBlockThresholdUnspecified harmBlockThreshold = "HARM_BLOCK_THRESHOLD_UNSPECIFIED" | |
// Content with NEGLIGIBLE will be allowed. | |
BlockLowAndAbove harmBlockThreshold = "BLOCK_LOW_AND_ABOVE" | |
// Content with NEGLIGIBLE and LOW will be allowed. | |
BlockMediumAndAbove harmBlockThreshold = "BLOCK_MEDIUM_AND_ABOVE" | |
// Content with NEGLIGIBLE, LOW, and MEDIUM will be allowed. | |
BlockOnlyHigh harmBlockThreshold = "BLOCK_ONLY_HIGH" | |
// All content will be allowed. | |
BlockNone harmBlockThreshold = "BLOCK_NONE" | |
) | |
type harmProbability string | |
var ( | |
// Probability is unspecified. | |
HARM_PROBABILITY_UNSPECIFIED harmProbability = "HARM_PROBABILITY_UNSPECIFIED" | |
// Content has a negligible chance of being unsafe. | |
NEGLIGIBLE harmProbability = "NEGLIGIBLE" | |
// Content has a low chance of being unsafe. | |
LOW harmProbability = "LOW" | |
// Content has a medium chance of being unsafe. | |
MEDIUM harmProbability = "MEDIUM" | |
// Content has a high chance of being unsafe. | |
HIGH harmProbability = "HIGH" | |
) | |
var compile, _ = regexp.Compile(`{([\w\s]*?)}`) | |
func buildURL(useGenerativeAI bool, apiEndoint, projectID, modelID string) string { | |
if useGenerativeAI { | |
// Generative AI endpoints, for more context check out this link: | |
// https://developers.generativeai.google/models/language#model_variations | |
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage | |
if strings.HasPrefix(modelID, "gemini") { | |
return fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent", modelID) | |
} | |
return "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" | |
} | |
urlTemplate := "https://%s/v1/projects/%s/locations/us-central1/publishers/google/models/%s:predict" | |
return fmt.Sprintf(urlTemplate, apiEndoint, projectID, modelID) | |
} | |
type palm struct { | |
apiKey string | |
buildUrlFn func(useGenerativeAI bool, apiEndoint, projectID, modelID string) string | |
httpClient *http.Client | |
logger logrus.FieldLogger | |
} | |
func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *palm { | |
return &palm{ | |
apiKey: apiKey, | |
httpClient: &http.Client{ | |
Timeout: timeout, | |
}, | |
buildUrlFn: buildURL, | |
logger: logger, | |
} | |
} | |
func (v *palm) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { | |
forPrompt, err := v.generateForPrompt(textProperties, prompt) | |
if err != nil { | |
return nil, err | |
} | |
return v.Generate(ctx, cfg, forPrompt) | |
} | |
func (v *palm) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { | |
forTask, err := v.generatePromptForTask(textProperties, task) | |
if err != nil { | |
return nil, err | |
} | |
return v.Generate(ctx, cfg, forTask) | |
} | |
func (v *palm) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) { | |
settings := config.NewClassSettings(cfg) | |
useGenerativeAIEndpoint := v.useGenerativeAIEndpoint(settings.ApiEndpoint()) | |
modelID := settings.ModelID() | |
if settings.EndpointID() != "" { | |
modelID = settings.EndpointID() | |
} | |
endpointURL := v.buildUrlFn(useGenerativeAIEndpoint, settings.ApiEndpoint(), settings.ProjectID(), modelID) | |
input := v.getPayload(useGenerativeAIEndpoint, prompt, settings) | |
body, err := json.Marshal(input) | |
if err != nil { | |
return nil, errors.Wrap(err, "marshal body") | |
} | |
req, err := http.NewRequestWithContext(ctx, "POST", endpointURL, | |
bytes.NewReader(body)) | |
if err != nil { | |
return nil, errors.Wrap(err, "create POST request") | |
} | |
apiKey, err := v.getApiKey(ctx) | |
if err != nil { | |
return nil, errors.Wrapf(err, "PaLM API Key") | |
} | |
req.Header.Add("Content-Type", "application/json") | |
if useGenerativeAIEndpoint { | |
req.Header.Add("x-goog-api-key", apiKey) | |
} else { | |
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey)) | |
} | |
res, err := v.httpClient.Do(req) | |
if err != nil { | |
return nil, errors.Wrap(err, "send POST request") | |
} | |
defer res.Body.Close() | |
bodyBytes, err := io.ReadAll(res.Body) | |
if err != nil { | |
return nil, errors.Wrap(err, "read response body") | |
} | |
if useGenerativeAIEndpoint { | |
if strings.HasPrefix(modelID, "gemini") { | |
return v.parseGenerateContentResponse(res.StatusCode, bodyBytes) | |
} | |
return v.parseGenerateMessageResponse(res.StatusCode, bodyBytes) | |
} | |
return v.parseResponse(res.StatusCode, bodyBytes) | |
} | |
func (v *palm) parseGenerateMessageResponse(statusCode int, bodyBytes []byte) (*generativemodels.GenerateResponse, error) { | |
var resBody generateMessageResponse | |
if err := json.Unmarshal(bodyBytes, &resBody); err != nil { | |
return nil, errors.Wrap(err, "unmarshal response body") | |
} | |
if err := v.checkResponse(statusCode, resBody.Error); err != nil { | |
return nil, err | |
} | |
if len(resBody.Candidates) > 0 { | |
return v.getGenerateResponse(resBody.Candidates[0].Content) | |
} | |
return &generativemodels.GenerateResponse{ | |
Result: nil, | |
}, nil | |
} | |
func (v *palm) parseGenerateContentResponse(statusCode int, bodyBytes []byte) (*generativemodels.GenerateResponse, error) { | |
var resBody generateContentResponse | |
if err := json.Unmarshal(bodyBytes, &resBody); err != nil { | |
return nil, errors.Wrap(err, "unmarshal response body") | |
} | |
if err := v.checkResponse(statusCode, resBody.Error); err != nil { | |
return nil, err | |
} | |
if len(resBody.Candidates) > 0 && len(resBody.Candidates[0].Content.Parts) > 0 { | |
return v.getGenerateResponse(resBody.Candidates[0].Content.Parts[0].Text) | |
} | |
return &generativemodels.GenerateResponse{ | |
Result: nil, | |
}, nil | |
} | |
func (v *palm) parseResponse(statusCode int, bodyBytes []byte) (*generativemodels.GenerateResponse, error) { | |
var resBody generateResponse | |
if err := json.Unmarshal(bodyBytes, &resBody); err != nil { | |
return nil, errors.Wrap(err, "unmarshal response body") | |
} | |
if err := v.checkResponse(statusCode, resBody.Error); err != nil { | |
return nil, err | |
} | |
if len(resBody.Predictions) > 0 && len(resBody.Predictions[0].Candidates) > 0 { | |
return v.getGenerateResponse(resBody.Predictions[0].Candidates[0].Content) | |
} | |
return &generativemodels.GenerateResponse{ | |
Result: nil, | |
}, nil | |
} | |
func (v *palm) getGenerateResponse(content string) (*generativemodels.GenerateResponse, error) { | |
if content != "" { | |
trimmedResponse := strings.Trim(content, "\n") | |
return &generativemodels.GenerateResponse{ | |
Result: &trimmedResponse, | |
}, nil | |
} | |
return &generativemodels.GenerateResponse{ | |
Result: nil, | |
}, nil | |
} | |
func (v *palm) checkResponse(statusCode int, palmApiError *palmApiError) error { | |
if statusCode != 200 || palmApiError != nil { | |
if palmApiError != nil { | |
return fmt.Errorf("connection to Google PaLM failed with status: %v error: %v", | |
statusCode, palmApiError.Message) | |
} | |
return fmt.Errorf("connection to Google PaLM failed with status: %d", statusCode) | |
} | |
return nil | |
} | |
func (v *palm) useGenerativeAIEndpoint(apiEndpoint string) bool { | |
return apiEndpoint == "generativelanguage.googleapis.com" | |
} | |
func (v *palm) getPayload(useGenerativeAI bool, prompt string, settings config.ClassSettings) any { | |
if useGenerativeAI { | |
if strings.HasPrefix(settings.ModelID(), "gemini") { | |
input := generateContentRequest{ | |
Contents: []content{ | |
{ | |
Role: "user", | |
Parts: []part{ | |
{ | |
Text: prompt, | |
}, | |
}, | |
}, | |
}, | |
GenerationConfig: &generationConfig{ | |
Temperature: settings.Temperature(), | |
TopP: settings.TopP(), | |
TopK: settings.TopK(), | |
CandidateCount: 1, | |
}, | |
SafetySettings: []safetySetting{ | |
{ | |
Category: HarmCategoryHarassment, | |
Threshold: BlockMediumAndAbove, | |
}, | |
{ | |
Category: HarmCategoryHate_speech, | |
Threshold: BlockMediumAndAbove, | |
}, | |
{ | |
Category: HarmCategoryDangerous_content, | |
Threshold: BlockMediumAndAbove, | |
}, | |
{ | |
Category: HarmCategoryDangerous_content, | |
Threshold: BlockMediumAndAbove, | |
}, | |
}, | |
} | |
return input | |
} | |
input := generateMessageRequest{ | |
Prompt: &generateMessagePrompt{ | |
Messages: []generateMessage{ | |
{ | |
Content: prompt, | |
}, | |
}, | |
}, | |
Temperature: settings.Temperature(), | |
TopP: settings.TopP(), | |
TopK: settings.TopK(), | |
CandidateCount: 1, | |
} | |
return input | |
} | |
input := generateInput{ | |
Instances: []instance{ | |
{ | |
Messages: []message{ | |
{ | |
Content: prompt, | |
}, | |
}, | |
}, | |
}, | |
Parameters: parameters{ | |
Temperature: settings.Temperature(), | |
MaxOutputTokens: settings.TokenLimit(), | |
TopP: settings.TopP(), | |
TopK: settings.TopK(), | |
}, | |
} | |
return input | |
} | |
func (v *palm) generatePromptForTask(textProperties []map[string]string, task string) (string, error) { | |
marshal, err := json.Marshal(textProperties) | |
if err != nil { | |
return "", err | |
} | |
return fmt.Sprintf(`'%v: | |
%v`, task, string(marshal)), nil | |
} | |
func (v *palm) generateForPrompt(textProperties map[string]string, prompt string) (string, error) { | |
all := compile.FindAll([]byte(prompt), -1) | |
for _, match := range all { | |
originalProperty := string(match) | |
replacedProperty := compile.FindStringSubmatch(originalProperty)[1] | |
replacedProperty = strings.TrimSpace(replacedProperty) | |
value := textProperties[replacedProperty] | |
if value == "" { | |
return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty) | |
} | |
prompt = strings.ReplaceAll(prompt, originalProperty, value) | |
} | |
return prompt, nil | |
} | |
func (v *palm) getApiKey(ctx context.Context) (string, error) { | |
if apiKeyValue := v.getValueFromContext(ctx, "X-Palm-Api-Key"); apiKeyValue != "" { | |
return apiKeyValue, nil | |
} | |
if len(v.apiKey) > 0 { | |
return v.apiKey, nil | |
} | |
return "", errors.New("no api key found " + | |
"neither in request header: X-Palm-Api-Key " + | |
"nor in environment variable under PALM_APIKEY") | |
} | |
func (v *palm) getValueFromContext(ctx context.Context, key string) string { | |
if value := ctx.Value(key); value != nil { | |
if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 { | |
return keyHeader[0] | |
} | |
} | |
// try getting header from GRPC if not successful | |
if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 { | |
return apiKey[0] | |
} | |
return "" | |
} | |
type generateInput struct { | |
Instances []instance `json:"instances,omitempty"` | |
Parameters parameters `json:"parameters"` | |
} | |
type instance struct { | |
Context string `json:"context,omitempty"` | |
Messages []message `json:"messages,omitempty"` | |
Examples []example `json:"examples,omitempty"` | |
} | |
type message struct { | |
Author string `json:"author"` | |
Content string `json:"content"` | |
} | |
type example struct { | |
Input string `json:"input"` | |
Output string `json:"output"` | |
} | |
type parameters struct { | |
Temperature float64 `json:"temperature"` | |
MaxOutputTokens int `json:"maxOutputTokens"` | |
TopP float64 `json:"topP"` | |
TopK int `json:"topK"` | |
} | |
type generateResponse struct { | |
Predictions []prediction `json:"predictions,omitempty"` | |
Error *palmApiError `json:"error,omitempty"` | |
DeployedModelId string `json:"deployedModelId,omitempty"` | |
Model string `json:"model,omitempty"` | |
ModelDisplayName string `json:"modelDisplayName,omitempty"` | |
ModelVersionId string `json:"modelVersionId,omitempty"` | |
} | |
type prediction struct { | |
Candidates []candidate `json:"candidates,omitempty"` | |
SafetyAttributes *[]safetyAttributes `json:"safetyAttributes,omitempty"` | |
} | |
type candidate struct { | |
Author string `json:"author"` | |
Content string `json:"content"` | |
} | |
type safetyAttributes struct { | |
Scores []float64 `json:"scores,omitempty"` | |
Blocked *bool `json:"blocked,omitempty"` | |
Categories []string `json:"categories,omitempty"` | |
} | |
type palmApiError struct { | |
Code int `json:"code"` | |
Message string `json:"message"` | |
Status string `json:"status"` | |
} | |
type generateMessageRequest struct { | |
Prompt *generateMessagePrompt `json:"prompt,omitempty"` | |
Temperature float64 `json:"temperature,omitempty"` | |
CandidateCount int `json:"candidateCount,omitempty"` // default 1 | |
TopP float64 `json:"topP"` | |
TopK int `json:"topK"` | |
} | |
type generateMessagePrompt struct { | |
Context string `json:"prompt,omitempty"` | |
Examples []generateExample `json:"examples,omitempty"` | |
Messages []generateMessage `json:"messages,omitempty"` | |
} | |
type generateMessage struct { | |
Author string `json:"author,omitempty"` | |
Content string `json:"content,omitempty"` | |
CitationMetadata *generateCitationMetadata `json:"citationMetadata,omitempty"` | |
} | |
type generateCitationMetadata struct { | |
CitationSources []generateCitationSource `json:"citationSources,omitempty"` | |
} | |
type generateCitationSource struct { | |
StartIndex int `json:"startIndex,omitempty"` | |
EndIndex int `json:"endIndex,omitempty"` | |
URI string `json:"uri,omitempty"` | |
License string `json:"license,omitempty"` | |
} | |
type generateExample struct { | |
Input *generateMessage `json:"input,omitempty"` | |
Output *generateMessage `json:"output,omitempty"` | |
} | |
type generateMessageResponse struct { | |
Candidates []generateMessage `json:"candidates,omitempty"` | |
Messages []generateMessage `json:"messages,omitempty"` | |
Filters []contentFilter `json:"filters,omitempty"` | |
Error *palmApiError `json:"error,omitempty"` | |
} | |
type contentFilter struct { | |
Reason string `json:"reason,omitempty"` | |
Message string `json:"message,omitempty"` | |
} | |
type generateContentRequest struct { | |
Contents []content `json:"contents,omitempty"` | |
SafetySettings []safetySetting `json:"safetySettings,omitempty"` | |
GenerationConfig *generationConfig `json:"generationConfig,omitempty"` | |
} | |
type content struct { | |
Parts []part `json:"parts,omitempty"` | |
Role string `json:"role,omitempty"` | |
} | |
type part struct { | |
Text string `json:"text,omitempty"` | |
InlineData string `json:"inline_data,omitempty"` | |
} | |
type safetySetting struct { | |
Category harmCategory `json:"category,omitempty"` | |
Threshold harmBlockThreshold `json:"threshold,omitempty"` | |
} | |
type generationConfig struct { | |
StopSequences []string `json:"stopSequences,omitempty"` | |
CandidateCount int `json:"candidateCount,omitempty"` | |
MaxOutputTokens int `json:"maxOutputTokens,omitempty"` | |
Temperature float64 `json:"temperature,omitempty"` | |
TopP float64 `json:"topP,omitempty"` | |
TopK int `json:"topK,omitempty"` | |
} | |
type generateContentResponse struct { | |
Candidates []generateContentCandidate `json:"candidates,omitempty"` | |
PromptFeedback *promptFeedback `json:"promptFeedback,omitempty"` | |
Error *palmApiError `json:"error,omitempty"` | |
} | |
type generateContentCandidate struct { | |
Content contentResponse `json:"content,omitempty"` | |
FinishReason string `json:"finishReason,omitempty"` | |
Index int `json:"index,omitempty"` | |
SafetyRatings []safetyRating `json:"safetyRatings,omitempty"` | |
} | |
type contentResponse struct { | |
Parts []part `json:"parts,omitempty"` | |
Role string `json:"role,omitempty"` | |
} | |
type promptFeedback struct { | |
SafetyRatings []safetyRating `json:"safetyRatings,omitempty"` | |
} | |
type safetyRating struct { | |
Category harmCategory `json:"category,omitempty"` | |
Probability harmProbability `json:"probability,omitempty"` | |
Blocked *bool `json:"blocked,omitempty"` | |
} | |