depai / main.go
nbugs's picture
Update main.go
4c4cbf4 verified
package main
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"math/rand"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"regexp"
"strings"
"sync"
"syscall"
"time"
"strconv"
"github.com/google/uuid"
"github.com/spf13/viper"
"golang.org/x/net/proxy"
)
// ---------------------- 配置结构 ----------------------
type Config struct {
ThinkingServices []ThinkingService `mapstructure:"thinking_services"`
Channels map[string]Channel `mapstructure:"channels"`
Global GlobalConfig `mapstructure:"global"`
}
type ThinkingService struct {
ID int `mapstructure:"id"`
Name string `mapstructure:"name"`
Model string `mapstructure:"model"`
BaseURL string `mapstructure:"base_url"`
APIPath string `mapstructure:"api_path"`
APIKey string `mapstructure:"api_key"`
Timeout int `mapstructure:"timeout"`
Retry int `mapstructure:"retry"`
Weight int `mapstructure:"weight"`
Proxy string `mapstructure:"proxy"`
Mode string `mapstructure:"mode"` // "standard" 或 "full"
ReasoningEffort string `mapstructure:"reasoning_effort"`
ReasoningFormat string `mapstructure:"reasoning_format"`
Temperature *float64 `mapstructure:"temperature"`
ForceStopDeepThinking bool `mapstructure:"force_stop_deep_thinking"` // 配置项:标准模式下遇到 content 时是否立即停止
}
func (s *ThinkingService) GetFullURL() string {
return s.BaseURL + s.APIPath
}
type Channel struct {
Name string `mapstructure:"name"`
BaseURL string `mapstructure:"base_url"`
APIPath string `mapstructure:"api_path"`
Timeout int `mapstructure:"timeout"`
Proxy string `mapstructure:"proxy"`
}
func (c *Channel) GetFullURL() string {
return c.BaseURL + c.APIPath
}
type LogConfig struct {
Level string `mapstructure:"level"`
Format string `mapstructure:"format"`
Output string `mapstructure:"output"`
FilePath string `mapstructure:"file_path"`
Debug DebugConfig `mapstructure:"debug"`
}
type DebugConfig struct {
Enabled bool `mapstructure:"enabled"`
PrintRequest bool `mapstructure:"print_request"`
PrintResponse bool `mapstructure:"print_response"`
MaxContentLength int `mapstructure:"max_content_length"`
}
type ProxyConfig struct {
Enabled bool `mapstructure:"enabled"`
Default string `mapstructure:"default"`
AllowInsecure bool `mapstructure:"allow_insecure"`
}
type GlobalConfig struct {
MaxRetries int `mapstructure:"max_retries"`
DefaultTimeout int `mapstructure:"default_timeout"`
ErrorCodes struct {
RetryOn []int `mapstructure:"retry_on"`
} `mapstructure:"error_codes"`
Log LogConfig `mapstructure:"log"`
Server ServerConfig `mapstructure:"server"`
Proxy ProxyConfig `mapstructure:"proxy"`
ConfigPaths []string `mapstructure:"config_paths"`
Thinking ThinkingConfig `mapstructure:"thinking"`
}
type ServerConfig struct {
Port int `mapstructure:"port"`
Host string `mapstructure:"host"`
ReadTimeout int `mapstructure:"read_timeout"`
WriteTimeout int `mapstructure:"write_timeout"`
IdleTimeout int `mapstructure:"idle_timeout"`
}
type ThinkingConfig struct {
Enabled bool `mapstructure:"enabled"`
AddToAllRequests bool `mapstructure:"add_to_all_requests"`
Timeout int `mapstructure:"timeout"`
}
// ---------------------- API 相关结构 ----------------------
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream,omitempty"`
APIKey string `json:"-"` // 内部传递,不序列化
}
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ReasoningContent interface{} `json:"reasoning_content,omitempty"`
}
type ChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
}
type Choice struct {
Index int `json:"index"`
Message ChatCompletionMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// ---------------------- 日志工具 ----------------------
type RequestLogger struct {
RequestID string
Model string
StartTime time.Time
logs []string
config *Config
}
func NewRequestLogger(config *Config) *RequestLogger {
return &RequestLogger{
RequestID: uuid.New().String(),
StartTime: time.Now(),
logs: make([]string, 0),
config: config,
}
}
func (l *RequestLogger) Log(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
l.logs = append(l.logs, fmt.Sprintf("[%s] %s", time.Now().Format(time.RFC3339), msg))
log.Printf("[RequestID: %s] %s", l.RequestID, msg)
}
func (l *RequestLogger) LogContent(contentType string, content interface{}, maxLength int) {
if !l.config.Global.Log.Debug.Enabled {
return
}
sanitizedContent := sanitizeJSON(content)
truncatedContent := truncateContent(sanitizedContent, maxLength)
l.Log("%s Content:\n%s", contentType, truncatedContent)
}
func truncateContent(content string, maxLength int) string {
if len(content) <= maxLength {
return content
}
return content[:maxLength] + "... (truncated)"
}
func sanitizeJSON(data interface{}) string {
sanitized, err := json.Marshal(data)
if err != nil {
return "Failed to marshal JSON"
}
content := string(sanitized)
sensitivePattern := `"api_key":\s*"[^"]*"`
content = regexp.MustCompile(sensitivePattern).ReplaceAllString(content, `"api_key":"****"`)
return content
}
func extractRealAPIKey(fullKey string) string {
parts := strings.Split(fullKey, "-")
if len(parts) >= 3 && (parts[0] == "deep" || parts[0] == "openai") {
return strings.Join(parts[2:], "-")
}
return fullKey
}
func extractChannelID(fullKey string) string {
parts := strings.Split(fullKey, "-")
if len(parts) >= 2 && (parts[0] == "deep" || parts[0] == "openai") {
return parts[1]
}
return "1" // 默认渠道
}
func logAPIKey(key string) string {
if len(key) <= 8 {
return "****"
}
return key[:4] + "..." + key[len(key)-4:]
}
// ---------------------- Server 结构 ----------------------
type Server struct {
config *Config
srv *http.Server
}
var (
randMu sync.Mutex
randGen = rand.New(rand.NewSource(time.Now().UnixNano()))
)
func NewServer(config *Config) *Server {
return &Server{
config: config,
}
}
func (s *Server) Start() error {
mux := http.NewServeMux()
mux.HandleFunc("/v1/chat/completions", s.handleOpenAIRequests)
mux.HandleFunc("/v1/models", s.handleOpenAIRequests)
mux.HandleFunc("/health", s.handleHealth)
s.srv = &http.Server{
Addr: fmt.Sprintf("%s:%d", s.config.Global.Server.Host, s.config.Global.Server.Port),
Handler: mux,
ReadTimeout: time.Duration(s.config.Global.Server.ReadTimeout) * time.Second,
WriteTimeout: time.Duration(s.config.Global.Server.WriteTimeout) * time.Second,
IdleTimeout: time.Duration(s.config.Global.Server.IdleTimeout) * time.Second,
}
log.Printf("Server starting on %s\n", s.srv.Addr)
return s.srv.ListenAndServe()
}
func (s *Server) Shutdown(ctx context.Context) error {
return s.srv.Shutdown(ctx)
}
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "healthy"})
}
func (s *Server) handleOpenAIRequests(w http.ResponseWriter, r *http.Request) {
logger := NewRequestLogger(s.config)
fullAPIKey := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
apiKey := extractRealAPIKey(fullAPIKey)
channelID := extractChannelID(fullAPIKey)
logger.Log("Received request for %s with API Key: %s", r.URL.Path, logAPIKey(fullAPIKey))
logger.Log("Extracted channel ID: %s", channelID)
logger.Log("Extracted real API Key: %s", logAPIKey(apiKey))
targetChannel, ok := s.config.Channels[channelID]
if !ok {
http.Error(w, "Invalid channel", http.StatusBadRequest)
return
}
if r.URL.Path == "/v1/models" {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
req := &ChatCompletionRequest{APIKey: apiKey}
s.forwardModelsRequest(w, r.Context(), req, targetChannel)
return
}
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
logger.Log("Error reading request body: %v", err)
http.Error(w, "Failed to read request", http.StatusBadRequest)
return
}
r.Body.Close()
r.Body = io.NopCloser(bytes.NewBuffer(body))
if s.config.Global.Log.Debug.PrintRequest {
logger.LogContent("Request", string(body), s.config.Global.Log.Debug.MaxContentLength)
}
var req ChatCompletionRequest
if err := json.NewDecoder(bytes.NewBuffer(body)).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
req.APIKey = apiKey
thinkingService := s.getWeightedRandomThinkingService()
logger.Log("Using thinking service: %s with API Key: %s", thinkingService.Name, logAPIKey(thinkingService.APIKey))
if req.Stream {
handler, err := NewStreamHandler(w, thinkingService, targetChannel, s.config)
if err != nil {
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
return
}
if err := handler.HandleRequest(r.Context(), &req); err != nil {
logger.Log("Stream handler error: %v", err)
}
} else {
thinkingResp, err := s.processThinkingContent(r.Context(), &req, thinkingService)
if err != nil {
logger.Log("Error processing thinking content: %v", err)
http.Error(w, "Thinking service error: "+err.Error(), http.StatusInternalServerError)
return
}
enhancedReq := s.prepareEnhancedRequest(&req, thinkingResp, thinkingService)
s.forwardRequest(w, r.Context(), enhancedReq, targetChannel)
}
}
func (s *Server) getWeightedRandomThinkingService() ThinkingService {
thinkingServices := s.config.ThinkingServices
if len(thinkingServices) == 0 {
return ThinkingService{}
}
totalWeight := 0
for _, svc := range thinkingServices {
totalWeight += svc.Weight
}
if totalWeight <= 0 {
log.Println("Warning: Total weight of thinking services is not positive, using first service as default.")
return thinkingServices[0]
}
randMu.Lock()
randNum := randGen.Intn(totalWeight)
randMu.Unlock()
currentSum := 0
for _, svc := range thinkingServices {
currentSum += svc.Weight
if randNum < currentSum {
return svc
}
}
return thinkingServices[0]
}
// ---------------------- 非流式处理思考服务 ----------------------
type ThinkingResponse struct {
Content string
ReasoningContent string
}
func (s *Server) processThinkingContent(ctx context.Context, req *ChatCompletionRequest, svc ThinkingService) (*ThinkingResponse, error) {
logger := NewRequestLogger(s.config)
log.Printf("Getting thinking content from service: %s (mode=%s)", svc.Name, svc.Mode)
thinkingReq := *req
thinkingReq.Model = svc.Model
thinkingReq.APIKey = svc.APIKey
var systemPrompt string
if svc.Mode == "full" {
systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly."
} else {
systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step."
}
thinkingReq.Messages = append([]ChatCompletionMessage{
{Role: "system", Content: systemPrompt},
}, thinkingReq.Messages...)
temp := 0.7
if svc.Temperature != nil {
temp = *svc.Temperature
}
payload := map[string]interface{}{
"model": svc.Model,
"messages": thinkingReq.Messages,
"stream": false,
"temperature": temp,
}
if isValidReasoningEffort(svc.ReasoningEffort) {
payload["reasoning_effort"] = svc.ReasoningEffort
}
if isValidReasoningFormat(svc.ReasoningFormat) {
payload["reasoning_format"] = svc.ReasoningFormat
}
if s.config.Global.Log.Debug.PrintRequest {
logger.LogContent("Thinking Service Request", payload, s.config.Global.Log.Debug.MaxContentLength)
}
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal thinking request: %v", err)
}
client, err := createHTTPClient(svc.Proxy, time.Duration(svc.Timeout)*time.Second)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client: %v", err)
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, svc.GetFullURL(), bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+svc.APIKey)
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send thinking request: %v", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %v", err)
}
if s.config.Global.Log.Debug.PrintResponse {
logger.LogContent("Thinking Service Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("thinking service returned %d: %s", resp.StatusCode, string(respBody))
}
var thinkingResp ChatCompletionResponse
if err := json.Unmarshal(respBody, &thinkingResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal thinking response: %v", err)
}
if len(thinkingResp.Choices) == 0 {
return nil, fmt.Errorf("thinking service returned no choices")
}
result := &ThinkingResponse{}
choice := thinkingResp.Choices[0]
if svc.Mode == "full" {
result.ReasoningContent = choice.Message.Content
result.Content = "Based on the above detailed analysis."
} else {
if choice.Message.ReasoningContent != nil {
switch v := choice.Message.ReasoningContent.(type) {
case string:
result.ReasoningContent = v
case map[string]interface{}:
if j, err := json.Marshal(v); err == nil {
result.ReasoningContent = string(j)
}
}
}
if result.ReasoningContent == "" {
result.ReasoningContent = choice.Message.Content
}
result.Content = "Based on the above reasoning."
}
return result, nil
}
func (s *Server) prepareEnhancedRequest(originalReq *ChatCompletionRequest, thinkingResp *ThinkingResponse, svc ThinkingService) *ChatCompletionRequest {
newReq := *originalReq
var systemPrompt string
if svc.Mode == "full" {
systemPrompt = fmt.Sprintf(`Consider the following detailed analysis (not shown to user):
%s
Provide a clear, concise response that incorporates insights from this analysis.`, thinkingResp.ReasoningContent)
} else {
systemPrompt = fmt.Sprintf(`Previous thinking process:
%s
Please consider the above thinking process in your response.`, thinkingResp.ReasoningContent)
}
newReq.Messages = append([]ChatCompletionMessage{
{Role: "system", Content: systemPrompt},
}, newReq.Messages...)
return &newReq
}
func (s *Server) forwardRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, channel Channel) {
logger := NewRequestLogger(s.config)
if s.config.Global.Log.Debug.PrintRequest {
logger.LogContent("Forward Request", req, s.config.Global.Log.Debug.MaxContentLength)
}
jsonData, err := json.Marshal(req)
if err != nil {
http.Error(w, "Failed to marshal request", http.StatusInternalServerError)
return
}
client, err := createHTTPClient(channel.Proxy, time.Duration(channel.Timeout)*time.Second)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError)
return
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, channel.GetFullURL(), bytes.NewBuffer(jsonData))
if err != nil {
http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
resp, err := client.Do(httpReq)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError)
return
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
http.Error(w, "Failed to read response", http.StatusInternalServerError)
return
}
if s.config.Global.Log.Debug.PrintResponse {
logger.LogContent("Forward Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode)
return
}
for k, vals := range resp.Header {
for _, v := range vals {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
w.Write(respBody)
}
func (s *Server) forwardModelsRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, targetChannel Channel) {
logger := NewRequestLogger(s.config)
if s.config.Global.Log.Debug.PrintRequest {
logger.LogContent("/v1/models Request", req, s.config.Global.Log.Debug.MaxContentLength)
}
fullChatURL := targetChannel.GetFullURL()
parsedURL, err := url.Parse(fullChatURL)
if err != nil {
http.Error(w, "Failed to parse channel URL", http.StatusInternalServerError)
return
}
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
modelsURL := strings.TrimSuffix(baseURL, "/") + "/v1/models"
client, err := createHTTPClient(targetChannel.Proxy, time.Duration(targetChannel.Timeout)*time.Second)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError)
return
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil)
if err != nil {
http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
}
httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
resp, err := client.Do(httpReq)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError)
return
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
http.Error(w, "Failed to read response", http.StatusInternalServerError)
return
}
if s.config.Global.Log.Debug.PrintResponse {
logger.LogContent("/v1/models Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode)
return
}
for k, vals := range resp.Header {
for _, v := range vals {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
w.Write(respBody)
}
// ---------------------- 流式处理 ----------------------
// collectedReasoningBuffer 用于收集思考服务返回的 reasoning_content(标准模式只收集 reasoning_content;full 模式收集全部)
type collectedReasoningBuffer struct {
builder strings.Builder
mode string
}
func (cb *collectedReasoningBuffer) append(text string) {
cb.builder.WriteString(text)
}
func (cb *collectedReasoningBuffer) get() string {
return cb.builder.String()
}
type StreamHandler struct {
thinkingService ThinkingService
targetChannel Channel
writer http.ResponseWriter
flusher http.Flusher
config *Config
}
func NewStreamHandler(w http.ResponseWriter, thinkingService ThinkingService, targetChannel Channel, config *Config) (*StreamHandler, error) {
flusher, ok := w.(http.Flusher)
if !ok {
return nil, fmt.Errorf("streaming not supported by this writer")
}
return &StreamHandler{
thinkingService: thinkingService,
targetChannel: targetChannel,
writer: w,
flusher: flusher,
config: config,
}, nil
}
// HandleRequest 分两阶段:先流式接收思考服务 SSE 并转发给客户端(只转发包含 reasoning_content 的 chunk,标准模式遇到纯 content 则中断);
// 然后使用收集到的 reasoning_content 构造最终请求,发起目标 Channel 的 SSE 并转发给客户端。
func (h *StreamHandler) HandleRequest(ctx context.Context, req *ChatCompletionRequest) error {
h.writer.Header().Set("Content-Type", "text/event-stream")
h.writer.Header().Set("Cache-Control", "no-cache")
h.writer.Header().Set("Connection", "keep-alive")
logger := NewRequestLogger(h.config)
reasonBuf := &collectedReasoningBuffer{mode: h.thinkingService.Mode}
// 阶段一:流式接收思考服务 SSE
if err := h.streamThinkingService(ctx, req, reasonBuf, logger); err != nil {
return err
}
// 阶段二:构造最终请求,并流式转发目标 Channel SSE
finalReq := h.prepareFinalRequest(req, reasonBuf.get())
if err := h.streamFinalChannel(ctx, finalReq, logger); err != nil {
return err
}
return nil
}
// streamThinkingService 连接思考服务 SSE,原样转发包含 reasoning_content 的 SSE chunk给客户端,同时收集 reasoning_content。
// 在标准模式下,一旦遇到只返回 non-empty 的 content(且 reasoning_content 为空),则立即中断,不转发该 chunk。
func (h *StreamHandler) streamThinkingService(ctx context.Context, req *ChatCompletionRequest, reasonBuf *collectedReasoningBuffer, logger *RequestLogger) error {
thinkingReq := *req
thinkingReq.Model = h.thinkingService.Model
thinkingReq.APIKey = h.thinkingService.APIKey
var systemPrompt string
if h.thinkingService.Mode == "full" {
systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly."
} else {
systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step."
}
messages := append([]ChatCompletionMessage{
{Role: "system", Content: systemPrompt},
}, thinkingReq.Messages...)
temp := 0.7
if h.thinkingService.Temperature != nil {
temp = *h.thinkingService.Temperature
}
bodyMap := map[string]interface{}{
"model": thinkingReq.Model,
"messages": messages,
"temperature": temp,
"stream": true,
}
if isValidReasoningEffort(h.thinkingService.ReasoningEffort) {
bodyMap["reasoning_effort"] = h.thinkingService.ReasoningEffort
}
if isValidReasoningFormat(h.thinkingService.ReasoningFormat) {
bodyMap["reasoning_format"] = h.thinkingService.ReasoningFormat
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return err
}
client, err := createHTTPClient(h.thinkingService.Proxy, time.Duration(h.thinkingService.Timeout)*time.Second)
if err != nil {
return err
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.thinkingService.GetFullURL(), bytes.NewBuffer(jsonData))
if err != nil {
return err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+h.thinkingService.APIKey)
log.Printf("Starting ThinkingService SSE => %s (mode=%s, force_stop=%v)", h.thinkingService.GetFullURL(), h.thinkingService.Mode, h.thinkingService.ForceStopDeepThinking)
resp, err := client.Do(httpReq)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return fmt.Errorf("thinking service status=%d, body=%s", resp.StatusCode, string(b))
}
reader := bufio.NewReader(resp.Body)
var lastLine string
// 标识是否需中断思考 SSE(标准模式下)
forceStop := false
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
break
}
return err
}
line = strings.TrimSpace(line)
if line == "" || line == lastLine {
continue
}
lastLine = line
if !strings.HasPrefix(line, "data: ") {
continue
}
dataPart := strings.TrimPrefix(line, "data: ")
if dataPart == "[DONE]" {
// 思考服务结束:直接中断(不发送特殊标记)
break
}
var chunk struct {
Choices []struct {
Delta struct {
Content string `json:"content,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
} `json:"choices"`
}
if err := json.Unmarshal([]byte(dataPart), &chunk); err != nil {
// 如果解析失败,直接原样转发
h.writer.Write([]byte("data: " + dataPart + "\n\n"))
h.flusher.Flush()
continue
}
if len(chunk.Choices) > 0 {
c := chunk.Choices[0]
if h.config.Global.Log.Debug.PrintResponse {
logger.LogContent("Thinking SSE chunk", chunk, h.config.Global.Log.Debug.MaxContentLength)
}
if h.thinkingService.Mode == "full" {
// full 模式:收集所有内容
if c.Delta.ReasoningContent != "" {
reasonBuf.append(c.Delta.ReasoningContent)
}
if c.Delta.Content != "" {
reasonBuf.append(c.Delta.Content)
}
// 原样转发整 chunk
forwardLine := "data: " + dataPart + "\n\n"
h.writer.Write([]byte(forwardLine))
h.flusher.Flush()
} else {
// standard 模式:只收集 reasoning_content
if c.Delta.ReasoningContent != "" {
reasonBuf.append(c.Delta.ReasoningContent)
// 转发该 chunk
forwardLine := "data: " + dataPart + "\n\n"
h.writer.Write([]byte(forwardLine))
h.flusher.Flush()
}
// 如果遇到 chunk 中只有 content(且 reasoning_content 为空),认为思考链结束,不转发该 chunk
if c.Delta.Content != "" && strings.TrimSpace(c.Delta.ReasoningContent) == "" {
forceStop = true
}
}
// 如果 finishReason 非空,也认为结束
if c.FinishReason != nil && *c.FinishReason != "" {
forceStop = true
}
}
if forceStop {
break
}
}
// 读空剩余数据
io.Copy(io.Discard, reader)
return nil
}
func (h *StreamHandler) prepareFinalRequest(originalReq *ChatCompletionRequest, reasoningCollected string) *ChatCompletionRequest {
req := *originalReq
var systemPrompt string
if h.thinkingService.Mode == "full" {
systemPrompt = fmt.Sprintf(
`Consider the following detailed analysis (not shown to user):
%s
Provide a clear, concise response that incorporates insights from this analysis.`,
reasoningCollected,
)
} else {
systemPrompt = fmt.Sprintf(
`Previous thinking process:
%s
Please consider the above thinking process in your response.`,
reasoningCollected,
)
}
req.Messages = append([]ChatCompletionMessage{
{Role: "system", Content: systemPrompt},
}, req.Messages...)
return &req
}
func (h *StreamHandler) streamFinalChannel(ctx context.Context, req *ChatCompletionRequest, logger *RequestLogger) error {
if h.config.Global.Log.Debug.PrintRequest {
logger.LogContent("Final Request to Channel", req, h.config.Global.Log.Debug.MaxContentLength)
}
jsonData, err := json.Marshal(req)
if err != nil {
return err
}
client, err := createHTTPClient(h.targetChannel.Proxy, time.Duration(h.targetChannel.Timeout)*time.Second)
if err != nil {
return err
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.targetChannel.GetFullURL(), bytes.NewBuffer(jsonData))
if err != nil {
return err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
resp, err := client.Do(httpReq)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return fmt.Errorf("target channel status=%d, body=%s", resp.StatusCode, string(b))
}
reader := bufio.NewReader(resp.Body)
var lastLine string
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
break
}
return err
}
line = strings.TrimSpace(line)
if line == "" || line == lastLine {
continue
}
lastLine = line
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
doneLine := "data: [DONE]\n\n"
h.writer.Write([]byte(doneLine))
h.flusher.Flush()
break
}
forwardLine := "data: " + data + "\n\n"
h.writer.Write([]byte(forwardLine))
h.flusher.Flush()
if h.config.Global.Log.Debug.PrintResponse {
logger.LogContent("Channel SSE chunk", forwardLine, h.config.Global.Log.Debug.MaxContentLength)
}
}
io.Copy(io.Discard, reader)
return nil
}
// ---------------------- HTTP Client 工具 ----------------------
func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) {
transport := &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
if proxyURL != "" {
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %v", err)
}
switch parsedURL.Scheme {
case "http", "https":
transport.Proxy = http.ProxyURL(parsedURL)
case "socks5":
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("failed to create SOCKS5 dialer: %v", err)
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
}
}
return &http.Client{
Transport: transport,
Timeout: timeout,
}, nil
}
func maskSensitiveHeaders(headers http.Header) http.Header {
masked := make(http.Header)
for k, vals := range headers {
if strings.ToLower(k) == "authorization" {
masked[k] = []string{"Bearer ****"}
} else {
masked[k] = vals
}
}
return masked
}
func isValidReasoningEffort(effort string) bool {
switch strings.ToLower(effort) {
case "low", "medium", "high":
return true
}
return false
}
func isValidReasoningFormat(format string) bool {
switch strings.ToLower(format) {
case "parsed", "raw", "hidden":
return true
}
return false
}
// ---------------------- 配置加载与验证 ----------------------
// 首先添加这些辅助函数
func getEnvOrDefault(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
func getEnvIntOrDefault(key string, defaultValue int) int {
if value := os.Getenv(key); value != "" {
if intValue, err := strconv.Atoi(value); err == nil {
return intValue
}
}
return defaultValue
}
func getEnvBoolOrDefault(key string, defaultValue bool) bool {
if value := os.Getenv(key); value != "" {
if boolValue, err := strconv.ParseBool(value); err == nil {
return boolValue
}
}
return defaultValue
}
func getEnvFloatPtr(key string, defaultValue float64) *float64 {
if value := os.Getenv(key); value != "" {
if floatValue, err := strconv.ParseFloat(value, 64); err == nil {
return &floatValue
}
}
return &defaultValue
}
// 然后替换原有的 loadConfig 函数
func loadConfig() (*Config, error) {
// 优先从环境变量加载完整配置
if envConfig := os.Getenv("CONFIG_JSON"); envConfig != "" {
var cfg Config
if err := json.Unmarshal([]byte(envConfig), &cfg); err == nil {
if err := validateConfig(&cfg); err == nil {
log.Println("Using configuration from CONFIG_JSON environment variable")
return &cfg, nil
}
}
}
// 如果没有完整的 JSON 配置,则从独立的环境变量构建配置
log.Println("Building configuration from individual environment variables")
cfg := &Config{
ThinkingServices: []ThinkingService{
{
ID: 1,
Name: getEnvOrDefault("THINKING_SERVICE_NAME", "modelscope-deepseek-thinking"),
Mode: getEnvOrDefault("THINKING_SERVICE_MODE", "standard"),
Model: getEnvOrDefault("THINKING_SERVICE_MODEL", "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"),
BaseURL: getEnvOrDefault("THINKING_SERVICE_BASE_URL", "https://api-inference.modelscope.cn"),
APIPath: getEnvOrDefault("THINKING_SERVICE_API_PATH", "/v1/chat/completions"),
APIKey: os.Getenv("THINKING_SERVICE_API_KEY"), // 敏感信息必须从环境变量读取
Timeout: getEnvIntOrDefault("THINKING_SERVICE_TIMEOUT", 6000),
Weight: getEnvIntOrDefault("THINKING_SERVICE_WEIGHT", 100),
Proxy: getEnvOrDefault("THINKING_SERVICE_PROXY", ""),
ReasoningEffort: getEnvOrDefault("THINKING_SERVICE_REASONING_EFFORT", "high"),
ReasoningFormat: getEnvOrDefault("THINKING_SERVICE_REASONING_FORMAT", "parsed"),
Temperature: getEnvFloatPtr("THINKING_SERVICE_TEMPERATURE", 0.8),
ForceStopDeepThinking: getEnvBoolOrDefault("THINKING_SERVICE_FORCE_STOP", false),
},
},
Channels: map[string]Channel{
"1": {
Name: getEnvOrDefault("CHANNEL_NAME", "Gemini-channel"),
BaseURL: getEnvOrDefault("CHANNEL_BASE_URL", "https://Richardlsr-gemini-balance.hf.space/hf"),
APIPath: getEnvOrDefault("CHANNEL_API_PATH", "/v1/chat/completions"),
Timeout: getEnvIntOrDefault("CHANNEL_TIMEOUT", 600),
Proxy: getEnvOrDefault("CHANNEL_PROXY", ""),
},
},
Global: GlobalConfig{
Server: ServerConfig{
Port: getEnvIntOrDefault("SERVER_PORT", 7860), // HF 默认端口
Host: getEnvOrDefault("SERVER_HOST", "0.0.0.0"),
ReadTimeout: getEnvIntOrDefault("SERVER_READ_TIMEOUT", 600),
WriteTimeout: getEnvIntOrDefault("SERVER_WRITE_TIMEOUT", 600),
IdleTimeout: getEnvIntOrDefault("SERVER_IDLE_TIMEOUT", 600),
},
Log: LogConfig{
Level: getEnvOrDefault("LOG_LEVEL", "info"),
Format: getEnvOrDefault("LOG_FORMAT", "text"),
Output: getEnvOrDefault("LOG_OUTPUT", "console"), // HF 环境下默认输出到控制台
FilePath: getEnvOrDefault("LOG_FILE_PATH", "./logs/deepai.log"),
Debug: DebugConfig{
Enabled: getEnvBoolOrDefault("LOG_DEBUG_ENABLED", true),
PrintRequest: getEnvBoolOrDefault("LOG_PRINT_REQUEST", true),
PrintResponse: getEnvBoolOrDefault("LOG_PRINT_RESPONSE", true),
MaxContentLength: getEnvIntOrDefault("LOG_MAX_CONTENT_LENGTH", 1000),
},
},
},
}
// 验证必需的环境变量
if cfg.ThinkingServices[0].APIKey == "" {
return nil, fmt.Errorf("THINKING_SERVICE_API_KEY environment variable is required")
}
if err := validateConfig(cfg); err != nil {
return nil, fmt.Errorf("config validation error: %v", err)
}
return cfg, nil
}
func validateConfig(config *Config) error {
if len(config.ThinkingServices) == 0 {
return fmt.Errorf("no thinking services configured")
}
if len(config.Channels) == 0 {
return fmt.Errorf("no channels configured")
}
for i, svc := range config.ThinkingServices {
if svc.BaseURL == "" {
return fmt.Errorf("thinking service %s has empty baseURL", svc.Name)
}
if svc.APIKey == "" {
return fmt.Errorf("thinking service %s has empty apiKey", svc.Name)
}
if svc.Timeout <= 0 {
return fmt.Errorf("thinking service %s has invalid timeout", svc.Name)
}
if svc.Model == "" {
return fmt.Errorf("thinking service %s has empty model", svc.Name)
}
if svc.Mode == "" {
config.ThinkingServices[i].Mode = "standard"
} else if svc.Mode != "standard" && svc.Mode != "full" {
return fmt.Errorf("thinking service %s unknown mode=%s", svc.Name, svc.Mode)
}
}
return nil
}
func main() {
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile)
cfg, err := loadConfig()
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
log.Printf("Using config file: %s", viper.ConfigFileUsed())
server := NewServer(cfg)
done := make(chan os.Signal, 1)
signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
go func() {
if err := server.Start(); err != nil && err != http.ErrServerClosed {
log.Fatalf("start server error: %v", err)
}
}()
log.Printf("Server started successfully")
<-done
log.Print("Server stopping...")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
log.Printf("Server forced to shutdown: %v", err)
}
log.Print("Server stopped")
}