Spaces:
Configuration error
Configuration error
package openai | |
import ( | |
"context" | |
"encoding/json" | |
"fmt" | |
"github.com/gofiber/fiber/v2" | |
"github.com/google/uuid" | |
"github.com/mudler/LocalAI/core/config" | |
fiberContext "github.com/mudler/LocalAI/core/http/ctx" | |
"github.com/mudler/LocalAI/core/schema" | |
"github.com/mudler/LocalAI/pkg/functions" | |
"github.com/mudler/LocalAI/pkg/model" | |
"github.com/mudler/LocalAI/pkg/templates" | |
"github.com/mudler/LocalAI/pkg/utils" | |
"github.com/rs/zerolog/log" | |
) | |
type correlationIDKeyType string | |
// CorrelationIDKey to track request across process boundary | |
const CorrelationIDKey correlationIDKeyType = "correlationID" | |
func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) { | |
input := new(schema.OpenAIRequest) | |
// Get input data from the request body | |
if err := c.BodyParser(input); err != nil { | |
return "", nil, fmt.Errorf("failed parsing request body: %w", err) | |
} | |
received, _ := json.Marshal(input) | |
// Extract or generate the correlation ID | |
correlationID := c.Get("X-Correlation-ID", uuid.New().String()) | |
ctx, cancel := context.WithCancel(o.Context) | |
// Add the correlation ID to the new context | |
ctxWithCorrelationID := context.WithValue(ctx, CorrelationIDKey, correlationID) | |
input.Context = ctxWithCorrelationID | |
input.Cancel = cancel | |
log.Debug().Msgf("Request received: %s", string(received)) | |
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel) | |
return modelFile, input, err | |
} | |
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) { | |
if input.Echo { | |
config.Echo = input.Echo | |
} | |
if input.TopK != nil { | |
config.TopK = input.TopK | |
} | |
if input.TopP != nil { | |
config.TopP = input.TopP | |
} | |
if input.Backend != "" { | |
config.Backend = input.Backend | |
} | |
if input.ClipSkip != 0 { | |
config.Diffusers.ClipSkip = input.ClipSkip | |
} | |
if input.ModelBaseName != "" { | |
config.AutoGPTQ.ModelBaseName = input.ModelBaseName | |
} | |
if input.NegativePromptScale != 0 { | |
config.NegativePromptScale = input.NegativePromptScale | |
} | |
if input.UseFastTokenizer { | |
config.UseFastTokenizer = input.UseFastTokenizer | |
} | |
if input.NegativePrompt != "" { | |
config.NegativePrompt = input.NegativePrompt | |
} | |
if input.RopeFreqBase != 0 { | |
config.RopeFreqBase = input.RopeFreqBase | |
} | |
if input.RopeFreqScale != 0 { | |
config.RopeFreqScale = input.RopeFreqScale | |
} | |
if input.Grammar != "" { | |
config.Grammar = input.Grammar | |
} | |
if input.Temperature != nil { | |
config.Temperature = input.Temperature | |
} | |
if input.Maxtokens != nil { | |
config.Maxtokens = input.Maxtokens | |
} | |
if input.ResponseFormat != nil { | |
switch responseFormat := input.ResponseFormat.(type) { | |
case string: | |
config.ResponseFormat = responseFormat | |
case map[string]interface{}: | |
config.ResponseFormatMap = responseFormat | |
} | |
} | |
switch stop := input.Stop.(type) { | |
case string: | |
if stop != "" { | |
config.StopWords = append(config.StopWords, stop) | |
} | |
case []interface{}: | |
for _, pp := range stop { | |
if s, ok := pp.(string); ok { | |
config.StopWords = append(config.StopWords, s) | |
} | |
} | |
} | |
if len(input.Tools) > 0 { | |
for _, tool := range input.Tools { | |
input.Functions = append(input.Functions, tool.Function) | |
} | |
} | |
if input.ToolsChoice != nil { | |
var toolChoice functions.Tool | |
switch content := input.ToolsChoice.(type) { | |
case string: | |
_ = json.Unmarshal([]byte(content), &toolChoice) | |
case map[string]interface{}: | |
dat, _ := json.Marshal(content) | |
_ = json.Unmarshal(dat, &toolChoice) | |
} | |
input.FunctionCall = map[string]interface{}{ | |
"name": toolChoice.Function.Name, | |
} | |
} | |
// Decode each request's message content | |
imgIndex, vidIndex, audioIndex := 0, 0, 0 | |
for i, m := range input.Messages { | |
nrOfImgsInMessage := 0 | |
nrOfVideosInMessage := 0 | |
nrOfAudiosInMessage := 0 | |
switch content := m.Content.(type) { | |
case string: | |
input.Messages[i].StringContent = content | |
case []interface{}: | |
dat, _ := json.Marshal(content) | |
c := []schema.Content{} | |
json.Unmarshal(dat, &c) | |
textContent := "" | |
// we will template this at the end | |
CONTENT: | |
for _, pp := range c { | |
switch pp.Type { | |
case "text": | |
textContent += pp.Text | |
//input.Messages[i].StringContent = pp.Text | |
case "video", "video_url": | |
// Decode content as base64 either if it's an URL or base64 text | |
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL) | |
if err != nil { | |
log.Error().Msgf("Failed encoding video: %s", err) | |
continue CONTENT | |
} | |
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff | |
vidIndex++ | |
nrOfVideosInMessage++ | |
case "audio_url", "audio": | |
// Decode content as base64 either if it's an URL or base64 text | |
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL) | |
if err != nil { | |
log.Error().Msgf("Failed encoding image: %s", err) | |
continue CONTENT | |
} | |
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff | |
audioIndex++ | |
nrOfAudiosInMessage++ | |
case "image_url", "image": | |
// Decode content as base64 either if it's an URL or base64 text | |
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL) | |
if err != nil { | |
log.Error().Msgf("Failed encoding image: %s", err) | |
continue CONTENT | |
} | |
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff | |
imgIndex++ | |
nrOfImgsInMessage++ | |
} | |
} | |
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{ | |
TotalImages: imgIndex, | |
TotalVideos: vidIndex, | |
TotalAudios: audioIndex, | |
ImagesInMessage: nrOfImgsInMessage, | |
VideosInMessage: nrOfVideosInMessage, | |
AudiosInMessage: nrOfAudiosInMessage, | |
}, textContent) | |
} | |
} | |
if input.RepeatPenalty != 0 { | |
config.RepeatPenalty = input.RepeatPenalty | |
} | |
if input.FrequencyPenalty != 0 { | |
config.FrequencyPenalty = input.FrequencyPenalty | |
} | |
if input.PresencePenalty != 0 { | |
config.PresencePenalty = input.PresencePenalty | |
} | |
if input.Keep != 0 { | |
config.Keep = input.Keep | |
} | |
if input.Batch != 0 { | |
config.Batch = input.Batch | |
} | |
if input.IgnoreEOS { | |
config.IgnoreEOS = input.IgnoreEOS | |
} | |
if input.Seed != nil { | |
config.Seed = input.Seed | |
} | |
if input.TypicalP != nil { | |
config.TypicalP = input.TypicalP | |
} | |
switch inputs := input.Input.(type) { | |
case string: | |
if inputs != "" { | |
config.InputStrings = append(config.InputStrings, inputs) | |
} | |
case []interface{}: | |
for _, pp := range inputs { | |
switch i := pp.(type) { | |
case string: | |
config.InputStrings = append(config.InputStrings, i) | |
case []interface{}: | |
tokens := []int{} | |
for _, ii := range i { | |
tokens = append(tokens, int(ii.(float64))) | |
} | |
config.InputToken = append(config.InputToken, tokens) | |
} | |
} | |
} | |
// Can be either a string or an object | |
switch fnc := input.FunctionCall.(type) { | |
case string: | |
if fnc != "" { | |
config.SetFunctionCallString(fnc) | |
} | |
case map[string]interface{}: | |
var name string | |
n, exists := fnc["name"] | |
if exists { | |
nn, e := n.(string) | |
if e { | |
name = nn | |
} | |
} | |
config.SetFunctionCallNameString(name) | |
} | |
switch p := input.Prompt.(type) { | |
case string: | |
config.PromptStrings = append(config.PromptStrings, p) | |
case []interface{}: | |
for _, pp := range p { | |
if s, ok := pp.(string); ok { | |
config.PromptStrings = append(config.PromptStrings, s) | |
} | |
} | |
} | |
} | |
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) { | |
cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath, | |
config.LoadOptionDebug(debug), | |
config.LoadOptionThreads(threads), | |
config.LoadOptionContextSize(ctx), | |
config.LoadOptionF16(f16), | |
config.ModelPath(loader.ModelPath), | |
) | |
// Set the parameters for the language model prediction | |
updateRequestConfig(cfg, input) | |
if !cfg.Validate() { | |
return nil, nil, fmt.Errorf("failed to validate config") | |
} | |
return cfg, input, err | |
} | |