nbugs commited on
Commit
59b3d84
·
verified ·
1 Parent(s): 924b5f2

Upload main.go

Browse files
Files changed (1) hide show
  1. main.go +1181 -0
main.go ADDED
@@ -0,0 +1,1181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "bufio"
5
+ "bytes"
6
+ "context"
7
+ "encoding/json"
8
+ "flag"
9
+ "fmt"
10
+ "io"
11
+ "log"
12
+ "math/rand"
13
+ "net"
14
+ "net/http"
15
+ "net/url"
16
+ "os"
17
+ "os/signal"
18
+ "path/filepath"
19
+ "regexp"
20
+ "strings"
21
+ "sync"
22
+ "syscall"
23
+ "time"
24
+
25
+ "github.com/google/uuid"
26
+ "github.com/spf13/viper"
27
+ "golang.org/x/net/proxy"
28
+ )
29
+
30
+ // ---------------------- 配置结构 ----------------------
31
+
32
+ type Config struct {
33
+ ThinkingServices []ThinkingService `mapstructure:"thinking_services"`
34
+ Channels map[string]Channel `mapstructure:"channels"`
35
+ Global GlobalConfig `mapstructure:"global"`
36
+ }
37
+
38
+ type ThinkingService struct {
39
+ ID int `mapstructure:"id"`
40
+ Name string `mapstructure:"name"`
41
+ Model string `mapstructure:"model"`
42
+ BaseURL string `mapstructure:"base_url"`
43
+ APIPath string `mapstructure:"api_path"`
44
+ APIKey string `mapstructure:"api_key"`
45
+ Timeout int `mapstructure:"timeout"`
46
+ Retry int `mapstructure:"retry"`
47
+ Weight int `mapstructure:"weight"`
48
+ Proxy string `mapstructure:"proxy"`
49
+ Mode string `mapstructure:"mode"` // "standard" 或 "full"
50
+ ReasoningEffort string `mapstructure:"reasoning_effort"`
51
+ ReasoningFormat string `mapstructure:"reasoning_format"`
52
+ Temperature *float64 `mapstructure:"temperature"`
53
+ ForceStopDeepThinking bool `mapstructure:"force_stop_deep_thinking"` // 配置项:标准模式下遇到 content 时是否立即停止
54
+ }
55
+
56
+ func (s *ThinkingService) GetFullURL() string {
57
+ return s.BaseURL + s.APIPath
58
+ }
59
+
60
+ type Channel struct {
61
+ Name string `mapstructure:"name"`
62
+ BaseURL string `mapstructure:"base_url"`
63
+ APIPath string `mapstructure:"api_path"`
64
+ Timeout int `mapstructure:"timeout"`
65
+ Proxy string `mapstructure:"proxy"`
66
+ }
67
+
68
+ func (c *Channel) GetFullURL() string {
69
+ return c.BaseURL + c.APIPath
70
+ }
71
+
72
+ type LogConfig struct {
73
+ Level string `mapstructure:"level"`
74
+ Format string `mapstructure:"format"`
75
+ Output string `mapstructure:"output"`
76
+ FilePath string `mapstructure:"file_path"`
77
+ Debug DebugConfig `mapstructure:"debug"`
78
+ }
79
+
80
+ type DebugConfig struct {
81
+ Enabled bool `mapstructure:"enabled"`
82
+ PrintRequest bool `mapstructure:"print_request"`
83
+ PrintResponse bool `mapstructure:"print_response"`
84
+ MaxContentLength int `mapstructure:"max_content_length"`
85
+ }
86
+
87
+ type ProxyConfig struct {
88
+ Enabled bool `mapstructure:"enabled"`
89
+ Default string `mapstructure:"default"`
90
+ AllowInsecure bool `mapstructure:"allow_insecure"`
91
+ }
92
+
93
+ type GlobalConfig struct {
94
+ MaxRetries int `mapstructure:"max_retries"`
95
+ DefaultTimeout int `mapstructure:"default_timeout"`
96
+ ErrorCodes struct {
97
+ RetryOn []int `mapstructure:"retry_on"`
98
+ } `mapstructure:"error_codes"`
99
+ Log LogConfig `mapstructure:"log"`
100
+ Server ServerConfig `mapstructure:"server"`
101
+ Proxy ProxyConfig `mapstructure:"proxy"`
102
+ ConfigPaths []string `mapstructure:"config_paths"`
103
+ Thinking ThinkingConfig `mapstructure:"thinking"`
104
+ }
105
+
106
+ type ServerConfig struct {
107
+ Port int `mapstructure:"port"`
108
+ Host string `mapstructure:"host"`
109
+ ReadTimeout int `mapstructure:"read_timeout"`
110
+ WriteTimeout int `mapstructure:"write_timeout"`
111
+ IdleTimeout int `mapstructure:"idle_timeout"`
112
+ }
113
+
114
+ type ThinkingConfig struct {
115
+ Enabled bool `mapstructure:"enabled"`
116
+ AddToAllRequests bool `mapstructure:"add_to_all_requests"`
117
+ Timeout int `mapstructure:"timeout"`
118
+ }
119
+
120
+ // ---------------------- API 相关结构 ----------------------
121
+
122
+ type ChatCompletionRequest struct {
123
+ Model string `json:"model"`
124
+ Messages []ChatCompletionMessage `json:"messages"`
125
+ Temperature float64 `json:"temperature,omitempty"`
126
+ MaxTokens int `json:"max_tokens,omitempty"`
127
+ Stream bool `json:"stream,omitempty"`
128
+ APIKey string `json:"-"` // 内部传递,不序列化
129
+ }
130
+
131
+ type ChatCompletionMessage struct {
132
+ Role string `json:"role"`
133
+ Content string `json:"content"`
134
+ ReasoningContent interface{} `json:"reasoning_content,omitempty"`
135
+ }
136
+
137
+ type ChatCompletionResponse struct {
138
+ ID string `json:"id"`
139
+ Object string `json:"object"`
140
+ Created int64 `json:"created"`
141
+ Model string `json:"model"`
142
+ Choices []Choice `json:"choices"`
143
+ Usage Usage `json:"usage"`
144
+ }
145
+
146
+ type Choice struct {
147
+ Index int `json:"index"`
148
+ Message ChatCompletionMessage `json:"message"`
149
+ FinishReason string `json:"finish_reason"`
150
+ }
151
+
152
+ type Usage struct {
153
+ PromptTokens int `json:"prompt_tokens"`
154
+ CompletionTokens int `json:"completion_tokens"`
155
+ TotalTokens int `json:"total_tokens"`
156
+ }
157
+
158
+ // ---------------------- 日志工具 ----------------------
159
+
160
+ type RequestLogger struct {
161
+ RequestID string
162
+ Model string
163
+ StartTime time.Time
164
+ logs []string
165
+ config *Config
166
+ }
167
+
168
+ func NewRequestLogger(config *Config) *RequestLogger {
169
+ return &RequestLogger{
170
+ RequestID: uuid.New().String(),
171
+ StartTime: time.Now(),
172
+ logs: make([]string, 0),
173
+ config: config,
174
+ }
175
+ }
176
+
177
+ // 新增函数:替换环境变量
178
+ func replaceEnvVars(value string) string {
179
+ // 检查是否包含环境变量占位符 ${...}
180
+ if !strings.Contains(value, "${") {
181
+ return value
182
+ }
183
+
184
+ // 查找所有 ${...} 格式的环境变量并替换
185
+ result := value
186
+ for {
187
+ start := strings.Index(result, "${")
188
+ if start == -1 {
189
+ break
190
+ }
191
+
192
+ end := strings.Index(result[start:], "}")
193
+ if end == -1 {
194
+ break
195
+ }
196
+ end += start
197
+
198
+ // 提取环境变量名称
199
+ envName := result[start+2:end]
200
+
201
+ // 获取环境变量值,如果不存在则保持原样
202
+ envValue := os.Getenv(envName)
203
+ if envValue == "" {
204
+ // 可选:记录警告或使用默认值
205
+ // 这里我们保持原样,也可以设置为空字符串或其他默认值
206
+ }
207
+
208
+ // 替换环境变量
209
+ result = result[:start] + envValue + result[end+1:]
210
+ }
211
+
212
+ return result
213
+ }
214
+
215
+ // 处理配置结构体中的环境变量
216
+ func ProcessConfig(config *Config) {
217
+ // 处理 ThinkingServices
218
+ for i := range config.ThinkingServices {
219
+ config.ThinkingServices[i].Name = replaceEnvVars(config.ThinkingServices[i].Name)
220
+ config.ThinkingServices[i].Model = replaceEnvVars(config.ThinkingServices[i].Model)
221
+ config.ThinkingServices[i].BaseURL = replaceEnvVars(config.ThinkingServices[i].BaseURL)
222
+ config.ThinkingServices[i].APIKey = replaceEnvVars(config.ThinkingServices[i].APIKey)
223
+ config.ThinkingServices[i].Proxy = replaceEnvVars(config.ThinkingServices[i].Proxy)
224
+ }
225
+
226
+ // 处理 Channels
227
+ for key, channel := range config.Channels {
228
+ channel.Name = replaceEnvVars(channel.Name)
229
+ channel.BaseURL = replaceEnvVars(channel.BaseURL)
230
+ channel.Proxy = replaceEnvVars(channel.Proxy)
231
+ config.Channels[key] = channel
232
+ }
233
+
234
+ // 如果 Global 配置中也有环境变量,可以在这里添加处理逻辑
235
+ }
236
+
237
+ func (l *RequestLogger) Log(format string, args ...interface{}) {
238
+ msg := fmt.Sprintf(format, args...)
239
+ l.logs = append(l.logs, fmt.Sprintf("[%s] %s", time.Now().Format(time.RFC3339), msg))
240
+ log.Printf("[RequestID: %s] %s", l.RequestID, msg)
241
+ }
242
+
243
+ func (l *RequestLogger) LogContent(contentType string, content interface{}, maxLength int) {
244
+ if !l.config.Global.Log.Debug.Enabled {
245
+ return
246
+ }
247
+ sanitizedContent := sanitizeJSON(content)
248
+ truncatedContent := truncateContent(sanitizedContent, maxLength)
249
+ l.Log("%s Content:\n%s", contentType, truncatedContent)
250
+ }
251
+
252
+ func truncateContent(content string, maxLength int) string {
253
+ if len(content) <= maxLength {
254
+ return content
255
+ }
256
+ return content[:maxLength] + "... (truncated)"
257
+ }
258
+
259
+ func sanitizeJSON(data interface{}) string {
260
+ sanitized, err := json.Marshal(data)
261
+ if err != nil {
262
+ return "Failed to marshal JSON"
263
+ }
264
+ content := string(sanitized)
265
+ sensitivePattern := `"api_key":\s*"[^"]*"`
266
+ content = regexp.MustCompile(sensitivePattern).ReplaceAllString(content, `"api_key":"****"`)
267
+ return content
268
+ }
269
+
270
+ func extractRealAPIKey(fullKey string) string {
271
+ parts := strings.Split(fullKey, "-")
272
+ if len(parts) >= 3 && (parts[0] == "deep" || parts[0] == "openai") {
273
+ return strings.Join(parts[2:], "-")
274
+ }
275
+ return fullKey
276
+ }
277
+
278
+ func extractChannelID(fullKey string) string {
279
+ parts := strings.Split(fullKey, "-")
280
+ if len(parts) >= 2 && (parts[0] == "deep" || parts[0] == "openai") {
281
+ return parts[1]
282
+ }
283
+ return "1" // 默认渠道
284
+ }
285
+
286
+ func logAPIKey(key string) string {
287
+ if len(key) <= 8 {
288
+ return "****"
289
+ }
290
+ return key[:4] + "..." + key[len(key)-4:]
291
+ }
292
+
293
+ // ---------------------- Server 结构 ----------------------
294
+
295
+ type Server struct {
296
+ config *Config
297
+ srv *http.Server
298
+ }
299
+
300
+ var (
301
+ randMu sync.Mutex
302
+ randGen = rand.New(rand.NewSource(time.Now().UnixNano()))
303
+ )
304
+
305
+ func NewServer(config *Config) *Server {
306
+ return &Server{
307
+ config: config,
308
+ }
309
+ }
310
+
311
+ func (s *Server) Start() error {
312
+ mux := http.NewServeMux()
313
+ mux.HandleFunc("/v1/chat/completions", s.handleOpenAIRequests)
314
+ mux.HandleFunc("/v1/models", s.handleOpenAIRequests)
315
+ mux.HandleFunc("/health", s.handleHealth)
316
+
317
+ s.srv = &http.Server{
318
+ Addr: fmt.Sprintf("%s:%d", s.config.Global.Server.Host, s.config.Global.Server.Port),
319
+ Handler: mux,
320
+ ReadTimeout: time.Duration(s.config.Global.Server.ReadTimeout) * time.Second,
321
+ WriteTimeout: time.Duration(s.config.Global.Server.WriteTimeout) * time.Second,
322
+ IdleTimeout: time.Duration(s.config.Global.Server.IdleTimeout) * time.Second,
323
+ }
324
+
325
+ log.Printf("Server starting on %s\n", s.srv.Addr)
326
+ return s.srv.ListenAndServe()
327
+ }
328
+
329
+ func (s *Server) Shutdown(ctx context.Context) error {
330
+ return s.srv.Shutdown(ctx)
331
+ }
332
+
333
+ func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
334
+ w.WriteHeader(http.StatusOK)
335
+ json.NewEncoder(w).Encode(map[string]string{"status": "healthy"})
336
+ }
337
+
338
+ func (s *Server) handleOpenAIRequests(w http.ResponseWriter, r *http.Request) {
339
+ logger := NewRequestLogger(s.config)
340
+
341
+ fullAPIKey := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
342
+ apiKey := extractRealAPIKey(fullAPIKey)
343
+ channelID := extractChannelID(fullAPIKey)
344
+
345
+ logger.Log("Received request for %s with API Key: %s", r.URL.Path, logAPIKey(fullAPIKey))
346
+ logger.Log("Extracted channel ID: %s", channelID)
347
+ logger.Log("Extracted real API Key: %s", logAPIKey(apiKey))
348
+
349
+ targetChannel, ok := s.config.Channels[channelID]
350
+ if !ok {
351
+ http.Error(w, "Invalid channel", http.StatusBadRequest)
352
+ return
353
+ }
354
+
355
+ if r.URL.Path == "/v1/models" {
356
+ if r.Method != http.MethodGet {
357
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
358
+ return
359
+ }
360
+ req := &ChatCompletionRequest{APIKey: apiKey}
361
+ s.forwardModelsRequest(w, r.Context(), req, targetChannel)
362
+ return
363
+ }
364
+
365
+ if r.Method != http.MethodPost {
366
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
367
+ return
368
+ }
369
+
370
+ body, err := io.ReadAll(r.Body)
371
+ if err != nil {
372
+ logger.Log("Error reading request body: %v", err)
373
+ http.Error(w, "Failed to read request", http.StatusBadRequest)
374
+ return
375
+ }
376
+ r.Body.Close()
377
+ r.Body = io.NopCloser(bytes.NewBuffer(body))
378
+
379
+ if s.config.Global.Log.Debug.PrintRequest {
380
+ logger.LogContent("Request", string(body), s.config.Global.Log.Debug.MaxContentLength)
381
+ }
382
+
383
+ var req ChatCompletionRequest
384
+ if err := json.NewDecoder(bytes.NewBuffer(body)).Decode(&req); err != nil {
385
+ http.Error(w, "Invalid request body", http.StatusBadRequest)
386
+ return
387
+ }
388
+ req.APIKey = apiKey
389
+
390
+ thinkingService := s.getWeightedRandomThinkingService()
391
+ logger.Log("Using thinking service: %s with API Key: %s", thinkingService.Name, logAPIKey(thinkingService.APIKey))
392
+
393
+ if req.Stream {
394
+ handler, err := NewStreamHandler(w, thinkingService, targetChannel, s.config)
395
+ if err != nil {
396
+ http.Error(w, "Streaming not supported", http.StatusInternalServerError)
397
+ return
398
+ }
399
+ if err := handler.HandleRequest(r.Context(), &req); err != nil {
400
+ logger.Log("Stream handler error: %v", err)
401
+ }
402
+ } else {
403
+ thinkingResp, err := s.processThinkingContent(r.Context(), &req, thinkingService)
404
+ if err != nil {
405
+ logger.Log("Error processing thinking content: %v", err)
406
+ http.Error(w, "Thinking service error: "+err.Error(), http.StatusInternalServerError)
407
+ return
408
+ }
409
+ enhancedReq := s.prepareEnhancedRequest(&req, thinkingResp, thinkingService)
410
+ s.forwardRequest(w, r.Context(), enhancedReq, targetChannel)
411
+ }
412
+ }
413
+
414
+ func (s *Server) getWeightedRandomThinkingService() ThinkingService {
415
+ thinkingServices := s.config.ThinkingServices
416
+ if len(thinkingServices) == 0 {
417
+ return ThinkingService{}
418
+ }
419
+ totalWeight := 0
420
+ for _, svc := range thinkingServices {
421
+ totalWeight += svc.Weight
422
+ }
423
+ if totalWeight <= 0 {
424
+ log.Println("Warning: Total weight of thinking services is not positive, using first service as default.")
425
+ return thinkingServices[0]
426
+ }
427
+ randMu.Lock()
428
+ randNum := randGen.Intn(totalWeight)
429
+ randMu.Unlock()
430
+ currentSum := 0
431
+ for _, svc := range thinkingServices {
432
+ currentSum += svc.Weight
433
+ if randNum < currentSum {
434
+ return svc
435
+ }
436
+ }
437
+ return thinkingServices[0]
438
+ }
439
+
440
+ // ---------------------- 非流式处理思考服务 ----------------------
441
+
442
+ type ThinkingResponse struct {
443
+ Content string
444
+ ReasoningContent string
445
+ }
446
+
447
+ func (s *Server) processThinkingContent(ctx context.Context, req *ChatCompletionRequest, svc ThinkingService) (*ThinkingResponse, error) {
448
+ logger := NewRequestLogger(s.config)
449
+ log.Printf("Getting thinking content from service: %s (mode=%s)", svc.Name, svc.Mode)
450
+
451
+ thinkingReq := *req
452
+ thinkingReq.Model = svc.Model
453
+ thinkingReq.APIKey = svc.APIKey
454
+
455
+ var systemPrompt string
456
+ if svc.Mode == "full" {
457
+ systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly."
458
+ } else {
459
+ systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step."
460
+ }
461
+ thinkingReq.Messages = append([]ChatCompletionMessage{
462
+ {Role: "system", Content: systemPrompt},
463
+ }, thinkingReq.Messages...)
464
+
465
+ temp := 0.7
466
+ if svc.Temperature != nil {
467
+ temp = *svc.Temperature
468
+ }
469
+ payload := map[string]interface{}{
470
+ "model": svc.Model,
471
+ "messages": thinkingReq.Messages,
472
+ "stream": false,
473
+ "temperature": temp,
474
+ }
475
+ if isValidReasoningEffort(svc.ReasoningEffort) {
476
+ payload["reasoning_effort"] = svc.ReasoningEffort
477
+ }
478
+ if isValidReasoningFormat(svc.ReasoningFormat) {
479
+ payload["reasoning_format"] = svc.ReasoningFormat
480
+ }
481
+
482
+ if s.config.Global.Log.Debug.PrintRequest {
483
+ logger.LogContent("Thinking Service Request", payload, s.config.Global.Log.Debug.MaxContentLength)
484
+ }
485
+
486
+ jsonData, err := json.Marshal(payload)
487
+ if err != nil {
488
+ return nil, fmt.Errorf("failed to marshal thinking request: %v", err)
489
+ }
490
+ client, err := createHTTPClient(svc.Proxy, time.Duration(svc.Timeout)*time.Second)
491
+ if err != nil {
492
+ return nil, fmt.Errorf("failed to create HTTP client: %v", err)
493
+ }
494
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, svc.GetFullURL(), bytes.NewBuffer(jsonData))
495
+ if err != nil {
496
+ return nil, fmt.Errorf("failed to create request: %v", err)
497
+ }
498
+ httpReq.Header.Set("Content-Type", "application/json")
499
+ httpReq.Header.Set("Authorization", "Bearer "+svc.APIKey)
500
+
501
+ resp, err := client.Do(httpReq)
502
+ if err != nil {
503
+ return nil, fmt.Errorf("failed to send thinking request: %v", err)
504
+ }
505
+ defer resp.Body.Close()
506
+
507
+ respBody, err := io.ReadAll(resp.Body)
508
+ if err != nil {
509
+ return nil, fmt.Errorf("failed to read response body: %v", err)
510
+ }
511
+ if s.config.Global.Log.Debug.PrintResponse {
512
+ logger.LogContent("Thinking Service Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
513
+ }
514
+ if resp.StatusCode != http.StatusOK {
515
+ return nil, fmt.Errorf("thinking service returned %d: %s", resp.StatusCode, string(respBody))
516
+ }
517
+
518
+ var thinkingResp ChatCompletionResponse
519
+ if err := json.Unmarshal(respBody, &thinkingResp); err != nil {
520
+ return nil, fmt.Errorf("failed to unmarshal thinking response: %v", err)
521
+ }
522
+ if len(thinkingResp.Choices) == 0 {
523
+ return nil, fmt.Errorf("thinking service returned no choices")
524
+ }
525
+
526
+ result := &ThinkingResponse{}
527
+ choice := thinkingResp.Choices[0]
528
+
529
+ if svc.Mode == "full" {
530
+ result.ReasoningContent = choice.Message.Content
531
+ result.Content = "Based on the above detailed analysis."
532
+ } else {
533
+ if choice.Message.ReasoningContent != nil {
534
+ switch v := choice.Message.ReasoningContent.(type) {
535
+ case string:
536
+ result.ReasoningContent = v
537
+ case map[string]interface{}:
538
+ if j, err := json.Marshal(v); err == nil {
539
+ result.ReasoningContent = string(j)
540
+ }
541
+ }
542
+ }
543
+ if result.ReasoningContent == "" {
544
+ result.ReasoningContent = choice.Message.Content
545
+ }
546
+ result.Content = "Based on the above reasoning."
547
+ }
548
+ return result, nil
549
+ }
550
+
551
+ func (s *Server) prepareEnhancedRequest(originalReq *ChatCompletionRequest, thinkingResp *ThinkingResponse, svc ThinkingService) *ChatCompletionRequest {
552
+ newReq := *originalReq
553
+ var systemPrompt string
554
+ if svc.Mode == "full" {
555
+ systemPrompt = fmt.Sprintf(`Consider the following detailed analysis (not shown to user):
556
+ %s
557
+
558
+ Provide a clear, concise response that incorporates insights from this analysis.`, thinkingResp.ReasoningContent)
559
+ } else {
560
+ systemPrompt = fmt.Sprintf(`Previous thinking process:
561
+ %s
562
+ Please consider the above thinking process in your response.`, thinkingResp.ReasoningContent)
563
+ }
564
+ newReq.Messages = append([]ChatCompletionMessage{
565
+ {Role: "system", Content: systemPrompt},
566
+ }, newReq.Messages...)
567
+ return &newReq
568
+ }
569
+
570
+ func (s *Server) forwardRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, channel Channel) {
571
+ logger := NewRequestLogger(s.config)
572
+ if s.config.Global.Log.Debug.PrintRequest {
573
+ logger.LogContent("Forward Request", req, s.config.Global.Log.Debug.MaxContentLength)
574
+ }
575
+ jsonData, err := json.Marshal(req)
576
+ if err != nil {
577
+ http.Error(w, "Failed to marshal request", http.StatusInternalServerError)
578
+ return
579
+ }
580
+
581
+ client, err := createHTTPClient(channel.Proxy, time.Duration(channel.Timeout)*time.Second)
582
+ if err != nil {
583
+ http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError)
584
+ return
585
+ }
586
+
587
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, channel.GetFullURL(), bytes.NewBuffer(jsonData))
588
+ if err != nil {
589
+ http.Error(w, "Failed to create request", http.StatusInternalServerError)
590
+ return
591
+ }
592
+ httpReq.Header.Set("Content-Type", "application/json")
593
+ httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
594
+
595
+ resp, err := client.Do(httpReq)
596
+ if err != nil {
597
+ http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError)
598
+ return
599
+ }
600
+ defer resp.Body.Close()
601
+
602
+ respBody, err := io.ReadAll(resp.Body)
603
+ if err != nil {
604
+ http.Error(w, "Failed to read response", http.StatusInternalServerError)
605
+ return
606
+ }
607
+ if s.config.Global.Log.Debug.PrintResponse {
608
+ logger.LogContent("Forward Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
609
+ }
610
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
611
+ http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode)
612
+ return
613
+ }
614
+
615
+ for k, vals := range resp.Header {
616
+ for _, v := range vals {
617
+ w.Header().Add(k, v)
618
+ }
619
+ }
620
+ w.WriteHeader(resp.StatusCode)
621
+ w.Write(respBody)
622
+ }
623
+
624
+ func (s *Server) forwardModelsRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, targetChannel Channel) {
625
+ logger := NewRequestLogger(s.config)
626
+ if s.config.Global.Log.Debug.PrintRequest {
627
+ logger.LogContent("/v1/models Request", req, s.config.Global.Log.Debug.MaxContentLength)
628
+ }
629
+ fullChatURL := targetChannel.GetFullURL()
630
+ parsedURL, err := url.Parse(fullChatURL)
631
+ if err != nil {
632
+ http.Error(w, "Failed to parse channel URL", http.StatusInternalServerError)
633
+ return
634
+ }
635
+ baseURL := parsedURL.Scheme + "://" + parsedURL.Host
636
+ modelsURL := strings.TrimSuffix(baseURL, "/") + "/v1/models"
637
+
638
+ client, err := createHTTPClient(targetChannel.Proxy, time.Duration(targetChannel.Timeout)*time.Second)
639
+ if err != nil {
640
+ http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError)
641
+ return
642
+ }
643
+
644
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil)
645
+ if err != nil {
646
+ http.Error(w, "Failed to create request", http.StatusInternalServerError)
647
+ return
648
+ }
649
+ httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
650
+
651
+ resp, err := client.Do(httpReq)
652
+ if err != nil {
653
+ http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError)
654
+ return
655
+ }
656
+ defer resp.Body.Close()
657
+
658
+ respBody, err := io.ReadAll(resp.Body)
659
+ if err != nil {
660
+ http.Error(w, "Failed to read response", http.StatusInternalServerError)
661
+ return
662
+ }
663
+ if s.config.Global.Log.Debug.PrintResponse {
664
+ logger.LogContent("/v1/models Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
665
+ }
666
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
667
+ http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode)
668
+ return
669
+ }
670
+
671
+ for k, vals := range resp.Header {
672
+ for _, v := range vals {
673
+ w.Header().Add(k, v)
674
+ }
675
+ }
676
+ w.WriteHeader(resp.StatusCode)
677
+ w.Write(respBody)
678
+ }
679
+
680
+ // ---------------------- 流式处理 ----------------------
681
+
682
+ // collectedReasoningBuffer 用于收集思考服务返回的 reasoning_content(标准模式只收集 reasoning_content;full 模式收集全部)
683
+ type collectedReasoningBuffer struct {
684
+ builder strings.Builder
685
+ mode string
686
+ }
687
+
688
+ func (cb *collectedReasoningBuffer) append(text string) {
689
+ cb.builder.WriteString(text)
690
+ }
691
+
692
+ func (cb *collectedReasoningBuffer) get() string {
693
+ return cb.builder.String()
694
+ }
695
+
696
+ type StreamHandler struct {
697
+ thinkingService ThinkingService
698
+ targetChannel Channel
699
+ writer http.ResponseWriter
700
+ flusher http.Flusher
701
+ config *Config
702
+ }
703
+
704
+ func NewStreamHandler(w http.ResponseWriter, thinkingService ThinkingService, targetChannel Channel, config *Config) (*StreamHandler, error) {
705
+ flusher, ok := w.(http.Flusher)
706
+ if !ok {
707
+ return nil, fmt.Errorf("streaming not supported by this writer")
708
+ }
709
+ return &StreamHandler{
710
+ thinkingService: thinkingService,
711
+ targetChannel: targetChannel,
712
+ writer: w,
713
+ flusher: flusher,
714
+ config: config,
715
+ }, nil
716
+ }
717
+
718
+ // HandleRequest 分两阶段:先流式接收思考服务 SSE 并转发给客户端(只转发包含 reasoning_content 的 chunk,标准模式遇到纯 content 则中断);
719
+ // 然后使用收集到的 reasoning_content 构造最终请求,发起目标 Channel 的 SSE 并转发给客户端。
720
+ func (h *StreamHandler) HandleRequest(ctx context.Context, req *ChatCompletionRequest) error {
721
+ h.writer.Header().Set("Content-Type", "text/event-stream")
722
+ h.writer.Header().Set("Cache-Control", "no-cache")
723
+ h.writer.Header().Set("Connection", "keep-alive")
724
+
725
+ logger := NewRequestLogger(h.config)
726
+ reasonBuf := &collectedReasoningBuffer{mode: h.thinkingService.Mode}
727
+
728
+ // 阶段一:流式接收思考服务 SSE
729
+ if err := h.streamThinkingService(ctx, req, reasonBuf, logger); err != nil {
730
+ return err
731
+ }
732
+
733
+ // 阶段二:构造最终请求,并流式转发目标 Channel SSE
734
+ finalReq := h.prepareFinalRequest(req, reasonBuf.get())
735
+ if err := h.streamFinalChannel(ctx, finalReq, logger); err != nil {
736
+ return err
737
+ }
738
+ return nil
739
+ }
740
+
741
+ // streamThinkingService 连接思考服务 SSE,原样转发包含 reasoning_content 的 SSE chunk给客户端,同时收集 reasoning_content。
742
+ // 在标准模式下,一旦遇到只返回 non-empty 的 content(且 reasoning_content 为空),则立即中断,不转发该 chunk。
743
+ func (h *StreamHandler) streamThinkingService(ctx context.Context, req *ChatCompletionRequest, reasonBuf *collectedReasoningBuffer, logger *RequestLogger) error {
744
+ thinkingReq := *req
745
+ thinkingReq.Model = h.thinkingService.Model
746
+ thinkingReq.APIKey = h.thinkingService.APIKey
747
+
748
+ var systemPrompt string
749
+ if h.thinkingService.Mode == "full" {
750
+ systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly."
751
+ } else {
752
+ systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step."
753
+ }
754
+ messages := append([]ChatCompletionMessage{
755
+ {Role: "system", Content: systemPrompt},
756
+ }, thinkingReq.Messages...)
757
+
758
+ temp := 0.7
759
+ if h.thinkingService.Temperature != nil {
760
+ temp = *h.thinkingService.Temperature
761
+ }
762
+ bodyMap := map[string]interface{}{
763
+ "model": thinkingReq.Model,
764
+ "messages": messages,
765
+ "temperature": temp,
766
+ "stream": true,
767
+ }
768
+ if isValidReasoningEffort(h.thinkingService.ReasoningEffort) {
769
+ bodyMap["reasoning_effort"] = h.thinkingService.ReasoningEffort
770
+ }
771
+ if isValidReasoningFormat(h.thinkingService.ReasoningFormat) {
772
+ bodyMap["reasoning_format"] = h.thinkingService.ReasoningFormat
773
+ }
774
+
775
+ jsonData, err := json.Marshal(bodyMap)
776
+ if err != nil {
777
+ return err
778
+ }
779
+
780
+ client, err := createHTTPClient(h.thinkingService.Proxy, time.Duration(h.thinkingService.Timeout)*time.Second)
781
+ if err != nil {
782
+ return err
783
+ }
784
+
785
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.thinkingService.GetFullURL(), bytes.NewBuffer(jsonData))
786
+ if err != nil {
787
+ return err
788
+ }
789
+ httpReq.Header.Set("Content-Type", "application/json")
790
+ httpReq.Header.Set("Authorization", "Bearer "+h.thinkingService.APIKey)
791
+
792
+ log.Printf("Starting ThinkingService SSE => %s (mode=%s, force_stop=%v)", h.thinkingService.GetFullURL(), h.thinkingService.Mode, h.thinkingService.ForceStopDeepThinking)
793
+ resp, err := client.Do(httpReq)
794
+ if err != nil {
795
+ return err
796
+ }
797
+ defer resp.Body.Close()
798
+
799
+ if resp.StatusCode != http.StatusOK {
800
+ b, _ := io.ReadAll(resp.Body)
801
+ return fmt.Errorf("thinking service status=%d, body=%s", resp.StatusCode, string(b))
802
+ }
803
+
804
+ reader := bufio.NewReader(resp.Body)
805
+ var lastLine string
806
+ // 标识是否需中断思考 SSE(标准模式下)
807
+ forceStop := false
808
+
809
+ for {
810
+ select {
811
+ case <-ctx.Done():
812
+ return ctx.Err()
813
+ default:
814
+ }
815
+ line, err := reader.ReadString('\n')
816
+ if err != nil {
817
+ if err == io.EOF {
818
+ break
819
+ }
820
+ return err
821
+ }
822
+ line = strings.TrimSpace(line)
823
+ if line == "" || line == lastLine {
824
+ continue
825
+ }
826
+ lastLine = line
827
+
828
+ if !strings.HasPrefix(line, "data: ") {
829
+ continue
830
+ }
831
+ dataPart := strings.TrimPrefix(line, "data: ")
832
+ if dataPart == "[DONE]" {
833
+ // 思考服务结束:直接中断(不发送特殊标记)
834
+ break
835
+ }
836
+
837
+ var chunk struct {
838
+ Choices []struct {
839
+ Delta struct {
840
+ Content string `json:"content,omitempty"`
841
+ ReasoningContent string `json:"reasoning_content,omitempty"`
842
+ } `json:"delta"`
843
+ FinishReason *string `json:"finish_reason,omitempty"`
844
+ } `json:"choices"`
845
+ }
846
+ if err := json.Unmarshal([]byte(dataPart), &chunk); err != nil {
847
+ // 如果解析失败,直接原样转发
848
+ h.writer.Write([]byte("data: " + dataPart + "\n\n"))
849
+ h.flusher.Flush()
850
+ continue
851
+ }
852
+
853
+ if len(chunk.Choices) > 0 {
854
+ c := chunk.Choices[0]
855
+ if h.config.Global.Log.Debug.PrintResponse {
856
+ logger.LogContent("Thinking SSE chunk", chunk, h.config.Global.Log.Debug.MaxContentLength)
857
+ }
858
+
859
+ if h.thinkingService.Mode == "full" {
860
+ // full 模式:收集所有内容
861
+ if c.Delta.ReasoningContent != "" {
862
+ reasonBuf.append(c.Delta.ReasoningContent)
863
+ }
864
+ if c.Delta.Content != "" {
865
+ reasonBuf.append(c.Delta.Content)
866
+ }
867
+ // 原样转发整 chunk
868
+ forwardLine := "data: " + dataPart + "\n\n"
869
+ h.writer.Write([]byte(forwardLine))
870
+ h.flusher.Flush()
871
+ } else {
872
+ // standard 模式:只收集 reasoning_content
873
+ if c.Delta.ReasoningContent != "" {
874
+ reasonBuf.append(c.Delta.ReasoningContent)
875
+ // 转发该 chunk
876
+ forwardLine := "data: " + dataPart + "\n\n"
877
+ h.writer.Write([]byte(forwardLine))
878
+ h.flusher.Flush()
879
+ }
880
+ // 如果遇到 chunk 中只有 content(且 reasoning_content 为空),认为思考链结束,不转发该 chunk
881
+ if c.Delta.Content != "" && strings.TrimSpace(c.Delta.ReasoningContent) == "" {
882
+ forceStop = true
883
+ }
884
+ }
885
+
886
+ // 如果 finishReason 非空,也认为结束
887
+ if c.FinishReason != nil && *c.FinishReason != "" {
888
+ forceStop = true
889
+ }
890
+ }
891
+
892
+ if forceStop {
893
+ break
894
+ }
895
+ }
896
+ // 读空剩余数据
897
+ io.Copy(io.Discard, reader)
898
+ return nil
899
+ }
900
+
901
+ func (h *StreamHandler) prepareFinalRequest(originalReq *ChatCompletionRequest, reasoningCollected string) *ChatCompletionRequest {
902
+ req := *originalReq
903
+ var systemPrompt string
904
+ if h.thinkingService.Mode == "full" {
905
+ systemPrompt = fmt.Sprintf(
906
+ `Consider the following detailed analysis (not shown to user):
907
+ %s
908
+
909
+ Provide a clear, concise response that incorporates insights from this analysis.`,
910
+ reasoningCollected,
911
+ )
912
+ } else {
913
+ systemPrompt = fmt.Sprintf(
914
+ `Previous thinking process:
915
+ %s
916
+ Please consider the above thinking process in your response.`,
917
+ reasoningCollected,
918
+ )
919
+ }
920
+ req.Messages = append([]ChatCompletionMessage{
921
+ {Role: "system", Content: systemPrompt},
922
+ }, req.Messages...)
923
+ return &req
924
+ }
925
+
926
+ func (h *StreamHandler) streamFinalChannel(ctx context.Context, req *ChatCompletionRequest, logger *RequestLogger) error {
927
+ if h.config.Global.Log.Debug.PrintRequest {
928
+ logger.LogContent("Final Request to Channel", req, h.config.Global.Log.Debug.MaxContentLength)
929
+ }
930
+
931
+ jsonData, err := json.Marshal(req)
932
+ if err != nil {
933
+ return err
934
+ }
935
+ client, err := createHTTPClient(h.targetChannel.Proxy, time.Duration(h.targetChannel.Timeout)*time.Second)
936
+ if err != nil {
937
+ return err
938
+ }
939
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.targetChannel.GetFullURL(), bytes.NewBuffer(jsonData))
940
+ if err != nil {
941
+ return err
942
+ }
943
+ httpReq.Header.Set("Content-Type", "application/json")
944
+ httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
945
+
946
+ resp, err := client.Do(httpReq)
947
+ if err != nil {
948
+ return err
949
+ }
950
+ defer resp.Body.Close()
951
+
952
+ if resp.StatusCode != http.StatusOK {
953
+ b, _ := io.ReadAll(resp.Body)
954
+ return fmt.Errorf("target channel status=%d, body=%s", resp.StatusCode, string(b))
955
+ }
956
+
957
+ reader := bufio.NewReader(resp.Body)
958
+ var lastLine string
959
+
960
+ for {
961
+ select {
962
+ case <-ctx.Done():
963
+ return ctx.Err()
964
+ default:
965
+ }
966
+ line, err := reader.ReadString('\n')
967
+ if err != nil {
968
+ if err == io.EOF {
969
+ break
970
+ }
971
+ return err
972
+ }
973
+ line = strings.TrimSpace(line)
974
+ if line == "" || line == lastLine {
975
+ continue
976
+ }
977
+ lastLine = line
978
+
979
+ if !strings.HasPrefix(line, "data: ") {
980
+ continue
981
+ }
982
+ data := strings.TrimPrefix(line, "data: ")
983
+ if data == "[DONE]" {
984
+ doneLine := "data: [DONE]\n\n"
985
+ h.writer.Write([]byte(doneLine))
986
+ h.flusher.Flush()
987
+ break
988
+ }
989
+ forwardLine := "data: " + data + "\n\n"
990
+ h.writer.Write([]byte(forwardLine))
991
+ h.flusher.Flush()
992
+
993
+ if h.config.Global.Log.Debug.PrintResponse {
994
+ logger.LogContent("Channel SSE chunk", forwardLine, h.config.Global.Log.Debug.MaxContentLength)
995
+ }
996
+ }
997
+ io.Copy(io.Discard, reader)
998
+ return nil
999
+ }
1000
+
1001
+ // ---------------------- HTTP Client 工具 ----------------------
1002
+
1003
+ func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) {
1004
+ transport := &http.Transport{
1005
+ DialContext: (&net.Dialer{
1006
+ Timeout: 30 * time.Second,
1007
+ KeepAlive: 30 * time.Second,
1008
+ }).DialContext,
1009
+ ForceAttemptHTTP2: true,
1010
+ MaxIdleConns: 100,
1011
+ IdleConnTimeout: 90 * time.Second,
1012
+ TLSHandshakeTimeout: 10 * time.Second,
1013
+ ExpectContinueTimeout: 1 * time.Second,
1014
+ }
1015
+ if proxyURL != "" {
1016
+ parsedURL, err := url.Parse(proxyURL)
1017
+ if err != nil {
1018
+ return nil, fmt.Errorf("invalid proxy URL: %v", err)
1019
+ }
1020
+ switch parsedURL.Scheme {
1021
+ case "http", "https":
1022
+ transport.Proxy = http.ProxyURL(parsedURL)
1023
+ case "socks5":
1024
+ dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
1025
+ if err != nil {
1026
+ return nil, fmt.Errorf("failed to create SOCKS5 dialer: %v", err)
1027
+ }
1028
+ transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
1029
+ return dialer.Dial(network, addr)
1030
+ }
1031
+ default:
1032
+ return nil, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
1033
+ }
1034
+ }
1035
+ return &http.Client{
1036
+ Transport: transport,
1037
+ Timeout: timeout,
1038
+ }, nil
1039
+ }
1040
+
1041
+ func maskSensitiveHeaders(headers http.Header) http.Header {
1042
+ masked := make(http.Header)
1043
+ for k, vals := range headers {
1044
+ if strings.ToLower(k) == "authorization" {
1045
+ masked[k] = []string{"Bearer ****"}
1046
+ } else {
1047
+ masked[k] = vals
1048
+ }
1049
+ }
1050
+ return masked
1051
+ }
1052
+
1053
+ func isValidReasoningEffort(effort string) bool {
1054
+ switch strings.ToLower(effort) {
1055
+ case "low", "medium", "high":
1056
+ return true
1057
+ }
1058
+ return false
1059
+ }
1060
+
1061
+ func isValidReasoningFormat(format string) bool {
1062
+ switch strings.ToLower(format) {
1063
+ case "parsed", "raw", "hidden":
1064
+ return true
1065
+ }
1066
+ return false
1067
+ }
1068
+
1069
+ // ---------------------- 配置加载与验证 ----------------------
1070
+
1071
+ func loadConfig() (*Config, error) {
1072
+ var configFile string
1073
+ flag.StringVar(&configFile, "config", "", "path to config file")
1074
+ flag.Parse()
1075
+ viper.SetConfigType("yaml")
1076
+
1077
+ if configFile != "" {
1078
+ viper.SetConfigFile(configFile)
1079
+ } else {
1080
+ ex, err := os.Executable()
1081
+ if err != nil {
1082
+ return nil, err
1083
+ }
1084
+ exePath := filepath.Dir(ex)
1085
+ defaultPaths := []string{
1086
+ filepath.Join(exePath, "config.yaml"),
1087
+ filepath.Join(exePath, "conf", "config.yaml"),
1088
+ "./config.yaml",
1089
+ "./conf/config.yaml",
1090
+ }
1091
+ if os.PathSeparator == '\\' {
1092
+ programData := os.Getenv("PROGRAMDATA")
1093
+ if programData != "" {
1094
+ defaultPaths = append(defaultPaths, filepath.Join(programData, "DeepAI", "config.yaml"))
1095
+ }
1096
+ } else {
1097
+ defaultPaths = append(defaultPaths, "/etc/deepai/config.yaml")
1098
+ }
1099
+ for _, p := range defaultPaths {
1100
+ viper.AddConfigPath(filepath.Dir(p))
1101
+ if strings.Contains(p, ".yaml") {
1102
+ viper.SetConfigName(strings.TrimSuffix(filepath.Base(p), ".yaml"))
1103
+ }
1104
+ }
1105
+ }
1106
+
1107
+ if err := viper.ReadInConfig(); err != nil {
1108
+ return nil, fmt.Errorf("failed to read config: %v", err)
1109
+ }
1110
+
1111
+ var cfg Config
1112
+ if err := viper.Unmarshal(&cfg); err != nil {
1113
+ return nil, fmt.Errorf("unmarshal config error: %v", err)
1114
+ }
1115
+ if err := validateConfig(&cfg); err != nil {
1116
+ return nil, fmt.Errorf("config validation error: %v", err)
1117
+ }
1118
+ return &cfg, nil
1119
+ }
1120
+
1121
+ func validateConfig(config *Config) error {
1122
+ if len(config.ThinkingServices) == 0 {
1123
+ return fmt.Errorf("no thinking services configured")
1124
+ }
1125
+ if len(config.Channels) == 0 {
1126
+ return fmt.Errorf("no channels configured")
1127
+ }
1128
+ for i, svc := range config.ThinkingServices {
1129
+ if svc.BaseURL == "" {
1130
+ return fmt.Errorf("thinking service %s has empty baseURL", svc.Name)
1131
+ }
1132
+ if svc.APIKey == "" {
1133
+ return fmt.Errorf("thinking service %s has empty apiKey", svc.Name)
1134
+ }
1135
+ if svc.Timeout <= 0 {
1136
+ return fmt.Errorf("thinking service %s has invalid timeout", svc.Name)
1137
+ }
1138
+ if svc.Model == "" {
1139
+ return fmt.Errorf("thinking service %s has empty model", svc.Name)
1140
+ }
1141
+ if svc.Mode == "" {
1142
+ config.ThinkingServices[i].Mode = "standard"
1143
+ } else if svc.Mode != "standard" && svc.Mode != "full" {
1144
+ return fmt.Errorf("thinking service %s unknown mode=%s", svc.Name, svc.Mode)
1145
+ }
1146
+ }
1147
+ return nil
1148
+ }
1149
+
1150
+ func main() {
1151
+ log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile)
1152
+
1153
+ cfg, err := loadConfig()
1154
+ if err != nil {
1155
+ log.Fatalf("Failed to load config: %v", err)
1156
+ }
1157
+ log.Printf("Using config file: %s", viper.ConfigFileUsed())
1158
+
1159
+ server := NewServer(cfg)
1160
+
1161
+ done := make(chan os.Signal, 1)
1162
+ signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
1163
+
1164
+ go func() {
1165
+ if err := server.Start(); err != nil && err != http.ErrServerClosed {
1166
+ log.Fatalf("start server error: %v", err)
1167
+ }
1168
+ }()
1169
+ log.Printf("Server started successfully")
1170
+
1171
+ <-done
1172
+ log.Print("Server stopping...")
1173
+
1174
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
1175
+ defer cancel()
1176
+
1177
+ if err := server.Shutdown(ctx); err != nil {
1178
+ log.Printf("Server forced to shutdown: %v", err)
1179
+ }
1180
+ log.Print("Server stopped")
1181
+ }