de / main.go
caidaohz's picture
fix: 修复模型不存在时的返回值,添加错误处理逻辑
04cfbbc
package main
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/joho/godotenv"
)
// 配置常量
const (
BadKeyRetryInterval = 600 * time.Second // 10分钟
SessionTimeout = 600 * time.Second // 10分钟
DefaultPort = 7860
)
// 全局变量
var (
privateKey string
ondemandAPIKeys []string
safeHeaders = []string{"Authorization", "X-API-KEY"}
ondemandAPIBase = "https://api.on-demand.io/chat/v1"
defaultModel = "predefined-openai-gpt4o"
)
// 模型映射
var modelMap = map[string]string{
"o3-mini": "predefined-openai-gpto3-mini",
"o4-mini": "predefined-openai-gpto4-mini",
"gpt-4o": "predefined-openai-gpt4o",
"gpt-4.1": "predefined-openai-gpt4.1",
"gpt-4.1-mini": "predefined-openai-gpt4.1-mini",
"gpt-4o-mini": "predefined-openai-gpt4o-mini",
"deepseek-v3": "predefined-deepseek-v3",
"deepseek-r1": "predefined-deepseek-r1",
"claude-4-sonnet": "predefined-claude-4-sonnet",
"claude-4-opus": "predefined-claude-4-opus",
}
// KeyStatus 表示API密钥的状态
type KeyStatus struct {
Bad bool `json:"bad"`
BadTS time.Time `json:"bad_ts"`
}
// KeyManager 管理API密钥的轮换和状态
type KeyManager struct {
keyList []string
mu sync.RWMutex
keyStatus map[string]*KeyStatus
idx int
currentKey string
currentSession string
lastUsedTime time.Time
}
// NewKeyManager 创建新的密钥管理器
func NewKeyManager(keys []string) *KeyManager {
km := &KeyManager{
keyList: make([]string, len(keys)),
keyStatus: make(map[string]*KeyStatus),
}
copy(km.keyList, keys)
for _, key := range keys {
km.keyStatus[key] = &KeyStatus{}
}
return km
}
// displayKey 显示密钥的简化版本
func (km *KeyManager) displayKey(key string) string {
if len(key) <= 10 {
return key
}
return fmt.Sprintf("%s...%s", key[:6], key[len(key)-4:])
}
// Get 获取可用的API密钥
func (km *KeyManager) Get() string {
km.mu.Lock()
defer km.mu.Unlock()
now := time.Now()
// 检查会话是否超时
if km.currentKey != "" && !km.lastUsedTime.IsZero() &&
now.Sub(km.lastUsedTime) > SessionTimeout {
log.Printf("【对话超时】上次使用时间: %s", km.lastUsedTime.Format("2006-01-02 15:04:05"))
log.Printf("【对话超时】当前时间: %s", now.Format("2006-01-02 15:04:05"))
log.Printf("【对话超时】超时%d分钟,切换新会话", int(SessionTimeout.Minutes()))
km.currentKey = ""
km.currentSession = ""
}
// 如果已有正在使用的key,继续使用
if km.currentKey != "" {
if !km.keyStatus[km.currentKey].Bad {
log.Printf("【对话请求】【继续使用API KEY: %s】【状态:正常】", km.displayKey(km.currentKey))
km.lastUsedTime = now
return km.currentKey
} else {
// 当前key已标记为异常,需要切换
km.currentKey = ""
km.currentSession = ""
}
}
// 选择新的key
total := len(km.keyList)
for i := 0; i < total; i++ {
key := km.keyList[km.idx]
km.idx = (km.idx + 1) % total
status := km.keyStatus[key]
if !status.Bad {
log.Printf("【对话请求】【使用新API KEY: %s】【状态:正常】", km.displayKey(key))
km.currentKey = key
km.currentSession = ""
km.lastUsedTime = now
return key
}
if status.Bad && !status.BadTS.IsZero() {
if now.Sub(status.BadTS) >= BadKeyRetryInterval {
log.Printf("【KEY自动尝试恢复】API KEY: %s 满足重试周期,标记为正常", km.displayKey(key))
status.Bad = false
status.BadTS = time.Time{}
km.currentKey = key
km.currentSession = ""
km.lastUsedTime = now
return key
}
}
}
// 所有密钥都不可用,强制重置
log.Printf("【警告】全部KEY已被禁用,强制选用第一个KEY继续尝试: %s", km.displayKey(km.keyList[0]))
for _, key := range km.keyList {
km.keyStatus[key].Bad = false
km.keyStatus[key].BadTS = time.Time{}
}
km.idx = 0
km.currentKey = km.keyList[0]
km.currentSession = ""
km.lastUsedTime = now
log.Printf("【对话请求】【使用API KEY: %s】【状态:强制尝试(全部异常)】", km.displayKey(km.currentKey))
return km.currentKey
}
// MarkBad 标记密钥为不可用
func (km *KeyManager) MarkBad(key string) {
km.mu.Lock()
defer km.mu.Unlock()
if status, exists := km.keyStatus[key]; exists && !status.Bad {
log.Printf("【禁用KEY】API KEY: %s,接口返回无效(将在%d分钟后自动重试)",
km.displayKey(key), int(BadKeyRetryInterval.Minutes()))
status.Bad = true
status.BadTS = time.Now()
if km.currentKey == key {
km.currentKey = ""
km.currentSession = ""
}
}
}
// GetSession 获取或创建会话
func (km *KeyManager) GetSession(ctx context.Context, apikey string) (string, error) {
km.mu.Lock()
defer km.mu.Unlock()
if km.currentSession == "" {
session, err := createSession(ctx, apikey, "", nil)
if err != nil {
log.Printf("【创建会话失败】错误: %v", err)
return "", err
}
km.currentSession = session
log.Printf("【创建新会话】SESSION ID: %s", km.currentSession)
}
km.lastUsedTime = time.Now()
return km.currentSession, nil
}
var keyManager *KeyManager
// HTTP请求结构
type ChatCompletionRequest struct {
Messages []Message `json:"messages"`
Model string `json:"model"`
Stream bool `json:"stream"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
}
type Choice struct {
Index int `json:"index"`
Message *Message `json:"message,omitempty"`
Delta *Message `json:"delta,omitempty"`
FinishReason *string `json:"finish_reason"`
}
type Usage struct{}
type ModelsResponse struct {
Object string `json:"object"`
Data []Model `json:"data"`
}
type Model struct {
ID string `json:"id"`
Object string `json:"object"`
OwnedBy string `json:"owned_by"`
}
// OnDemand API 结构
type CreateSessionRequest struct {
ExternalUserID string `json:"externalUserId"`
PluginIds []string `json:"pluginIds,omitempty"`
}
type CreateSessionResponse struct {
Data struct {
ID string `json:"id"`
} `json:"data"`
}
type QueryRequest struct {
Query string `json:"query"`
EndpointID string `json:"endpointId"`
PluginIds []string `json:"pluginIds"`
ResponseMode string `json:"responseMode"`
}
type QueryResponse struct {
Data struct {
Answer string `json:"answer"`
} `json:"data"`
}
// 初始化配置
func init() {
// 加载 .env 文件
err := godotenv.Load()
if err != nil {
log.Println("警告:没有找到 .env 文件,将仅使用系统环境变量")
}
initConfig()
}
func initConfig() {
privateKey = getEnv("PRIVATE_KEY", "testofli")
apiKeysStr := os.Getenv("ONDEMAND_APIKEYS")
if apiKeysStr != "" {
ondemandAPIKeys = strings.Split(apiKeysStr, ",")
}
if len(ondemandAPIKeys) == 0 && !isTestMode() {
log.Fatal("ONDEMAND_APIKEYS 环境变量为空,请设置API密钥")
}
if len(ondemandAPIKeys) > 0 {
keyManager = NewKeyManager(ondemandAPIKeys)
}
}
func isTestMode() bool {
for _, arg := range os.Args {
if strings.Contains(arg, "test") {
return true
}
}
return os.Getenv("GIN_MODE") == "test"
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
// 权限检查中间件
func checkPrivateKey() gin.HandlerFunc {
return func(c *gin.Context) {
// 放宽部分接口
if c.Request.URL.Path == "/" || c.Request.URL.Path == "/favicon.ico" {
c.Next()
return
}
var key string
for _, header := range safeHeaders {
if value := c.GetHeader(header); value != "" {
key = value
if header == "Authorization" && strings.HasPrefix(value, "Bearer ") {
key = strings.TrimSpace(value[7:])
}
break
}
}
if key == "" || key != privateKey {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized, must provide correct Authorization or X-API-KEY",
"headers": c.Request.Header,
})
c.Abort()
return
}
c.Next()
}
}
// 获取端点ID
func getEndpointID(openaiModel string) string {
model := strings.ToLower(strings.ReplaceAll(openaiModel, " ", ""))
if endpoint, exists := modelMap[model]; exists {
return endpoint
}
return ""
}
// 创建会话
func createSession(ctx context.Context, apikey, externalUserID string, pluginIds []string) (string, error) {
if externalUserID == "" {
externalUserID = uuid.New().String()
}
payload := CreateSessionRequest{
ExternalUserID: externalUserID,
PluginIds: pluginIds,
}
jsonData, err := json.Marshal(payload)
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, "POST", ondemandAPIBase+"/sessions", bytes.NewBuffer(jsonData))
if err != nil {
return "", err
}
req.Header.Set("apikey", apikey)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 20 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return "", fmt.Errorf("create session failed with status: %d", resp.StatusCode)
}
var sessionResp CreateSessionResponse
if err := json.NewDecoder(resp.Body).Decode(&sessionResp); err != nil {
return "", err
}
return sessionResp.Data.ID, nil
}
// 执行带重试的操作
func withValidKey(ctx context.Context, fn func(ctx context.Context, key string) (interface{}, error)) (interface{}, error) {
badCount := 0
maxRetry := len(keyManager.keyList) * 2
for badCount < maxRetry {
key := keyManager.Get()
result, err := fn(ctx, key)
if err != nil {
// 检查是否是需要标记密钥为坏的错误
if isAuthError(err) {
keyManager.MarkBad(key)
badCount++
continue
}
return nil, err
}
return result, nil
}
return nil, fmt.Errorf("没有可用API KEY,请补充新KEY或联系技术支持")
}
// 检查是否是认证相关错误
func isAuthError(err error) bool {
errStr := err.Error()
return strings.Contains(errStr, "401") ||
strings.Contains(errStr, "403") ||
strings.Contains(errStr, "429") ||
strings.Contains(errStr, "500")
}
// 聊天完成接口
func chatCompletions(c *gin.Context) {
var req ChatCompletionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求缺少messages字段"})
return
}
if len(req.Messages) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求缺少messages字段"})
return
}
// 获取用户消息
var userMsg string
for i := len(req.Messages) - 1; i >= 0; i-- {
if req.Messages[i].Role == "user" {
userMsg = req.Messages[i].Content
break
}
}
if userMsg == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "未找到用户消息"})
return
}
endpointID := getEndpointID(req.Model)
if endpointID == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": map[string]interface{}{
"message": fmt.Sprintf("The model '%s' does not exist", req.Model),
"type": "invalid_request_error",
"param": "model",
"code": "model_not_found",
},
})
return
}
// 添加模型和端点的日志记录
log.Printf("【模型请求】模型: %s, 端点: %s, 流式: %t", req.Model, endpointID, req.Stream)
if req.Stream {
handleStreamResponse(c, userMsg, endpointID, req.Model)
} else {
handleNonStreamResponse(c, userMsg, endpointID, req.Model)
}
}
// 处理流式响应
func handleStreamResponse(c *gin.Context, userMsg, endpointID, model string) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
// 使用channel进行异步处理
resultChan := make(chan string, 100)
errorChan := make(chan error, 1)
go func() {
defer close(resultChan)
defer close(errorChan)
ctx := context.Background()
result, err := withValidKey(ctx, func(ctx context.Context, apikey string) (interface{}, error) {
return streamQuery(ctx, apikey, userMsg, endpointID, model, resultChan)
})
if err != nil {
errorChan <- err
return
}
_ = result // 流式响应的结果通过channel传递
}()
// 处理响应流
for {
select {
case chunk, ok := <-resultChan:
if !ok {
return
}
if chunk == "data: [DONE]" {
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
c.Writer.Flush()
return
}
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", chunk)
c.Writer.Flush()
case err := <-errorChan:
if err != nil {
errorData := map[string]any{"error": err.Error()}
errorJSON, _ := json.Marshal(errorData)
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(errorJSON))
c.Writer.Flush()
}
return
case <-c.Request.Context().Done():
return
}
}
}
// 流式查询
func streamQuery(ctx context.Context, apikey, userMsg, endpointID, model string, resultChan chan<- string) (interface{}, error) {
sessionID, err := keyManager.GetSession(ctx, apikey)
if err != nil {
return nil, err
}
payload := QueryRequest{
Query: userMsg,
EndpointID: endpointID,
PluginIds: []string{},
ResponseMode: "stream",
}
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "POST",
fmt.Sprintf("%s/sessions/%s/query", ondemandAPIBase, sessionID),
bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("apikey", apikey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
client := &http.Client{Timeout: 300 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("stream query failed with status: %d", resp.StatusCode)
}
scanner := bufio.NewScanner(resp.Body)
firstChunk := true
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data:") {
continue
}
dataPart := strings.TrimSpace(line[5:])
if dataPart == "[DONE]" {
resultChan <- "data: [DONE]"
break
}
if strings.HasPrefix(dataPart, "[ERROR]:") {
errJSON := strings.TrimSpace(dataPart[8:])
resultChan <- fmt.Sprintf(`{"error": "%s"}`, errJSON)
break
}
var eventData map[string]any
if err := json.Unmarshal([]byte(dataPart), &eventData); err != nil {
continue
}
// 处理不同类型的事件
if eventType, ok := eventData["eventType"].(string); ok {
var content string
var hasContent bool
switch eventType {
case "fulfillment":
if answer, ok := eventData["answer"].(string); ok {
content = answer
hasContent = true
}
case "stream", "thinking", "reasoning", "thoughts": // 可能的思考过程事件类型
if answer, ok := eventData["answer"].(string); ok {
content = answer
hasContent = true
} else if text, ok := eventData["text"].(string); ok {
content = text
hasContent = true
} else if data, ok := eventData["data"].(string); ok {
content = data
hasContent = true
} else if thoughts, ok := eventData["thoughts"].(string); ok {
content = thoughts
hasContent = true
}
default:
// 对于未知事件类型,尝试提取任何文本内容
if answer, ok := eventData["answer"].(string); ok {
content = answer
hasContent = true
} else if text, ok := eventData["text"].(string); ok {
content = text
hasContent = true
} else if thoughts, ok := eventData["thoughts"].(string); ok {
content = thoughts
hasContent = true
}
}
if hasContent {
chunk := ChatCompletionResponse{
ID: "chatcmpl-" + uuid.New().String()[:8],
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: model,
Choices: []Choice{{
Index: 0,
Delta: &Message{
Role: func() string {
if firstChunk {
return "assistant"
} else {
return ""
}
}(),
Content: content,
},
FinishReason: nil,
}},
}
chunkJSON, _ := json.Marshal(chunk)
resultChan <- string(chunkJSON)
firstChunk = false
}
}
}
if err := scanner.Err(); err != nil {
return nil, err
}
return nil, nil
}
// 处理非流式响应
func handleNonStreamResponse(c *gin.Context, userMsg, endpointID, model string) {
ctx := c.Request.Context()
result, err := withValidKey(ctx, func(ctx context.Context, apikey string) (any, error) {
return nonStreamQuery(ctx, apikey, userMsg, endpointID, model)
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, result)
}
// 非流式查询
func nonStreamQuery(ctx context.Context, apikey, userMsg, endpointID, model string) (any, error) {
sessionID, err := keyManager.GetSession(ctx, apikey)
if err != nil {
return nil, err
}
payload := QueryRequest{
Query: userMsg,
EndpointID: endpointID,
PluginIds: []string{},
ResponseMode: "sync",
}
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "POST",
fmt.Sprintf("%s/sessions/%s/query", ondemandAPIBase, sessionID),
bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("apikey", apikey)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 300 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("non-stream query failed with status: %d", resp.StatusCode)
}
var queryResp QueryResponse
if err := json.NewDecoder(resp.Body).Decode(&queryResp); err != nil {
return nil, err
}
content := queryResp.Data.Answer
response := ChatCompletionResponse{
ID: "chatcmpl-" + uuid.New().String()[:8],
Object: "chat.completion",
Created: time.Now().Unix(),
Model: model,
Choices: []Choice{{
Index: 0,
Message: &Message{
Role: "assistant",
Content: content,
},
FinishReason: func() *string { s := "stop"; return &s }(),
}},
Usage: Usage{},
}
return response, nil
}
// 模型列表接口
func models(c *gin.Context) {
var modelList []Model
for modelID := range modelMap {
modelList = append(modelList, Model{
ID: modelID,
Object: "model",
OwnedBy: "ondemand-proxy",
})
}
response := ModelsResponse{
Object: "list",
Data: modelList,
}
c.JSON(http.StatusOK, response)
}
// 健康检查接口
func health(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "ok",
"keys": len(ondemandAPIKeys),
})
}
func main() {
// 设置日志格式
log.SetFlags(log.LstdFlags | log.Lshortfile)
// 设置Gin模式
if os.Getenv("GIN_MODE") == "" {
gin.SetMode(gin.ReleaseMode)
}
router := gin.New()
// 中间件
router.Use(gin.Logger())
router.Use(gin.Recovery())
router.Use(checkPrivateKey())
// 路由
router.GET("/", health)
router.POST("/v1/chat/completions", chatCompletions)
router.GET("/v1/models", models)
// 获取端口
port := DefaultPort
if portStr := os.Getenv("PORT"); portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil {
port = p
}
}
log.Printf("======== OnDemand KEY池数量:%d ========", len(ondemandAPIKeys))
log.Printf("服务器启动在端口:%d", port)
// 启动服务器
if err := router.Run(fmt.Sprintf(":%d", port)); err != nil {
log.Fatal("启动服务器失败:", err)
}
}