KevinStephenson
Adding in weaviate code
b110593
raw
history blame
18.5 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package clients
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"time"
"github.com/weaviate/weaviate/usecases/modulecomponents"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/weaviate/weaviate/entities/moduletools"
"github.com/weaviate/weaviate/modules/generative-palm/config"
generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models"
)
type harmCategory string
var (
// Category is unspecified.
HarmCategoryUnspecified harmCategory = "HARM_CATEGORY_UNSPECIFIED"
// Negative or harmful comments targeting identity and/or protected attribute.
HarmCategoryDerogatory harmCategory = "HARM_CATEGORY_DEROGATORY"
// Content that is rude, disrepspectful, or profane.
HarmCategoryToxicity harmCategory = "HARM_CATEGORY_TOXICITY"
// Describes scenarios depictng violence against an individual or group, or general descriptions of gore.
HarmCategoryViolence harmCategory = "HARM_CATEGORY_VIOLENCE"
// Contains references to sexual acts or other lewd content.
HarmCategorySexual harmCategory = "HARM_CATEGORY_SEXUAL"
// Promotes unchecked medical advice.
HarmCategoryMedical harmCategory = "HARM_CATEGORY_MEDICAL"
// Dangerous content that promotes, facilitates, or encourages harmful acts.
HarmCategoryDangerous harmCategory = "HARM_CATEGORY_DANGEROUS"
// Harassment content.
HarmCategoryHarassment harmCategory = "HARM_CATEGORY_HARASSMENT"
// Hate speech and content.
HarmCategoryHate_speech harmCategory = "HARM_CATEGORY_HATE_SPEECH"
// Sexually explicit content.
HarmCategorySexually_explicit harmCategory = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
// Dangerous content.
HarmCategoryDangerous_content harmCategory = "HARM_CATEGORY_DANGEROUS_CONTENT"
)
type harmBlockThreshold string
var (
// Threshold is unspecified.
HarmBlockThresholdUnspecified harmBlockThreshold = "HARM_BLOCK_THRESHOLD_UNSPECIFIED"
// Content with NEGLIGIBLE will be allowed.
BlockLowAndAbove harmBlockThreshold = "BLOCK_LOW_AND_ABOVE"
// Content with NEGLIGIBLE and LOW will be allowed.
BlockMediumAndAbove harmBlockThreshold = "BLOCK_MEDIUM_AND_ABOVE"
// Content with NEGLIGIBLE, LOW, and MEDIUM will be allowed.
BlockOnlyHigh harmBlockThreshold = "BLOCK_ONLY_HIGH"
// All content will be allowed.
BlockNone harmBlockThreshold = "BLOCK_NONE"
)
type harmProbability string
var (
// Probability is unspecified.
HARM_PROBABILITY_UNSPECIFIED harmProbability = "HARM_PROBABILITY_UNSPECIFIED"
// Content has a negligible chance of being unsafe.
NEGLIGIBLE harmProbability = "NEGLIGIBLE"
// Content has a low chance of being unsafe.
LOW harmProbability = "LOW"
// Content has a medium chance of being unsafe.
MEDIUM harmProbability = "MEDIUM"
// Content has a high chance of being unsafe.
HIGH harmProbability = "HIGH"
)
var compile, _ = regexp.Compile(`{([\w\s]*?)}`)
func buildURL(useGenerativeAI bool, apiEndoint, projectID, modelID string) string {
if useGenerativeAI {
// Generative AI endpoints, for more context check out this link:
// https://developers.generativeai.google/models/language#model_variations
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage
if strings.HasPrefix(modelID, "gemini") {
return fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent", modelID)
}
return "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
}
urlTemplate := "https://%s/v1/projects/%s/locations/us-central1/publishers/google/models/%s:predict"
return fmt.Sprintf(urlTemplate, apiEndoint, projectID, modelID)
}
type palm struct {
apiKey string
buildUrlFn func(useGenerativeAI bool, apiEndoint, projectID, modelID string) string
httpClient *http.Client
logger logrus.FieldLogger
}
func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *palm {
return &palm{
apiKey: apiKey,
httpClient: &http.Client{
Timeout: timeout,
},
buildUrlFn: buildURL,
logger: logger,
}
}
func (v *palm) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
forPrompt, err := v.generateForPrompt(textProperties, prompt)
if err != nil {
return nil, err
}
return v.Generate(ctx, cfg, forPrompt)
}
func (v *palm) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
forTask, err := v.generatePromptForTask(textProperties, task)
if err != nil {
return nil, err
}
return v.Generate(ctx, cfg, forTask)
}
func (v *palm) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) {
settings := config.NewClassSettings(cfg)
useGenerativeAIEndpoint := v.useGenerativeAIEndpoint(settings.ApiEndpoint())
modelID := settings.ModelID()
if settings.EndpointID() != "" {
modelID = settings.EndpointID()
}
endpointURL := v.buildUrlFn(useGenerativeAIEndpoint, settings.ApiEndpoint(), settings.ProjectID(), modelID)
input := v.getPayload(useGenerativeAIEndpoint, prompt, settings)
body, err := json.Marshal(input)
if err != nil {
return nil, errors.Wrap(err, "marshal body")
}
req, err := http.NewRequestWithContext(ctx, "POST", endpointURL,
bytes.NewReader(body))
if err != nil {
return nil, errors.Wrap(err, "create POST request")
}
apiKey, err := v.getApiKey(ctx)
if err != nil {
return nil, errors.Wrapf(err, "PaLM API Key")
}
req.Header.Add("Content-Type", "application/json")
if useGenerativeAIEndpoint {
req.Header.Add("x-goog-api-key", apiKey)
} else {
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey))
}
res, err := v.httpClient.Do(req)
if err != nil {
return nil, errors.Wrap(err, "send POST request")
}
defer res.Body.Close()
bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
return nil, errors.Wrap(err, "read response body")
}
if useGenerativeAIEndpoint {
if strings.HasPrefix(modelID, "gemini") {
return v.parseGenerateContentResponse(res.StatusCode, bodyBytes)
}
return v.parseGenerateMessageResponse(res.StatusCode, bodyBytes)
}
return v.parseResponse(res.StatusCode, bodyBytes)
}
func (v *palm) parseGenerateMessageResponse(statusCode int, bodyBytes []byte) (*generativemodels.GenerateResponse, error) {
var resBody generateMessageResponse
if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
return nil, errors.Wrap(err, "unmarshal response body")
}
if err := v.checkResponse(statusCode, resBody.Error); err != nil {
return nil, err
}
if len(resBody.Candidates) > 0 {
return v.getGenerateResponse(resBody.Candidates[0].Content)
}
return &generativemodels.GenerateResponse{
Result: nil,
}, nil
}
func (v *palm) parseGenerateContentResponse(statusCode int, bodyBytes []byte) (*generativemodels.GenerateResponse, error) {
var resBody generateContentResponse
if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
return nil, errors.Wrap(err, "unmarshal response body")
}
if err := v.checkResponse(statusCode, resBody.Error); err != nil {
return nil, err
}
if len(resBody.Candidates) > 0 && len(resBody.Candidates[0].Content.Parts) > 0 {
return v.getGenerateResponse(resBody.Candidates[0].Content.Parts[0].Text)
}
return &generativemodels.GenerateResponse{
Result: nil,
}, nil
}
func (v *palm) parseResponse(statusCode int, bodyBytes []byte) (*generativemodels.GenerateResponse, error) {
var resBody generateResponse
if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
return nil, errors.Wrap(err, "unmarshal response body")
}
if err := v.checkResponse(statusCode, resBody.Error); err != nil {
return nil, err
}
if len(resBody.Predictions) > 0 && len(resBody.Predictions[0].Candidates) > 0 {
return v.getGenerateResponse(resBody.Predictions[0].Candidates[0].Content)
}
return &generativemodels.GenerateResponse{
Result: nil,
}, nil
}
func (v *palm) getGenerateResponse(content string) (*generativemodels.GenerateResponse, error) {
if content != "" {
trimmedResponse := strings.Trim(content, "\n")
return &generativemodels.GenerateResponse{
Result: &trimmedResponse,
}, nil
}
return &generativemodels.GenerateResponse{
Result: nil,
}, nil
}
func (v *palm) checkResponse(statusCode int, palmApiError *palmApiError) error {
if statusCode != 200 || palmApiError != nil {
if palmApiError != nil {
return fmt.Errorf("connection to Google PaLM failed with status: %v error: %v",
statusCode, palmApiError.Message)
}
return fmt.Errorf("connection to Google PaLM failed with status: %d", statusCode)
}
return nil
}
func (v *palm) useGenerativeAIEndpoint(apiEndpoint string) bool {
return apiEndpoint == "generativelanguage.googleapis.com"
}
func (v *palm) getPayload(useGenerativeAI bool, prompt string, settings config.ClassSettings) any {
if useGenerativeAI {
if strings.HasPrefix(settings.ModelID(), "gemini") {
input := generateContentRequest{
Contents: []content{
{
Role: "user",
Parts: []part{
{
Text: prompt,
},
},
},
},
GenerationConfig: &generationConfig{
Temperature: settings.Temperature(),
TopP: settings.TopP(),
TopK: settings.TopK(),
CandidateCount: 1,
},
SafetySettings: []safetySetting{
{
Category: HarmCategoryHarassment,
Threshold: BlockMediumAndAbove,
},
{
Category: HarmCategoryHate_speech,
Threshold: BlockMediumAndAbove,
},
{
Category: HarmCategoryDangerous_content,
Threshold: BlockMediumAndAbove,
},
{
Category: HarmCategoryDangerous_content,
Threshold: BlockMediumAndAbove,
},
},
}
return input
}
input := generateMessageRequest{
Prompt: &generateMessagePrompt{
Messages: []generateMessage{
{
Content: prompt,
},
},
},
Temperature: settings.Temperature(),
TopP: settings.TopP(),
TopK: settings.TopK(),
CandidateCount: 1,
}
return input
}
input := generateInput{
Instances: []instance{
{
Messages: []message{
{
Content: prompt,
},
},
},
},
Parameters: parameters{
Temperature: settings.Temperature(),
MaxOutputTokens: settings.TokenLimit(),
TopP: settings.TopP(),
TopK: settings.TopK(),
},
}
return input
}
func (v *palm) generatePromptForTask(textProperties []map[string]string, task string) (string, error) {
marshal, err := json.Marshal(textProperties)
if err != nil {
return "", err
}
return fmt.Sprintf(`'%v:
%v`, task, string(marshal)), nil
}
func (v *palm) generateForPrompt(textProperties map[string]string, prompt string) (string, error) {
all := compile.FindAll([]byte(prompt), -1)
for _, match := range all {
originalProperty := string(match)
replacedProperty := compile.FindStringSubmatch(originalProperty)[1]
replacedProperty = strings.TrimSpace(replacedProperty)
value := textProperties[replacedProperty]
if value == "" {
return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty)
}
prompt = strings.ReplaceAll(prompt, originalProperty, value)
}
return prompt, nil
}
func (v *palm) getApiKey(ctx context.Context) (string, error) {
if apiKeyValue := v.getValueFromContext(ctx, "X-Palm-Api-Key"); apiKeyValue != "" {
return apiKeyValue, nil
}
if len(v.apiKey) > 0 {
return v.apiKey, nil
}
return "", errors.New("no api key found " +
"neither in request header: X-Palm-Api-Key " +
"nor in environment variable under PALM_APIKEY")
}
func (v *palm) getValueFromContext(ctx context.Context, key string) string {
if value := ctx.Value(key); value != nil {
if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 {
return keyHeader[0]
}
}
// try getting header from GRPC if not successful
if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 {
return apiKey[0]
}
return ""
}
type generateInput struct {
Instances []instance `json:"instances,omitempty"`
Parameters parameters `json:"parameters"`
}
type instance struct {
Context string `json:"context,omitempty"`
Messages []message `json:"messages,omitempty"`
Examples []example `json:"examples,omitempty"`
}
type message struct {
Author string `json:"author"`
Content string `json:"content"`
}
type example struct {
Input string `json:"input"`
Output string `json:"output"`
}
type parameters struct {
Temperature float64 `json:"temperature"`
MaxOutputTokens int `json:"maxOutputTokens"`
TopP float64 `json:"topP"`
TopK int `json:"topK"`
}
type generateResponse struct {
Predictions []prediction `json:"predictions,omitempty"`
Error *palmApiError `json:"error,omitempty"`
DeployedModelId string `json:"deployedModelId,omitempty"`
Model string `json:"model,omitempty"`
ModelDisplayName string `json:"modelDisplayName,omitempty"`
ModelVersionId string `json:"modelVersionId,omitempty"`
}
type prediction struct {
Candidates []candidate `json:"candidates,omitempty"`
SafetyAttributes *[]safetyAttributes `json:"safetyAttributes,omitempty"`
}
type candidate struct {
Author string `json:"author"`
Content string `json:"content"`
}
type safetyAttributes struct {
Scores []float64 `json:"scores,omitempty"`
Blocked *bool `json:"blocked,omitempty"`
Categories []string `json:"categories,omitempty"`
}
type palmApiError struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
}
type generateMessageRequest struct {
Prompt *generateMessagePrompt `json:"prompt,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"` // default 1
TopP float64 `json:"topP"`
TopK int `json:"topK"`
}
type generateMessagePrompt struct {
Context string `json:"prompt,omitempty"`
Examples []generateExample `json:"examples,omitempty"`
Messages []generateMessage `json:"messages,omitempty"`
}
type generateMessage struct {
Author string `json:"author,omitempty"`
Content string `json:"content,omitempty"`
CitationMetadata *generateCitationMetadata `json:"citationMetadata,omitempty"`
}
type generateCitationMetadata struct {
CitationSources []generateCitationSource `json:"citationSources,omitempty"`
}
type generateCitationSource struct {
StartIndex int `json:"startIndex,omitempty"`
EndIndex int `json:"endIndex,omitempty"`
URI string `json:"uri,omitempty"`
License string `json:"license,omitempty"`
}
type generateExample struct {
Input *generateMessage `json:"input,omitempty"`
Output *generateMessage `json:"output,omitempty"`
}
type generateMessageResponse struct {
Candidates []generateMessage `json:"candidates,omitempty"`
Messages []generateMessage `json:"messages,omitempty"`
Filters []contentFilter `json:"filters,omitempty"`
Error *palmApiError `json:"error,omitempty"`
}
type contentFilter struct {
Reason string `json:"reason,omitempty"`
Message string `json:"message,omitempty"`
}
type generateContentRequest struct {
Contents []content `json:"contents,omitempty"`
SafetySettings []safetySetting `json:"safetySettings,omitempty"`
GenerationConfig *generationConfig `json:"generationConfig,omitempty"`
}
type content struct {
Parts []part `json:"parts,omitempty"`
Role string `json:"role,omitempty"`
}
type part struct {
Text string `json:"text,omitempty"`
InlineData string `json:"inline_data,omitempty"`
}
type safetySetting struct {
Category harmCategory `json:"category,omitempty"`
Threshold harmBlockThreshold `json:"threshold,omitempty"`
}
type generationConfig struct {
StopSequences []string `json:"stopSequences,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
}
type generateContentResponse struct {
Candidates []generateContentCandidate `json:"candidates,omitempty"`
PromptFeedback *promptFeedback `json:"promptFeedback,omitempty"`
Error *palmApiError `json:"error,omitempty"`
}
type generateContentCandidate struct {
Content contentResponse `json:"content,omitempty"`
FinishReason string `json:"finishReason,omitempty"`
Index int `json:"index,omitempty"`
SafetyRatings []safetyRating `json:"safetyRatings,omitempty"`
}
type contentResponse struct {
Parts []part `json:"parts,omitempty"`
Role string `json:"role,omitempty"`
}
type promptFeedback struct {
SafetyRatings []safetyRating `json:"safetyRatings,omitempty"`
}
type safetyRating struct {
Category harmCategory `json:"category,omitempty"`
Probability harmProbability `json:"probability,omitempty"`
Blocked *bool `json:"blocked,omitempty"`
}