nbugs commited on
Commit
ae0f936
·
verified ·
1 Parent(s): 96c8fad

Upload main.go

Browse files
Files changed (1) hide show
  1. main.go +1271 -0
main.go ADDED
@@ -0,0 +1,1271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "bufio"
5
+ "bytes"
6
+ "context"
7
+ "encoding/json"
8
+ "flag"
9
+ "fmt"
10
+ "io"
11
+ "log"
12
+ "math/rand"
13
+ "net"
14
+ "net/http"
15
+ "net/url"
16
+ "os"
17
+ "os/signal"
18
+ "path/filepath"
19
+ "regexp"
20
+ "strings"
21
+ "sync"
22
+ "syscall"
23
+ "time"
24
+
25
+ "github.com/google/uuid"
26
+ "github.com/spf13/viper"
27
+ "golang.org/x/net/proxy"
28
+ )
29
+
30
+ // ---------------------- 配置结构 ----------------------
31
+
32
+ type Config struct {
33
+ ThinkingServices []ThinkingService `mapstructure:"thinking_services"`
34
+ Channels map[string]Channel `mapstructure:"channels"`
35
+ Global GlobalConfig `mapstructure:"global"`
36
+ }
37
+
38
+ type ThinkingService struct {
39
+ ID int `mapstructure:"id"`
40
+ Name string `mapstructure:"name"`
41
+ Model string `mapstructure:"model"`
42
+ BaseURL string `mapstructure:"base_url"`
43
+ APIPath string `mapstructure:"api_path"`
44
+ APIKey string `mapstructure:"api_key"`
45
+ Timeout int `mapstructure:"timeout"`
46
+ Retry int `mapstructure:"retry"`
47
+ Weight int `mapstructure:"weight"`
48
+ Proxy string `mapstructure:"proxy"`
49
+ Mode string `mapstructure:"mode"` // "standard" 或 "full"
50
+ ReasoningEffort string `mapstructure:"reasoning_effort"`
51
+ ReasoningFormat string `mapstructure:"reasoning_format"`
52
+ Temperature *float64 `mapstructure:"temperature"`
53
+ ForceStopDeepThinking bool `mapstructure:"force_stop_deep_thinking"` // 配置项:标准模式下遇到 content 时是否立即停止
54
+ }
55
+
56
+ // ... existing code ...
57
+
58
+ // 新增函数:从环境变量加载配置
59
+ func loadConfigFromEnv() *Config {
60
+ cfg := &Config{
61
+ ThinkingServices: []ThinkingService{
62
+ {
63
+ ID: 1,
64
+ Name: os.Getenv("THINKING_SERVICE_NAME"),
65
+ Mode: getEnvOrDefault("THINKING_SERVICE_MODE", "standard"),
66
+ Model: os.Getenv("THINKING_SERVICE_MODEL"),
67
+ BaseURL: os.Getenv("THINKING_SERVICE_BASE_URL"),
68
+ APIPath: "/v1/chat/completions",
69
+ APIKey: os.Getenv("THINKING_SERVICE_API_KEY"),
70
+ Timeout: 600,
71
+ Weight: 100,
72
+ Proxy: "",
73
+ ReasoningEffort: "high",
74
+ ReasoningFormat: "parsed",
75
+ Temperature: getFloatPtr(0.8),
76
+ ForceStopDeepThinking: false,
77
+ },
78
+ },
79
+ Channels: map[string]Channel{
80
+ "1": {
81
+ Name: os.Getenv("CHANNEL_NAME"),
82
+ BaseURL: os.Getenv("CHANNEL_BASE_URL"),
83
+ APIPath: "/v1/chat/completions",
84
+ Timeout: 600,
85
+ Proxy: "",
86
+ },
87
+ },
88
+ Global: GlobalConfig{
89
+ Log: LogConfig{
90
+ Level: "error",
91
+ Format: "json",
92
+ Output: "console",
93
+ Debug: DebugConfig{
94
+ Enabled: true,
95
+ PrintRequest: true,
96
+ PrintResponse: true,
97
+ MaxContentLength: 1000,
98
+ },
99
+ },
100
+ Server: ServerConfig{
101
+ Port: 8888,
102
+ Host: "0.0.0.0",
103
+ ReadTimeout: 600,
104
+ WriteTimeout: 600,
105
+ IdleTimeout: 600,
106
+ },
107
+ },
108
+ }
109
+ return cfg
110
+ }
111
+
112
+ // 新增辅助函数:获取环境变量,如果不存在则返回默认值
113
+ func getEnvOrDefault(key, defaultValue string) string {
114
+ if value := os.Getenv(key); value != "" {
115
+ return value
116
+ }
117
+ return defaultValue
118
+ }
119
+
120
+ // 新增辅助函数:创建float64指针
121
+ func getFloatPtr(value float64) *float64 {
122
+ return &value
123
+ }
124
+
125
+ // 修改 loadConfig 函数
126
+ func loadConfig() (*Config, error) {
127
+ // 优先使用环境变量
128
+ if os.Getenv("USE_ENV_CONFIG") == "true" {
129
+ cfg := loadConfigFromEnv()
130
+ if err := validateConfig(cfg); err != nil {
131
+ return nil, fmt.Errorf("环境变量配置验证错误: %v", err)
132
+ }
133
+ return cfg, nil
134
+ }
135
+
136
+ // 如果不使用环境变量,则使用原有的配置文件加载逻辑
137
+ var configFile string
138
+ flag.StringVar(&configFile, "config", "", "path to config file")
139
+ flag.Parse()
140
+ viper.SetConfigType("yaml")
141
+
142
+ // ... existing code ...
143
+ }
144
+
145
+
146
+ func (s *ThinkingService) GetFullURL() string {
147
+ return s.BaseURL + s.APIPath
148
+ }
149
+
150
+ type Channel struct {
151
+ Name string `mapstructure:"name"`
152
+ BaseURL string `mapstructure:"base_url"`
153
+ APIPath string `mapstructure:"api_path"`
154
+ Timeout int `mapstructure:"timeout"`
155
+ Proxy string `mapstructure:"proxy"`
156
+ }
157
+
158
+ func (c *Channel) GetFullURL() string {
159
+ return c.BaseURL + c.APIPath
160
+ }
161
+
162
+ type LogConfig struct {
163
+ Level string `mapstructure:"level"`
164
+ Format string `mapstructure:"format"`
165
+ Output string `mapstructure:"output"`
166
+ FilePath string `mapstructure:"file_path"`
167
+ Debug DebugConfig `mapstructure:"debug"`
168
+ }
169
+
170
+ type DebugConfig struct {
171
+ Enabled bool `mapstructure:"enabled"`
172
+ PrintRequest bool `mapstructure:"print_request"`
173
+ PrintResponse bool `mapstructure:"print_response"`
174
+ MaxContentLength int `mapstructure:"max_content_length"`
175
+ }
176
+
177
+ type ProxyConfig struct {
178
+ Enabled bool `mapstructure:"enabled"`
179
+ Default string `mapstructure:"default"`
180
+ AllowInsecure bool `mapstructure:"allow_insecure"`
181
+ }
182
+
183
+ type GlobalConfig struct {
184
+ MaxRetries int `mapstructure:"max_retries"`
185
+ DefaultTimeout int `mapstructure:"default_timeout"`
186
+ ErrorCodes struct {
187
+ RetryOn []int `mapstructure:"retry_on"`
188
+ } `mapstructure:"error_codes"`
189
+ Log LogConfig `mapstructure:"log"`
190
+ Server ServerConfig `mapstructure:"server"`
191
+ Proxy ProxyConfig `mapstructure:"proxy"`
192
+ ConfigPaths []string `mapstructure:"config_paths"`
193
+ Thinking ThinkingConfig `mapstructure:"thinking"`
194
+ }
195
+
196
+ type ServerConfig struct {
197
+ Port int `mapstructure:"port"`
198
+ Host string `mapstructure:"host"`
199
+ ReadTimeout int `mapstructure:"read_timeout"`
200
+ WriteTimeout int `mapstructure:"write_timeout"`
201
+ IdleTimeout int `mapstructure:"idle_timeout"`
202
+ }
203
+
204
+ type ThinkingConfig struct {
205
+ Enabled bool `mapstructure:"enabled"`
206
+ AddToAllRequests bool `mapstructure:"add_to_all_requests"`
207
+ Timeout int `mapstructure:"timeout"`
208
+ }
209
+
210
+ // ---------------------- API 相关结构 ----------------------
211
+
212
+ type ChatCompletionRequest struct {
213
+ Model string `json:"model"`
214
+ Messages []ChatCompletionMessage `json:"messages"`
215
+ Temperature float64 `json:"temperature,omitempty"`
216
+ MaxTokens int `json:"max_tokens,omitempty"`
217
+ Stream bool `json:"stream,omitempty"`
218
+ APIKey string `json:"-"` // 内部传递,不序列化
219
+ }
220
+
221
+ type ChatCompletionMessage struct {
222
+ Role string `json:"role"`
223
+ Content string `json:"content"`
224
+ ReasoningContent interface{} `json:"reasoning_content,omitempty"`
225
+ }
226
+
227
+ type ChatCompletionResponse struct {
228
+ ID string `json:"id"`
229
+ Object string `json:"object"`
230
+ Created int64 `json:"created"`
231
+ Model string `json:"model"`
232
+ Choices []Choice `json:"choices"`
233
+ Usage Usage `json:"usage"`
234
+ }
235
+
236
+ type Choice struct {
237
+ Index int `json:"index"`
238
+ Message ChatCompletionMessage `json:"message"`
239
+ FinishReason string `json:"finish_reason"`
240
+ }
241
+
242
+ type Usage struct {
243
+ PromptTokens int `json:"prompt_tokens"`
244
+ CompletionTokens int `json:"completion_tokens"`
245
+ TotalTokens int `json:"total_tokens"`
246
+ }
247
+
248
+ // ---------------------- 日志工具 ----------------------
249
+
250
+ type RequestLogger struct {
251
+ RequestID string
252
+ Model string
253
+ StartTime time.Time
254
+ logs []string
255
+ config *Config
256
+ }
257
+
258
+ func NewRequestLogger(config *Config) *RequestLogger {
259
+ return &RequestLogger{
260
+ RequestID: uuid.New().String(),
261
+ StartTime: time.Now(),
262
+ logs: make([]string, 0),
263
+ config: config,
264
+ }
265
+ }
266
+
267
+ // 新增函数:替换环境变量
268
+ func replaceEnvVars(value string) string {
269
+ // 检查是否包含环境变量占位符 ${...}
270
+ if !strings.Contains(value, "${") {
271
+ return value
272
+ }
273
+
274
+ // 查找所有 ${...} 格式的环境变量并替换
275
+ result := value
276
+ for {
277
+ start := strings.Index(result, "${")
278
+ if start == -1 {
279
+ break
280
+ }
281
+
282
+ end := strings.Index(result[start:], "}")
283
+ if end == -1 {
284
+ break
285
+ }
286
+ end += start
287
+
288
+ // 提取环境变量名称
289
+ envName := result[start+2:end]
290
+
291
+ // 获取环境变量值,如果不存在则保持原样
292
+ envValue := os.Getenv(envName)
293
+ if envValue == "" {
294
+ // 可选:记录警告或使用默认值
295
+ // 这里我们保持原样,也可以设置为空字符串或其他默认值
296
+ }
297
+
298
+ // 替换环境变量
299
+ result = result[:start] + envValue + result[end+1:]
300
+ }
301
+
302
+ return result
303
+ }
304
+
305
+ // 处理配置结构体中的环境变量
306
+ func ProcessConfig(config *Config) {
307
+ // 处理 ThinkingServices
308
+ for i := range config.ThinkingServices {
309
+ config.ThinkingServices[i].Name = replaceEnvVars(config.ThinkingServices[i].Name)
310
+ config.ThinkingServices[i].Model = replaceEnvVars(config.ThinkingServices[i].Model)
311
+ config.ThinkingServices[i].BaseURL = replaceEnvVars(config.ThinkingServices[i].BaseURL)
312
+ config.ThinkingServices[i].APIKey = replaceEnvVars(config.ThinkingServices[i].APIKey)
313
+ config.ThinkingServices[i].Proxy = replaceEnvVars(config.ThinkingServices[i].Proxy)
314
+ }
315
+
316
+ // 处理 Channels
317
+ for key, channel := range config.Channels {
318
+ channel.Name = replaceEnvVars(channel.Name)
319
+ channel.BaseURL = replaceEnvVars(channel.BaseURL)
320
+ channel.Proxy = replaceEnvVars(channel.Proxy)
321
+ config.Channels[key] = channel
322
+ }
323
+
324
+ // 如果 Global 配置中也有环境变量,可以在这里添加处理逻辑
325
+ }
326
+
327
+ func (l *RequestLogger) Log(format string, args ...interface{}) {
328
+ msg := fmt.Sprintf(format, args...)
329
+ l.logs = append(l.logs, fmt.Sprintf("[%s] %s", time.Now().Format(time.RFC3339), msg))
330
+ log.Printf("[RequestID: %s] %s", l.RequestID, msg)
331
+ }
332
+
333
+ func (l *RequestLogger) LogContent(contentType string, content interface{}, maxLength int) {
334
+ if !l.config.Global.Log.Debug.Enabled {
335
+ return
336
+ }
337
+ sanitizedContent := sanitizeJSON(content)
338
+ truncatedContent := truncateContent(sanitizedContent, maxLength)
339
+ l.Log("%s Content:\n%s", contentType, truncatedContent)
340
+ }
341
+
342
+ func truncateContent(content string, maxLength int) string {
343
+ if len(content) <= maxLength {
344
+ return content
345
+ }
346
+ return content[:maxLength] + "... (truncated)"
347
+ }
348
+
349
+ func sanitizeJSON(data interface{}) string {
350
+ sanitized, err := json.Marshal(data)
351
+ if err != nil {
352
+ return "Failed to marshal JSON"
353
+ }
354
+ content := string(sanitized)
355
+ sensitivePattern := `"api_key":\s*"[^"]*"`
356
+ content = regexp.MustCompile(sensitivePattern).ReplaceAllString(content, `"api_key":"****"`)
357
+ return content
358
+ }
359
+
360
+ func extractRealAPIKey(fullKey string) string {
361
+ parts := strings.Split(fullKey, "-")
362
+ if len(parts) >= 3 && (parts[0] == "deep" || parts[0] == "openai") {
363
+ return strings.Join(parts[2:], "-")
364
+ }
365
+ return fullKey
366
+ }
367
+
368
+ func extractChannelID(fullKey string) string {
369
+ parts := strings.Split(fullKey, "-")
370
+ if len(parts) >= 2 && (parts[0] == "deep" || parts[0] == "openai") {
371
+ return parts[1]
372
+ }
373
+ return "1" // 默认渠道
374
+ }
375
+
376
+ func logAPIKey(key string) string {
377
+ if len(key) <= 8 {
378
+ return "****"
379
+ }
380
+ return key[:4] + "..." + key[len(key)-4:]
381
+ }
382
+
383
+ // ---------------------- Server 结构 ----------------------
384
+
385
+ type Server struct {
386
+ config *Config
387
+ srv *http.Server
388
+ }
389
+
390
+ var (
391
+ randMu sync.Mutex
392
+ randGen = rand.New(rand.NewSource(time.Now().UnixNano()))
393
+ )
394
+
395
+ func NewServer(config *Config) *Server {
396
+ return &Server{
397
+ config: config,
398
+ }
399
+ }
400
+
401
+ func (s *Server) Start() error {
402
+ mux := http.NewServeMux()
403
+ mux.HandleFunc("/v1/chat/completions", s.handleOpenAIRequests)
404
+ mux.HandleFunc("/v1/models", s.handleOpenAIRequests)
405
+ mux.HandleFunc("/health", s.handleHealth)
406
+
407
+ s.srv = &http.Server{
408
+ Addr: fmt.Sprintf("%s:%d", s.config.Global.Server.Host, s.config.Global.Server.Port),
409
+ Handler: mux,
410
+ ReadTimeout: time.Duration(s.config.Global.Server.ReadTimeout) * time.Second,
411
+ WriteTimeout: time.Duration(s.config.Global.Server.WriteTimeout) * time.Second,
412
+ IdleTimeout: time.Duration(s.config.Global.Server.IdleTimeout) * time.Second,
413
+ }
414
+
415
+ log.Printf("Server starting on %s\n", s.srv.Addr)
416
+ return s.srv.ListenAndServe()
417
+ }
418
+
419
+ func (s *Server) Shutdown(ctx context.Context) error {
420
+ return s.srv.Shutdown(ctx)
421
+ }
422
+
423
+ func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
424
+ w.WriteHeader(http.StatusOK)
425
+ json.NewEncoder(w).Encode(map[string]string{"status": "healthy"})
426
+ }
427
+
428
+ func (s *Server) handleOpenAIRequests(w http.ResponseWriter, r *http.Request) {
429
+ logger := NewRequestLogger(s.config)
430
+
431
+ fullAPIKey := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
432
+ apiKey := extractRealAPIKey(fullAPIKey)
433
+ channelID := extractChannelID(fullAPIKey)
434
+
435
+ logger.Log("Received request for %s with API Key: %s", r.URL.Path, logAPIKey(fullAPIKey))
436
+ logger.Log("Extracted channel ID: %s", channelID)
437
+ logger.Log("Extracted real API Key: %s", logAPIKey(apiKey))
438
+
439
+ targetChannel, ok := s.config.Channels[channelID]
440
+ if !ok {
441
+ http.Error(w, "Invalid channel", http.StatusBadRequest)
442
+ return
443
+ }
444
+
445
+ if r.URL.Path == "/v1/models" {
446
+ if r.Method != http.MethodGet {
447
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
448
+ return
449
+ }
450
+ req := &ChatCompletionRequest{APIKey: apiKey}
451
+ s.forwardModelsRequest(w, r.Context(), req, targetChannel)
452
+ return
453
+ }
454
+
455
+ if r.Method != http.MethodPost {
456
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
457
+ return
458
+ }
459
+
460
+ body, err := io.ReadAll(r.Body)
461
+ if err != nil {
462
+ logger.Log("Error reading request body: %v", err)
463
+ http.Error(w, "Failed to read request", http.StatusBadRequest)
464
+ return
465
+ }
466
+ r.Body.Close()
467
+ r.Body = io.NopCloser(bytes.NewBuffer(body))
468
+
469
+ if s.config.Global.Log.Debug.PrintRequest {
470
+ logger.LogContent("Request", string(body), s.config.Global.Log.Debug.MaxContentLength)
471
+ }
472
+
473
+ var req ChatCompletionRequest
474
+ if err := json.NewDecoder(bytes.NewBuffer(body)).Decode(&req); err != nil {
475
+ http.Error(w, "Invalid request body", http.StatusBadRequest)
476
+ return
477
+ }
478
+ req.APIKey = apiKey
479
+
480
+ thinkingService := s.getWeightedRandomThinkingService()
481
+ logger.Log("Using thinking service: %s with API Key: %s", thinkingService.Name, logAPIKey(thinkingService.APIKey))
482
+
483
+ if req.Stream {
484
+ handler, err := NewStreamHandler(w, thinkingService, targetChannel, s.config)
485
+ if err != nil {
486
+ http.Error(w, "Streaming not supported", http.StatusInternalServerError)
487
+ return
488
+ }
489
+ if err := handler.HandleRequest(r.Context(), &req); err != nil {
490
+ logger.Log("Stream handler error: %v", err)
491
+ }
492
+ } else {
493
+ thinkingResp, err := s.processThinkingContent(r.Context(), &req, thinkingService)
494
+ if err != nil {
495
+ logger.Log("Error processing thinking content: %v", err)
496
+ http.Error(w, "Thinking service error: "+err.Error(), http.StatusInternalServerError)
497
+ return
498
+ }
499
+ enhancedReq := s.prepareEnhancedRequest(&req, thinkingResp, thinkingService)
500
+ s.forwardRequest(w, r.Context(), enhancedReq, targetChannel)
501
+ }
502
+ }
503
+
504
+ func (s *Server) getWeightedRandomThinkingService() ThinkingService {
505
+ thinkingServices := s.config.ThinkingServices
506
+ if len(thinkingServices) == 0 {
507
+ return ThinkingService{}
508
+ }
509
+ totalWeight := 0
510
+ for _, svc := range thinkingServices {
511
+ totalWeight += svc.Weight
512
+ }
513
+ if totalWeight <= 0 {
514
+ log.Println("Warning: Total weight of thinking services is not positive, using first service as default.")
515
+ return thinkingServices[0]
516
+ }
517
+ randMu.Lock()
518
+ randNum := randGen.Intn(totalWeight)
519
+ randMu.Unlock()
520
+ currentSum := 0
521
+ for _, svc := range thinkingServices {
522
+ currentSum += svc.Weight
523
+ if randNum < currentSum {
524
+ return svc
525
+ }
526
+ }
527
+ return thinkingServices[0]
528
+ }
529
+
530
+ // ---------------------- 非流式处理思考服务 ----------------------
531
+
532
+ type ThinkingResponse struct {
533
+ Content string
534
+ ReasoningContent string
535
+ }
536
+
537
+ func (s *Server) processThinkingContent(ctx context.Context, req *ChatCompletionRequest, svc ThinkingService) (*ThinkingResponse, error) {
538
+ logger := NewRequestLogger(s.config)
539
+ log.Printf("Getting thinking content from service: %s (mode=%s)", svc.Name, svc.Mode)
540
+
541
+ thinkingReq := *req
542
+ thinkingReq.Model = svc.Model
543
+ thinkingReq.APIKey = svc.APIKey
544
+
545
+ var systemPrompt string
546
+ if svc.Mode == "full" {
547
+ systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly."
548
+ } else {
549
+ systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step."
550
+ }
551
+ thinkingReq.Messages = append([]ChatCompletionMessage{
552
+ {Role: "system", Content: systemPrompt},
553
+ }, thinkingReq.Messages...)
554
+
555
+ temp := 0.7
556
+ if svc.Temperature != nil {
557
+ temp = *svc.Temperature
558
+ }
559
+ payload := map[string]interface{}{
560
+ "model": svc.Model,
561
+ "messages": thinkingReq.Messages,
562
+ "stream": false,
563
+ "temperature": temp,
564
+ }
565
+ if isValidReasoningEffort(svc.ReasoningEffort) {
566
+ payload["reasoning_effort"] = svc.ReasoningEffort
567
+ }
568
+ if isValidReasoningFormat(svc.ReasoningFormat) {
569
+ payload["reasoning_format"] = svc.ReasoningFormat
570
+ }
571
+
572
+ if s.config.Global.Log.Debug.PrintRequest {
573
+ logger.LogContent("Thinking Service Request", payload, s.config.Global.Log.Debug.MaxContentLength)
574
+ }
575
+
576
+ jsonData, err := json.Marshal(payload)
577
+ if err != nil {
578
+ return nil, fmt.Errorf("failed to marshal thinking request: %v", err)
579
+ }
580
+ client, err := createHTTPClient(svc.Proxy, time.Duration(svc.Timeout)*time.Second)
581
+ if err != nil {
582
+ return nil, fmt.Errorf("failed to create HTTP client: %v", err)
583
+ }
584
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, svc.GetFullURL(), bytes.NewBuffer(jsonData))
585
+ if err != nil {
586
+ return nil, fmt.Errorf("failed to create request: %v", err)
587
+ }
588
+ httpReq.Header.Set("Content-Type", "application/json")
589
+ httpReq.Header.Set("Authorization", "Bearer "+svc.APIKey)
590
+
591
+ resp, err := client.Do(httpReq)
592
+ if err != nil {
593
+ return nil, fmt.Errorf("failed to send thinking request: %v", err)
594
+ }
595
+ defer resp.Body.Close()
596
+
597
+ respBody, err := io.ReadAll(resp.Body)
598
+ if err != nil {
599
+ return nil, fmt.Errorf("failed to read response body: %v", err)
600
+ }
601
+ if s.config.Global.Log.Debug.PrintResponse {
602
+ logger.LogContent("Thinking Service Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
603
+ }
604
+ if resp.StatusCode != http.StatusOK {
605
+ return nil, fmt.Errorf("thinking service returned %d: %s", resp.StatusCode, string(respBody))
606
+ }
607
+
608
+ var thinkingResp ChatCompletionResponse
609
+ if err := json.Unmarshal(respBody, &thinkingResp); err != nil {
610
+ return nil, fmt.Errorf("failed to unmarshal thinking response: %v", err)
611
+ }
612
+ if len(thinkingResp.Choices) == 0 {
613
+ return nil, fmt.Errorf("thinking service returned no choices")
614
+ }
615
+
616
+ result := &ThinkingResponse{}
617
+ choice := thinkingResp.Choices[0]
618
+
619
+ if svc.Mode == "full" {
620
+ result.ReasoningContent = choice.Message.Content
621
+ result.Content = "Based on the above detailed analysis."
622
+ } else {
623
+ if choice.Message.ReasoningContent != nil {
624
+ switch v := choice.Message.ReasoningContent.(type) {
625
+ case string:
626
+ result.ReasoningContent = v
627
+ case map[string]interface{}:
628
+ if j, err := json.Marshal(v); err == nil {
629
+ result.ReasoningContent = string(j)
630
+ }
631
+ }
632
+ }
633
+ if result.ReasoningContent == "" {
634
+ result.ReasoningContent = choice.Message.Content
635
+ }
636
+ result.Content = "Based on the above reasoning."
637
+ }
638
+ return result, nil
639
+ }
640
+
641
+ func (s *Server) prepareEnhancedRequest(originalReq *ChatCompletionRequest, thinkingResp *ThinkingResponse, svc ThinkingService) *ChatCompletionRequest {
642
+ newReq := *originalReq
643
+ var systemPrompt string
644
+ if svc.Mode == "full" {
645
+ systemPrompt = fmt.Sprintf(`Consider the following detailed analysis (not shown to user):
646
+ %s
647
+
648
+ Provide a clear, concise response that incorporates insights from this analysis.`, thinkingResp.ReasoningContent)
649
+ } else {
650
+ systemPrompt = fmt.Sprintf(`Previous thinking process:
651
+ %s
652
+ Please consider the above thinking process in your response.`, thinkingResp.ReasoningContent)
653
+ }
654
+ newReq.Messages = append([]ChatCompletionMessage{
655
+ {Role: "system", Content: systemPrompt},
656
+ }, newReq.Messages...)
657
+ return &newReq
658
+ }
659
+
660
+ func (s *Server) forwardRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, channel Channel) {
661
+ logger := NewRequestLogger(s.config)
662
+ if s.config.Global.Log.Debug.PrintRequest {
663
+ logger.LogContent("Forward Request", req, s.config.Global.Log.Debug.MaxContentLength)
664
+ }
665
+ jsonData, err := json.Marshal(req)
666
+ if err != nil {
667
+ http.Error(w, "Failed to marshal request", http.StatusInternalServerError)
668
+ return
669
+ }
670
+
671
+ client, err := createHTTPClient(channel.Proxy, time.Duration(channel.Timeout)*time.Second)
672
+ if err != nil {
673
+ http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError)
674
+ return
675
+ }
676
+
677
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, channel.GetFullURL(), bytes.NewBuffer(jsonData))
678
+ if err != nil {
679
+ http.Error(w, "Failed to create request", http.StatusInternalServerError)
680
+ return
681
+ }
682
+ httpReq.Header.Set("Content-Type", "application/json")
683
+ httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
684
+
685
+ resp, err := client.Do(httpReq)
686
+ if err != nil {
687
+ http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError)
688
+ return
689
+ }
690
+ defer resp.Body.Close()
691
+
692
+ respBody, err := io.ReadAll(resp.Body)
693
+ if err != nil {
694
+ http.Error(w, "Failed to read response", http.StatusInternalServerError)
695
+ return
696
+ }
697
+ if s.config.Global.Log.Debug.PrintResponse {
698
+ logger.LogContent("Forward Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
699
+ }
700
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
701
+ http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode)
702
+ return
703
+ }
704
+
705
+ for k, vals := range resp.Header {
706
+ for _, v := range vals {
707
+ w.Header().Add(k, v)
708
+ }
709
+ }
710
+ w.WriteHeader(resp.StatusCode)
711
+ w.Write(respBody)
712
+ }
713
+
714
+ func (s *Server) forwardModelsRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, targetChannel Channel) {
715
+ logger := NewRequestLogger(s.config)
716
+ if s.config.Global.Log.Debug.PrintRequest {
717
+ logger.LogContent("/v1/models Request", req, s.config.Global.Log.Debug.MaxContentLength)
718
+ }
719
+ fullChatURL := targetChannel.GetFullURL()
720
+ parsedURL, err := url.Parse(fullChatURL)
721
+ if err != nil {
722
+ http.Error(w, "Failed to parse channel URL", http.StatusInternalServerError)
723
+ return
724
+ }
725
+ baseURL := parsedURL.Scheme + "://" + parsedURL.Host
726
+ modelsURL := strings.TrimSuffix(baseURL, "/") + "/v1/models"
727
+
728
+ client, err := createHTTPClient(targetChannel.Proxy, time.Duration(targetChannel.Timeout)*time.Second)
729
+ if err != nil {
730
+ http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError)
731
+ return
732
+ }
733
+
734
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil)
735
+ if err != nil {
736
+ http.Error(w, "Failed to create request", http.StatusInternalServerError)
737
+ return
738
+ }
739
+ httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
740
+
741
+ resp, err := client.Do(httpReq)
742
+ if err != nil {
743
+ http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError)
744
+ return
745
+ }
746
+ defer resp.Body.Close()
747
+
748
+ respBody, err := io.ReadAll(resp.Body)
749
+ if err != nil {
750
+ http.Error(w, "Failed to read response", http.StatusInternalServerError)
751
+ return
752
+ }
753
+ if s.config.Global.Log.Debug.PrintResponse {
754
+ logger.LogContent("/v1/models Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
755
+ }
756
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
757
+ http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode)
758
+ return
759
+ }
760
+
761
+ for k, vals := range resp.Header {
762
+ for _, v := range vals {
763
+ w.Header().Add(k, v)
764
+ }
765
+ }
766
+ w.WriteHeader(resp.StatusCode)
767
+ w.Write(respBody)
768
+ }
769
+
770
+ // ---------------------- 流式处理 ----------------------
771
+
772
+ // collectedReasoningBuffer 用于收集思考服务返回的 reasoning_content(标准模式只收集 reasoning_content;full 模式收集全部)
773
+ type collectedReasoningBuffer struct {
774
+ builder strings.Builder
775
+ mode string
776
+ }
777
+
778
+ func (cb *collectedReasoningBuffer) append(text string) {
779
+ cb.builder.WriteString(text)
780
+ }
781
+
782
+ func (cb *collectedReasoningBuffer) get() string {
783
+ return cb.builder.String()
784
+ }
785
+
786
+ type StreamHandler struct {
787
+ thinkingService ThinkingService
788
+ targetChannel Channel
789
+ writer http.ResponseWriter
790
+ flusher http.Flusher
791
+ config *Config
792
+ }
793
+
794
+ func NewStreamHandler(w http.ResponseWriter, thinkingService ThinkingService, targetChannel Channel, config *Config) (*StreamHandler, error) {
795
+ flusher, ok := w.(http.Flusher)
796
+ if !ok {
797
+ return nil, fmt.Errorf("streaming not supported by this writer")
798
+ }
799
+ return &StreamHandler{
800
+ thinkingService: thinkingService,
801
+ targetChannel: targetChannel,
802
+ writer: w,
803
+ flusher: flusher,
804
+ config: config,
805
+ }, nil
806
+ }
807
+
808
+ // HandleRequest 分两阶段:先流式接收思考服务 SSE 并转发给客户端(只转发包含 reasoning_content 的 chunk,标准模式遇到纯 content 则中断);
809
+ // 然后使用收集到的 reasoning_content 构造最终请求,发起目标 Channel 的 SSE 并转发给客户端。
810
+ func (h *StreamHandler) HandleRequest(ctx context.Context, req *ChatCompletionRequest) error {
811
+ h.writer.Header().Set("Content-Type", "text/event-stream")
812
+ h.writer.Header().Set("Cache-Control", "no-cache")
813
+ h.writer.Header().Set("Connection", "keep-alive")
814
+
815
+ logger := NewRequestLogger(h.config)
816
+ reasonBuf := &collectedReasoningBuffer{mode: h.thinkingService.Mode}
817
+
818
+ // 阶段一:流式接收思考服务 SSE
819
+ if err := h.streamThinkingService(ctx, req, reasonBuf, logger); err != nil {
820
+ return err
821
+ }
822
+
823
+ // 阶段二:构造最终请求,并流式转发目标 Channel SSE
824
+ finalReq := h.prepareFinalRequest(req, reasonBuf.get())
825
+ if err := h.streamFinalChannel(ctx, finalReq, logger); err != nil {
826
+ return err
827
+ }
828
+ return nil
829
+ }
830
+
831
+ // streamThinkingService 连接思考服务 SSE,原样转发包含 reasoning_content 的 SSE chunk给客户端,同时收集 reasoning_content。
832
+ // 在标准模式下,一旦遇到只返回 non-empty 的 content(且 reasoning_content 为空),则立即中断,不转发该 chunk。
833
+ func (h *StreamHandler) streamThinkingService(ctx context.Context, req *ChatCompletionRequest, reasonBuf *collectedReasoningBuffer, logger *RequestLogger) error {
834
+ thinkingReq := *req
835
+ thinkingReq.Model = h.thinkingService.Model
836
+ thinkingReq.APIKey = h.thinkingService.APIKey
837
+
838
+ var systemPrompt string
839
+ if h.thinkingService.Mode == "full" {
840
+ systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly."
841
+ } else {
842
+ systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step."
843
+ }
844
+ messages := append([]ChatCompletionMessage{
845
+ {Role: "system", Content: systemPrompt},
846
+ }, thinkingReq.Messages...)
847
+
848
+ temp := 0.7
849
+ if h.thinkingService.Temperature != nil {
850
+ temp = *h.thinkingService.Temperature
851
+ }
852
+ bodyMap := map[string]interface{}{
853
+ "model": thinkingReq.Model,
854
+ "messages": messages,
855
+ "temperature": temp,
856
+ "stream": true,
857
+ }
858
+ if isValidReasoningEffort(h.thinkingService.ReasoningEffort) {
859
+ bodyMap["reasoning_effort"] = h.thinkingService.ReasoningEffort
860
+ }
861
+ if isValidReasoningFormat(h.thinkingService.ReasoningFormat) {
862
+ bodyMap["reasoning_format"] = h.thinkingService.ReasoningFormat
863
+ }
864
+
865
+ jsonData, err := json.Marshal(bodyMap)
866
+ if err != nil {
867
+ return err
868
+ }
869
+
870
+ client, err := createHTTPClient(h.thinkingService.Proxy, time.Duration(h.thinkingService.Timeout)*time.Second)
871
+ if err != nil {
872
+ return err
873
+ }
874
+
875
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.thinkingService.GetFullURL(), bytes.NewBuffer(jsonData))
876
+ if err != nil {
877
+ return err
878
+ }
879
+ httpReq.Header.Set("Content-Type", "application/json")
880
+ httpReq.Header.Set("Authorization", "Bearer "+h.thinkingService.APIKey)
881
+
882
+ log.Printf("Starting ThinkingService SSE => %s (mode=%s, force_stop=%v)", h.thinkingService.GetFullURL(), h.thinkingService.Mode, h.thinkingService.ForceStopDeepThinking)
883
+ resp, err := client.Do(httpReq)
884
+ if err != nil {
885
+ return err
886
+ }
887
+ defer resp.Body.Close()
888
+
889
+ if resp.StatusCode != http.StatusOK {
890
+ b, _ := io.ReadAll(resp.Body)
891
+ return fmt.Errorf("thinking service status=%d, body=%s", resp.StatusCode, string(b))
892
+ }
893
+
894
+ reader := bufio.NewReader(resp.Body)
895
+ var lastLine string
896
+ // 标识是否需中断思考 SSE(标准模式下)
897
+ forceStop := false
898
+
899
+ for {
900
+ select {
901
+ case <-ctx.Done():
902
+ return ctx.Err()
903
+ default:
904
+ }
905
+ line, err := reader.ReadString('\n')
906
+ if err != nil {
907
+ if err == io.EOF {
908
+ break
909
+ }
910
+ return err
911
+ }
912
+ line = strings.TrimSpace(line)
913
+ if line == "" || line == lastLine {
914
+ continue
915
+ }
916
+ lastLine = line
917
+
918
+ if !strings.HasPrefix(line, "data: ") {
919
+ continue
920
+ }
921
+ dataPart := strings.TrimPrefix(line, "data: ")
922
+ if dataPart == "[DONE]" {
923
+ // 思考服务结束:直接中断(不发送特殊标记)
924
+ break
925
+ }
926
+
927
+ var chunk struct {
928
+ Choices []struct {
929
+ Delta struct {
930
+ Content string `json:"content,omitempty"`
931
+ ReasoningContent string `json:"reasoning_content,omitempty"`
932
+ } `json:"delta"`
933
+ FinishReason *string `json:"finish_reason,omitempty"`
934
+ } `json:"choices"`
935
+ }
936
+ if err := json.Unmarshal([]byte(dataPart), &chunk); err != nil {
937
+ // 如果解析失败,直接原样转发
938
+ h.writer.Write([]byte("data: " + dataPart + "\n\n"))
939
+ h.flusher.Flush()
940
+ continue
941
+ }
942
+
943
+ if len(chunk.Choices) > 0 {
944
+ c := chunk.Choices[0]
945
+ if h.config.Global.Log.Debug.PrintResponse {
946
+ logger.LogContent("Thinking SSE chunk", chunk, h.config.Global.Log.Debug.MaxContentLength)
947
+ }
948
+
949
+ if h.thinkingService.Mode == "full" {
950
+ // full 模式:收集所有内容
951
+ if c.Delta.ReasoningContent != "" {
952
+ reasonBuf.append(c.Delta.ReasoningContent)
953
+ }
954
+ if c.Delta.Content != "" {
955
+ reasonBuf.append(c.Delta.Content)
956
+ }
957
+ // 原样转发整 chunk
958
+ forwardLine := "data: " + dataPart + "\n\n"
959
+ h.writer.Write([]byte(forwardLine))
960
+ h.flusher.Flush()
961
+ } else {
962
+ // standard 模式:只收集 reasoning_content
963
+ if c.Delta.ReasoningContent != "" {
964
+ reasonBuf.append(c.Delta.ReasoningContent)
965
+ // 转发该 chunk
966
+ forwardLine := "data: " + dataPart + "\n\n"
967
+ h.writer.Write([]byte(forwardLine))
968
+ h.flusher.Flush()
969
+ }
970
+ // 如果遇到 chunk 中只有 content(且 reasoning_content 为空),认为思考链结束,不转发该 chunk
971
+ if c.Delta.Content != "" && strings.TrimSpace(c.Delta.ReasoningContent) == "" {
972
+ forceStop = true
973
+ }
974
+ }
975
+
976
+ // 如果 finishReason 非空,也认为结束
977
+ if c.FinishReason != nil && *c.FinishReason != "" {
978
+ forceStop = true
979
+ }
980
+ }
981
+
982
+ if forceStop {
983
+ break
984
+ }
985
+ }
986
+ // 读空剩余数据
987
+ io.Copy(io.Discard, reader)
988
+ return nil
989
+ }
990
+
991
+ func (h *StreamHandler) prepareFinalRequest(originalReq *ChatCompletionRequest, reasoningCollected string) *ChatCompletionRequest {
992
+ req := *originalReq
993
+ var systemPrompt string
994
+ if h.thinkingService.Mode == "full" {
995
+ systemPrompt = fmt.Sprintf(
996
+ `Consider the following detailed analysis (not shown to user):
997
+ %s
998
+
999
+ Provide a clear, concise response that incorporates insights from this analysis.`,
1000
+ reasoningCollected,
1001
+ )
1002
+ } else {
1003
+ systemPrompt = fmt.Sprintf(
1004
+ `Previous thinking process:
1005
+ %s
1006
+ Please consider the above thinking process in your response.`,
1007
+ reasoningCollected,
1008
+ )
1009
+ }
1010
+ req.Messages = append([]ChatCompletionMessage{
1011
+ {Role: "system", Content: systemPrompt},
1012
+ }, req.Messages...)
1013
+ return &req
1014
+ }
1015
+
1016
+ func (h *StreamHandler) streamFinalChannel(ctx context.Context, req *ChatCompletionRequest, logger *RequestLogger) error {
1017
+ if h.config.Global.Log.Debug.PrintRequest {
1018
+ logger.LogContent("Final Request to Channel", req, h.config.Global.Log.Debug.MaxContentLength)
1019
+ }
1020
+
1021
+ jsonData, err := json.Marshal(req)
1022
+ if err != nil {
1023
+ return err
1024
+ }
1025
+ client, err := createHTTPClient(h.targetChannel.Proxy, time.Duration(h.targetChannel.Timeout)*time.Second)
1026
+ if err != nil {
1027
+ return err
1028
+ }
1029
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.targetChannel.GetFullURL(), bytes.NewBuffer(jsonData))
1030
+ if err != nil {
1031
+ return err
1032
+ }
1033
+ httpReq.Header.Set("Content-Type", "application/json")
1034
+ httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
1035
+
1036
+ resp, err := client.Do(httpReq)
1037
+ if err != nil {
1038
+ return err
1039
+ }
1040
+ defer resp.Body.Close()
1041
+
1042
+ if resp.StatusCode != http.StatusOK {
1043
+ b, _ := io.ReadAll(resp.Body)
1044
+ return fmt.Errorf("target channel status=%d, body=%s", resp.StatusCode, string(b))
1045
+ }
1046
+
1047
+ reader := bufio.NewReader(resp.Body)
1048
+ var lastLine string
1049
+
1050
+ for {
1051
+ select {
1052
+ case <-ctx.Done():
1053
+ return ctx.Err()
1054
+ default:
1055
+ }
1056
+ line, err := reader.ReadString('\n')
1057
+ if err != nil {
1058
+ if err == io.EOF {
1059
+ break
1060
+ }
1061
+ return err
1062
+ }
1063
+ line = strings.TrimSpace(line)
1064
+ if line == "" || line == lastLine {
1065
+ continue
1066
+ }
1067
+ lastLine = line
1068
+
1069
+ if !strings.HasPrefix(line, "data: ") {
1070
+ continue
1071
+ }
1072
+ data := strings.TrimPrefix(line, "data: ")
1073
+ if data == "[DONE]" {
1074
+ doneLine := "data: [DONE]\n\n"
1075
+ h.writer.Write([]byte(doneLine))
1076
+ h.flusher.Flush()
1077
+ break
1078
+ }
1079
+ forwardLine := "data: " + data + "\n\n"
1080
+ h.writer.Write([]byte(forwardLine))
1081
+ h.flusher.Flush()
1082
+
1083
+ if h.config.Global.Log.Debug.PrintResponse {
1084
+ logger.LogContent("Channel SSE chunk", forwardLine, h.config.Global.Log.Debug.MaxContentLength)
1085
+ }
1086
+ }
1087
+ io.Copy(io.Discard, reader)
1088
+ return nil
1089
+ }
1090
+
1091
+ // ---------------------- HTTP Client 工具 ----------------------
1092
+
1093
+ func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) {
1094
+ transport := &http.Transport{
1095
+ DialContext: (&net.Dialer{
1096
+ Timeout: 30 * time.Second,
1097
+ KeepAlive: 30 * time.Second,
1098
+ }).DialContext,
1099
+ ForceAttemptHTTP2: true,
1100
+ MaxIdleConns: 100,
1101
+ IdleConnTimeout: 90 * time.Second,
1102
+ TLSHandshakeTimeout: 10 * time.Second,
1103
+ ExpectContinueTimeout: 1 * time.Second,
1104
+ }
1105
+ if proxyURL != "" {
1106
+ parsedURL, err := url.Parse(proxyURL)
1107
+ if err != nil {
1108
+ return nil, fmt.Errorf("invalid proxy URL: %v", err)
1109
+ }
1110
+ switch parsedURL.Scheme {
1111
+ case "http", "https":
1112
+ transport.Proxy = http.ProxyURL(parsedURL)
1113
+ case "socks5":
1114
+ dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
1115
+ if err != nil {
1116
+ return nil, fmt.Errorf("failed to create SOCKS5 dialer: %v", err)
1117
+ }
1118
+ transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
1119
+ return dialer.Dial(network, addr)
1120
+ }
1121
+ default:
1122
+ return nil, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
1123
+ }
1124
+ }
1125
+ return &http.Client{
1126
+ Transport: transport,
1127
+ Timeout: timeout,
1128
+ }, nil
1129
+ }
1130
+
1131
+ func maskSensitiveHeaders(headers http.Header) http.Header {
1132
+ masked := make(http.Header)
1133
+ for k, vals := range headers {
1134
+ if strings.ToLower(k) == "authorization" {
1135
+ masked[k] = []string{"Bearer ****"}
1136
+ } else {
1137
+ masked[k] = vals
1138
+ }
1139
+ }
1140
+ return masked
1141
+ }
1142
+
1143
+ func isValidReasoningEffort(effort string) bool {
1144
+ switch strings.ToLower(effort) {
1145
+ case "low", "medium", "high":
1146
+ return true
1147
+ }
1148
+ return false
1149
+ }
1150
+
1151
+ func isValidReasoningFormat(format string) bool {
1152
+ switch strings.ToLower(format) {
1153
+ case "parsed", "raw", "hidden":
1154
+ return true
1155
+ }
1156
+ return false
1157
+ }
1158
+
1159
+ // ---------------------- 配置加载与验证 ----------------------
1160
+
1161
+ func loadConfig() (*Config, error) {
1162
+ var configFile string
1163
+ flag.StringVar(&configFile, "config", "", "path to config file")
1164
+ flag.Parse()
1165
+ viper.SetConfigType("yaml")
1166
+
1167
+ if configFile != "" {
1168
+ viper.SetConfigFile(configFile)
1169
+ } else {
1170
+ ex, err := os.Executable()
1171
+ if err != nil {
1172
+ return nil, err
1173
+ }
1174
+ exePath := filepath.Dir(ex)
1175
+ defaultPaths := []string{
1176
+ filepath.Join(exePath, "config.yaml"),
1177
+ filepath.Join(exePath, "conf", "config.yaml"),
1178
+ "./config.yaml",
1179
+ "./conf/config.yaml",
1180
+ }
1181
+ if os.PathSeparator == '\\' {
1182
+ programData := os.Getenv("PROGRAMDATA")
1183
+ if programData != "" {
1184
+ defaultPaths = append(defaultPaths, filepath.Join(programData, "DeepAI", "config.yaml"))
1185
+ }
1186
+ } else {
1187
+ defaultPaths = append(defaultPaths, "/etc/deepai/config.yaml")
1188
+ }
1189
+ for _, p := range defaultPaths {
1190
+ viper.AddConfigPath(filepath.Dir(p))
1191
+ if strings.Contains(p, ".yaml") {
1192
+ viper.SetConfigName(strings.TrimSuffix(filepath.Base(p), ".yaml"))
1193
+ }
1194
+ }
1195
+ }
1196
+
1197
+ if err := viper.ReadInConfig(); err != nil {
1198
+ return nil, fmt.Errorf("failed to read config: %v", err)
1199
+ }
1200
+
1201
+ var cfg Config
1202
+ if err := viper.Unmarshal(&cfg); err != nil {
1203
+ return nil, fmt.Errorf("unmarshal config error: %v", err)
1204
+ }
1205
+ if err := validateConfig(&cfg); err != nil {
1206
+ return nil, fmt.Errorf("config validation error: %v", err)
1207
+ }
1208
+ return &cfg, nil
1209
+ }
1210
+
1211
+ func validateConfig(config *Config) error {
1212
+ if len(config.ThinkingServices) == 0 {
1213
+ return fmt.Errorf("no thinking services configured")
1214
+ }
1215
+ if len(config.Channels) == 0 {
1216
+ return fmt.Errorf("no channels configured")
1217
+ }
1218
+ for i, svc := range config.ThinkingServices {
1219
+ if svc.BaseURL == "" {
1220
+ return fmt.Errorf("thinking service %s has empty baseURL", svc.Name)
1221
+ }
1222
+ if svc.APIKey == "" {
1223
+ return fmt.Errorf("thinking service %s has empty apiKey", svc.Name)
1224
+ }
1225
+ if svc.Timeout <= 0 {
1226
+ return fmt.Errorf("thinking service %s has invalid timeout", svc.Name)
1227
+ }
1228
+ if svc.Model == "" {
1229
+ return fmt.Errorf("thinking service %s has empty model", svc.Name)
1230
+ }
1231
+ if svc.Mode == "" {
1232
+ config.ThinkingServices[i].Mode = "standard"
1233
+ } else if svc.Mode != "standard" && svc.Mode != "full" {
1234
+ return fmt.Errorf("thinking service %s unknown mode=%s", svc.Name, svc.Mode)
1235
+ }
1236
+ }
1237
+ return nil
1238
+ }
1239
+
1240
+ func main() {
1241
+ log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile)
1242
+
1243
+ cfg, err := loadConfig()
1244
+ if err != nil {
1245
+ log.Fatalf("Failed to load config: %v", err)
1246
+ }
1247
+ log.Printf("Using config file: %s", viper.ConfigFileUsed())
1248
+
1249
+ server := NewServer(cfg)
1250
+
1251
+ done := make(chan os.Signal, 1)
1252
+ signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
1253
+
1254
+ go func() {
1255
+ if err := server.Start(); err != nil && err != http.ErrServerClosed {
1256
+ log.Fatalf("start server error: %v", err)
1257
+ }
1258
+ }()
1259
+ log.Printf("Server started successfully")
1260
+
1261
+ <-done
1262
+ log.Print("Server stopping...")
1263
+
1264
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
1265
+ defer cancel()
1266
+
1267
+ if err := server.Shutdown(ctx); err != nil {
1268
+ log.Printf("Server forced to shutdown: %v", err)
1269
+ }
1270
+ log.Print("Server stopped")
1271
+ }