nbugs commited on
Commit
08d5b68
·
verified ·
1 Parent(s): 3741176

Upload main.go

Browse files
Files changed (1) hide show
  1. main.go +1337 -0
main.go ADDED
@@ -0,0 +1,1337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "bufio"
5
+ "bytes"
6
+ "context"
7
+ "encoding/json"
8
+ "flag"
9
+ "fmt"
10
+ "io"
11
+ "log"
12
+ "math/rand"
13
+ "net"
14
+ "net/http"
15
+ "net/url"
16
+ "os"
17
+ "os/signal"
18
+ "path/filepath"
19
+ "regexp"
20
+ "strings"
21
+ "sync"
22
+ "syscall"
23
+ "time"
24
+
25
+ "github.com/google/uuid"
26
+ "github.com/spf13/viper"
27
+ "golang.org/x/net/proxy"
28
+ )
29
+
30
+ // ---------------------- 配置结构 ----------------------
31
+
32
+ type Config struct {
33
+ ThinkingServices []ThinkingService `mapstructure:"thinking_services"`
34
+ Channels map[string]Channel `mapstructure:"channels"`
35
+ Global GlobalConfig `mapstructure:"global"`
36
+ }
37
+
38
+ type ThinkingService struct {
39
+ ID int `mapstructure:"id"`
40
+ Name string `mapstructure:"name"`
41
+ Model string `mapstructure:"model"`
42
+ BaseURL string `mapstructure:"base_url"`
43
+ APIPath string `mapstructure:"api_path"`
44
+ APIKey string `mapstructure:"api_key"`
45
+ Timeout int `mapstructure:"timeout"`
46
+ Retry int `mapstructure:"retry"`
47
+ Weight int `mapstructure:"weight"`
48
+ Proxy string `mapstructure:"proxy"`
49
+ Mode string `mapstructure:"mode"` // "standard" 或 "full"
50
+ ReasoningEffort string `mapstructure:"reasoning_effort"`
51
+ ReasoningFormat string `mapstructure:"reasoning_format"`
52
+ Temperature *float64 `mapstructure:"temperature"`
53
+ ForceStopDeepThinking bool `mapstructure:"force_stop_deep_thinking"` // 配置项:标准模式下遇到 content 时是否立即停止
54
+ }
55
+
56
+ // ... existing code ...
57
+
58
+ // 新增函数:从环境变量加载配置
59
+ func loadConfigFromEnv() *Config {
60
+ cfg := &Config{
61
+ ThinkingServices: []ThinkingService{
62
+ {
63
+ ID: 1,
64
+ Name: os.Getenv("THINKING_SERVICE_NAME"),
65
+ Mode: getEnvOrDefault("THINKING_SERVICE_MODE", "standard"),
66
+ Model: os.Getenv("THINKING_SERVICE_MODEL"),
67
+ BaseURL: os.Getenv("THINKING_SERVICE_BASE_URL"),
68
+ APIPath: "/v1/chat/completions",
69
+ APIKey: os.Getenv("THINKING_SERVICE_API_KEY"),
70
+ Timeout: 600,
71
+ Weight: 100,
72
+ Proxy: "",
73
+ ReasoningEffort: "high",
74
+ ReasoningFormat: "parsed",
75
+ Temperature: getFloatPtr(0.8),
76
+ ForceStopDeepThinking: false,
77
+ },
78
+ },
79
+ Channels: map[string]Channel{
80
+ "1": {
81
+ Name: os.Getenv("CHANNEL_NAME"),
82
+ BaseURL: os.Getenv("CHANNEL_BASE_URL"),
83
+ APIPath: "/v1/chat/completions",
84
+ Timeout: 600,
85
+ Proxy: "",
86
+ },
87
+ },
88
+ Global: GlobalConfig{
89
+ Log: LogConfig{
90
+ Level: "error",
91
+ Format: "json",
92
+ Output: "console",
93
+ Debug: DebugConfig{
94
+ Enabled: true,
95
+ PrintRequest: true,
96
+ PrintResponse: true,
97
+ MaxContentLength: 1000,
98
+ },
99
+ },
100
+ Server: ServerConfig{
101
+ Port: 8888,
102
+ Host: "0.0.0.0",
103
+ ReadTimeout: 600,
104
+ WriteTimeout: 600,
105
+ IdleTimeout: 600,
106
+ },
107
+ },
108
+ }
109
+ return cfg
110
+ }
111
+
112
+ // 新增辅助函数:获取环境变量,如果不存在则返回默认值
113
+ func getEnvOrDefault(key, defaultValue string) string {
114
+ if value := os.Getenv(key); value != "" {
115
+ return value
116
+ }
117
+ return defaultValue
118
+ }
119
+
120
+ // 新增辅助函数:创建float64指针
121
+ func getFloatPtr(value float64) *float64 {
122
+ return &value
123
+ }
124
+
125
+ // 修改 loadConfig 函数
126
+ func loadConfig() (*Config, error) {
127
+ // 优先使用环境变量
128
+ if os.Getenv("USE_ENV_CONFIG") == "true" {
129
+ cfg := loadConfigFromEnv()
130
+ if err := validateConfig(cfg); err != nil {
131
+ return nil, fmt.Errorf("环境变量配置验证错误: %v", err)
132
+ }
133
+ return cfg, nil
134
+ }
135
+
136
+ // 如果不使用环境变量,则使用原有的配置文件加载逻辑
137
+ var configFile string
138
+ flag.StringVar(&configFile, "config", "", "path to config file")
139
+ flag.Parse()
140
+ viper.SetConfigType("yaml")
141
+
142
+ // ... existing code ...
143
+ }
144
+ // ... existing code ...
145
+
146
+ // 修改第一个 loadConfig 函数名为 loadConfigFromFile
147
+ func loadConfigFromFile() (*Config, error) {
148
+ var configFile string
149
+ flag.StringVar(&configFile, "config", "", "path to config file")
150
+ flag.Parse()
151
+ viper.SetConfigType("yaml")
152
+
153
+ if configFile != "" {
154
+ viper.SetConfigFile(configFile)
155
+ } else {
156
+ ex, err := os.Executable()
157
+ if err != nil {
158
+ return nil, err
159
+ }
160
+ exePath := filepath.Dir(ex)
161
+ defaultPaths := []string{
162
+ filepath.Join(exePath, "config.yaml"),
163
+ filepath.Join(exePath, "conf", "config.yaml"),
164
+ "./config.yaml",
165
+ "./conf/config.yaml",
166
+ }
167
+ if os.PathSeparator == '\\' {
168
+ programData := os.Getenv("PROGRAMDATA")
169
+ if programData != "" {
170
+ defaultPaths = append(defaultPaths, filepath.Join(programData, "DeepAI", "config.yaml"))
171
+ }
172
+ } else {
173
+ defaultPaths = append(defaultPaths, "/etc/deepai/config.yaml")
174
+ }
175
+ for _, p := range defaultPaths {
176
+ viper.AddConfigPath(filepath.Dir(p))
177
+ if strings.Contains(p, ".yaml") {
178
+ viper.SetConfigName(strings.TrimSuffix(filepath.Base(p), ".yaml"))
179
+ }
180
+ }
181
+ }
182
+
183
+ if err := viper.ReadInConfig(); err != nil {
184
+ return nil, fmt.Errorf("failed to read config: %v", err)
185
+ }
186
+
187
+ var cfg Config
188
+ if err := viper.Unmarshal(&cfg); err != nil {
189
+ return nil, fmt.Errorf("unmarshal config error: %v", err)
190
+ }
191
+ if err := validateConfig(&cfg); err != nil {
192
+ return nil, fmt.Errorf("config validation error: %v", err)
193
+ }
194
+ return &cfg, nil
195
+ }
196
+
197
+ // 修改 loadConfig 函数
198
+ func loadConfig() (*Config, error) {
199
+ // 优先使用环境变量
200
+ if os.Getenv("USE_ENV_CONFIG") == "true" {
201
+ cfg := loadConfigFromEnv()
202
+ if err := validateConfig(cfg); err != nil {
203
+ return nil, fmt.Errorf("环境变量配置验证错误: %v", err)
204
+ }
205
+ return cfg, nil
206
+ }
207
+
208
+ // 如果不使用环境变量,则使用配置文件加载逻辑
209
+ return loadConfigFromFile()
210
+ }
211
+
212
+ func (s *ThinkingService) GetFullURL() string {
213
+ return s.BaseURL + s.APIPath
214
+ }
215
+
216
+ type Channel struct {
217
+ Name string `mapstructure:"name"`
218
+ BaseURL string `mapstructure:"base_url"`
219
+ APIPath string `mapstructure:"api_path"`
220
+ Timeout int `mapstructure:"timeout"`
221
+ Proxy string `mapstructure:"proxy"`
222
+ }
223
+
224
+ func (c *Channel) GetFullURL() string {
225
+ return c.BaseURL + c.APIPath
226
+ }
227
+
228
+ type LogConfig struct {
229
+ Level string `mapstructure:"level"`
230
+ Format string `mapstructure:"format"`
231
+ Output string `mapstructure:"output"`
232
+ FilePath string `mapstructure:"file_path"`
233
+ Debug DebugConfig `mapstructure:"debug"`
234
+ }
235
+
236
+ type DebugConfig struct {
237
+ Enabled bool `mapstructure:"enabled"`
238
+ PrintRequest bool `mapstructure:"print_request"`
239
+ PrintResponse bool `mapstructure:"print_response"`
240
+ MaxContentLength int `mapstructure:"max_content_length"`
241
+ }
242
+
243
+ type ProxyConfig struct {
244
+ Enabled bool `mapstructure:"enabled"`
245
+ Default string `mapstructure:"default"`
246
+ AllowInsecure bool `mapstructure:"allow_insecure"`
247
+ }
248
+
249
+ type GlobalConfig struct {
250
+ MaxRetries int `mapstructure:"max_retries"`
251
+ DefaultTimeout int `mapstructure:"default_timeout"`
252
+ ErrorCodes struct {
253
+ RetryOn []int `mapstructure:"retry_on"`
254
+ } `mapstructure:"error_codes"`
255
+ Log LogConfig `mapstructure:"log"`
256
+ Server ServerConfig `mapstructure:"server"`
257
+ Proxy ProxyConfig `mapstructure:"proxy"`
258
+ ConfigPaths []string `mapstructure:"config_paths"`
259
+ Thinking ThinkingConfig `mapstructure:"thinking"`
260
+ }
261
+
262
+ type ServerConfig struct {
263
+ Port int `mapstructure:"port"`
264
+ Host string `mapstructure:"host"`
265
+ ReadTimeout int `mapstructure:"read_timeout"`
266
+ WriteTimeout int `mapstructure:"write_timeout"`
267
+ IdleTimeout int `mapstructure:"idle_timeout"`
268
+ }
269
+
270
+ type ThinkingConfig struct {
271
+ Enabled bool `mapstructure:"enabled"`
272
+ AddToAllRequests bool `mapstructure:"add_to_all_requests"`
273
+ Timeout int `mapstructure:"timeout"`
274
+ }
275
+
276
+ // ---------------------- API 相关结构 ----------------------
277
+
278
+ type ChatCompletionRequest struct {
279
+ Model string `json:"model"`
280
+ Messages []ChatCompletionMessage `json:"messages"`
281
+ Temperature float64 `json:"temperature,omitempty"`
282
+ MaxTokens int `json:"max_tokens,omitempty"`
283
+ Stream bool `json:"stream,omitempty"`
284
+ APIKey string `json:"-"` // 内部传递,不序列化
285
+ }
286
+
287
+ type ChatCompletionMessage struct {
288
+ Role string `json:"role"`
289
+ Content string `json:"content"`
290
+ ReasoningContent interface{} `json:"reasoning_content,omitempty"`
291
+ }
292
+
293
+ type ChatCompletionResponse struct {
294
+ ID string `json:"id"`
295
+ Object string `json:"object"`
296
+ Created int64 `json:"created"`
297
+ Model string `json:"model"`
298
+ Choices []Choice `json:"choices"`
299
+ Usage Usage `json:"usage"`
300
+ }
301
+
302
+ type Choice struct {
303
+ Index int `json:"index"`
304
+ Message ChatCompletionMessage `json:"message"`
305
+ FinishReason string `json:"finish_reason"`
306
+ }
307
+
308
+ type Usage struct {
309
+ PromptTokens int `json:"prompt_tokens"`
310
+ CompletionTokens int `json:"completion_tokens"`
311
+ TotalTokens int `json:"total_tokens"`
312
+ }
313
+
314
+ // ---------------------- 日志工具 ----------------------
315
+
316
+ type RequestLogger struct {
317
+ RequestID string
318
+ Model string
319
+ StartTime time.Time
320
+ logs []string
321
+ config *Config
322
+ }
323
+
324
+ func NewRequestLogger(config *Config) *RequestLogger {
325
+ return &RequestLogger{
326
+ RequestID: uuid.New().String(),
327
+ StartTime: time.Now(),
328
+ logs: make([]string, 0),
329
+ config: config,
330
+ }
331
+ }
332
+
333
+ // 新增函数:替换环境变量
334
+ func replaceEnvVars(value string) string {
335
+ // 检查是否包含环境变量占位符 ${...}
336
+ if !strings.Contains(value, "${") {
337
+ return value
338
+ }
339
+
340
+ // 查找所有 ${...} 格式的环境变量并替换
341
+ result := value
342
+ for {
343
+ start := strings.Index(result, "${")
344
+ if start == -1 {
345
+ break
346
+ }
347
+
348
+ end := strings.Index(result[start:], "}")
349
+ if end == -1 {
350
+ break
351
+ }
352
+ end += start
353
+
354
+ // 提取环境变量名称
355
+ envName := result[start+2:end]
356
+
357
+ // 获取环境变量值,如果不存在则保持原样
358
+ envValue := os.Getenv(envName)
359
+ if envValue == "" {
360
+ // 可选:记录警告或使用默认值
361
+ // 这里我们保持原样,也可以设置为空字符串或其他默认值
362
+ }
363
+
364
+ // 替换环境变量
365
+ result = result[:start] + envValue + result[end+1:]
366
+ }
367
+
368
+ return result
369
+ }
370
+
371
+ // 处理配置结构体中的环境变量
372
+ func ProcessConfig(config *Config) {
373
+ // 处理 ThinkingServices
374
+ for i := range config.ThinkingServices {
375
+ config.ThinkingServices[i].Name = replaceEnvVars(config.ThinkingServices[i].Name)
376
+ config.ThinkingServices[i].Model = replaceEnvVars(config.ThinkingServices[i].Model)
377
+ config.ThinkingServices[i].BaseURL = replaceEnvVars(config.ThinkingServices[i].BaseURL)
378
+ config.ThinkingServices[i].APIKey = replaceEnvVars(config.ThinkingServices[i].APIKey)
379
+ config.ThinkingServices[i].Proxy = replaceEnvVars(config.ThinkingServices[i].Proxy)
380
+ }
381
+
382
+ // 处理 Channels
383
+ for key, channel := range config.Channels {
384
+ channel.Name = replaceEnvVars(channel.Name)
385
+ channel.BaseURL = replaceEnvVars(channel.BaseURL)
386
+ channel.Proxy = replaceEnvVars(channel.Proxy)
387
+ config.Channels[key] = channel
388
+ }
389
+
390
+ // 如果 Global 配置中也有环境变量,可以在这里添加处理逻辑
391
+ }
392
+
393
+ func (l *RequestLogger) Log(format string, args ...interface{}) {
394
+ msg := fmt.Sprintf(format, args...)
395
+ l.logs = append(l.logs, fmt.Sprintf("[%s] %s", time.Now().Format(time.RFC3339), msg))
396
+ log.Printf("[RequestID: %s] %s", l.RequestID, msg)
397
+ }
398
+
399
+ func (l *RequestLogger) LogContent(contentType string, content interface{}, maxLength int) {
400
+ if !l.config.Global.Log.Debug.Enabled {
401
+ return
402
+ }
403
+ sanitizedContent := sanitizeJSON(content)
404
+ truncatedContent := truncateContent(sanitizedContent, maxLength)
405
+ l.Log("%s Content:\n%s", contentType, truncatedContent)
406
+ }
407
+
408
+ func truncateContent(content string, maxLength int) string {
409
+ if len(content) <= maxLength {
410
+ return content
411
+ }
412
+ return content[:maxLength] + "... (truncated)"
413
+ }
414
+
415
+ func sanitizeJSON(data interface{}) string {
416
+ sanitized, err := json.Marshal(data)
417
+ if err != nil {
418
+ return "Failed to marshal JSON"
419
+ }
420
+ content := string(sanitized)
421
+ sensitivePattern := `"api_key":\s*"[^"]*"`
422
+ content = regexp.MustCompile(sensitivePattern).ReplaceAllString(content, `"api_key":"****"`)
423
+ return content
424
+ }
425
+
426
+ func extractRealAPIKey(fullKey string) string {
427
+ parts := strings.Split(fullKey, "-")
428
+ if len(parts) >= 3 && (parts[0] == "deep" || parts[0] == "openai") {
429
+ return strings.Join(parts[2:], "-")
430
+ }
431
+ return fullKey
432
+ }
433
+
434
+ func extractChannelID(fullKey string) string {
435
+ parts := strings.Split(fullKey, "-")
436
+ if len(parts) >= 2 && (parts[0] == "deep" || parts[0] == "openai") {
437
+ return parts[1]
438
+ }
439
+ return "1" // 默认渠道
440
+ }
441
+
442
+ func logAPIKey(key string) string {
443
+ if len(key) <= 8 {
444
+ return "****"
445
+ }
446
+ return key[:4] + "..." + key[len(key)-4:]
447
+ }
448
+
449
+ // ---------------------- Server 结构 ----------------------
450
+
451
+ type Server struct {
452
+ config *Config
453
+ srv *http.Server
454
+ }
455
+
456
+ var (
457
+ randMu sync.Mutex
458
+ randGen = rand.New(rand.NewSource(time.Now().UnixNano()))
459
+ )
460
+
461
+ func NewServer(config *Config) *Server {
462
+ return &Server{
463
+ config: config,
464
+ }
465
+ }
466
+
467
+ func (s *Server) Start() error {
468
+ mux := http.NewServeMux()
469
+ mux.HandleFunc("/v1/chat/completions", s.handleOpenAIRequests)
470
+ mux.HandleFunc("/v1/models", s.handleOpenAIRequests)
471
+ mux.HandleFunc("/health", s.handleHealth)
472
+
473
+ s.srv = &http.Server{
474
+ Addr: fmt.Sprintf("%s:%d", s.config.Global.Server.Host, s.config.Global.Server.Port),
475
+ Handler: mux,
476
+ ReadTimeout: time.Duration(s.config.Global.Server.ReadTimeout) * time.Second,
477
+ WriteTimeout: time.Duration(s.config.Global.Server.WriteTimeout) * time.Second,
478
+ IdleTimeout: time.Duration(s.config.Global.Server.IdleTimeout) * time.Second,
479
+ }
480
+
481
+ log.Printf("Server starting on %s\n", s.srv.Addr)
482
+ return s.srv.ListenAndServe()
483
+ }
484
+
485
+ func (s *Server) Shutdown(ctx context.Context) error {
486
+ return s.srv.Shutdown(ctx)
487
+ }
488
+
489
+ func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
490
+ w.WriteHeader(http.StatusOK)
491
+ json.NewEncoder(w).Encode(map[string]string{"status": "healthy"})
492
+ }
493
+
494
+ func (s *Server) handleOpenAIRequests(w http.ResponseWriter, r *http.Request) {
495
+ logger := NewRequestLogger(s.config)
496
+
497
+ fullAPIKey := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
498
+ apiKey := extractRealAPIKey(fullAPIKey)
499
+ channelID := extractChannelID(fullAPIKey)
500
+
501
+ logger.Log("Received request for %s with API Key: %s", r.URL.Path, logAPIKey(fullAPIKey))
502
+ logger.Log("Extracted channel ID: %s", channelID)
503
+ logger.Log("Extracted real API Key: %s", logAPIKey(apiKey))
504
+
505
+ targetChannel, ok := s.config.Channels[channelID]
506
+ if !ok {
507
+ http.Error(w, "Invalid channel", http.StatusBadRequest)
508
+ return
509
+ }
510
+
511
+ if r.URL.Path == "/v1/models" {
512
+ if r.Method != http.MethodGet {
513
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
514
+ return
515
+ }
516
+ req := &ChatCompletionRequest{APIKey: apiKey}
517
+ s.forwardModelsRequest(w, r.Context(), req, targetChannel)
518
+ return
519
+ }
520
+
521
+ if r.Method != http.MethodPost {
522
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
523
+ return
524
+ }
525
+
526
+ body, err := io.ReadAll(r.Body)
527
+ if err != nil {
528
+ logger.Log("Error reading request body: %v", err)
529
+ http.Error(w, "Failed to read request", http.StatusBadRequest)
530
+ return
531
+ }
532
+ r.Body.Close()
533
+ r.Body = io.NopCloser(bytes.NewBuffer(body))
534
+
535
+ if s.config.Global.Log.Debug.PrintRequest {
536
+ logger.LogContent("Request", string(body), s.config.Global.Log.Debug.MaxContentLength)
537
+ }
538
+
539
+ var req ChatCompletionRequest
540
+ if err := json.NewDecoder(bytes.NewBuffer(body)).Decode(&req); err != nil {
541
+ http.Error(w, "Invalid request body", http.StatusBadRequest)
542
+ return
543
+ }
544
+ req.APIKey = apiKey
545
+
546
+ thinkingService := s.getWeightedRandomThinkingService()
547
+ logger.Log("Using thinking service: %s with API Key: %s", thinkingService.Name, logAPIKey(thinkingService.APIKey))
548
+
549
+ if req.Stream {
550
+ handler, err := NewStreamHandler(w, thinkingService, targetChannel, s.config)
551
+ if err != nil {
552
+ http.Error(w, "Streaming not supported", http.StatusInternalServerError)
553
+ return
554
+ }
555
+ if err := handler.HandleRequest(r.Context(), &req); err != nil {
556
+ logger.Log("Stream handler error: %v", err)
557
+ }
558
+ } else {
559
+ thinkingResp, err := s.processThinkingContent(r.Context(), &req, thinkingService)
560
+ if err != nil {
561
+ logger.Log("Error processing thinking content: %v", err)
562
+ http.Error(w, "Thinking service error: "+err.Error(), http.StatusInternalServerError)
563
+ return
564
+ }
565
+ enhancedReq := s.prepareEnhancedRequest(&req, thinkingResp, thinkingService)
566
+ s.forwardRequest(w, r.Context(), enhancedReq, targetChannel)
567
+ }
568
+ }
569
+
570
+ func (s *Server) getWeightedRandomThinkingService() ThinkingService {
571
+ thinkingServices := s.config.ThinkingServices
572
+ if len(thinkingServices) == 0 {
573
+ return ThinkingService{}
574
+ }
575
+ totalWeight := 0
576
+ for _, svc := range thinkingServices {
577
+ totalWeight += svc.Weight
578
+ }
579
+ if totalWeight <= 0 {
580
+ log.Println("Warning: Total weight of thinking services is not positive, using first service as default.")
581
+ return thinkingServices[0]
582
+ }
583
+ randMu.Lock()
584
+ randNum := randGen.Intn(totalWeight)
585
+ randMu.Unlock()
586
+ currentSum := 0
587
+ for _, svc := range thinkingServices {
588
+ currentSum += svc.Weight
589
+ if randNum < currentSum {
590
+ return svc
591
+ }
592
+ }
593
+ return thinkingServices[0]
594
+ }
595
+
596
+ // ---------------------- 非流式处理思考服务 ----------------------
597
+
598
+ type ThinkingResponse struct {
599
+ Content string
600
+ ReasoningContent string
601
+ }
602
+
603
+ func (s *Server) processThinkingContent(ctx context.Context, req *ChatCompletionRequest, svc ThinkingService) (*ThinkingResponse, error) {
604
+ logger := NewRequestLogger(s.config)
605
+ log.Printf("Getting thinking content from service: %s (mode=%s)", svc.Name, svc.Mode)
606
+
607
+ thinkingReq := *req
608
+ thinkingReq.Model = svc.Model
609
+ thinkingReq.APIKey = svc.APIKey
610
+
611
+ var systemPrompt string
612
+ if svc.Mode == "full" {
613
+ systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly."
614
+ } else {
615
+ systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step."
616
+ }
617
+ thinkingReq.Messages = append([]ChatCompletionMessage{
618
+ {Role: "system", Content: systemPrompt},
619
+ }, thinkingReq.Messages...)
620
+
621
+ temp := 0.7
622
+ if svc.Temperature != nil {
623
+ temp = *svc.Temperature
624
+ }
625
+ payload := map[string]interface{}{
626
+ "model": svc.Model,
627
+ "messages": thinkingReq.Messages,
628
+ "stream": false,
629
+ "temperature": temp,
630
+ }
631
+ if isValidReasoningEffort(svc.ReasoningEffort) {
632
+ payload["reasoning_effort"] = svc.ReasoningEffort
633
+ }
634
+ if isValidReasoningFormat(svc.ReasoningFormat) {
635
+ payload["reasoning_format"] = svc.ReasoningFormat
636
+ }
637
+
638
+ if s.config.Global.Log.Debug.PrintRequest {
639
+ logger.LogContent("Thinking Service Request", payload, s.config.Global.Log.Debug.MaxContentLength)
640
+ }
641
+
642
+ jsonData, err := json.Marshal(payload)
643
+ if err != nil {
644
+ return nil, fmt.Errorf("failed to marshal thinking request: %v", err)
645
+ }
646
+ client, err := createHTTPClient(svc.Proxy, time.Duration(svc.Timeout)*time.Second)
647
+ if err != nil {
648
+ return nil, fmt.Errorf("failed to create HTTP client: %v", err)
649
+ }
650
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, svc.GetFullURL(), bytes.NewBuffer(jsonData))
651
+ if err != nil {
652
+ return nil, fmt.Errorf("failed to create request: %v", err)
653
+ }
654
+ httpReq.Header.Set("Content-Type", "application/json")
655
+ httpReq.Header.Set("Authorization", "Bearer "+svc.APIKey)
656
+
657
+ resp, err := client.Do(httpReq)
658
+ if err != nil {
659
+ return nil, fmt.Errorf("failed to send thinking request: %v", err)
660
+ }
661
+ defer resp.Body.Close()
662
+
663
+ respBody, err := io.ReadAll(resp.Body)
664
+ if err != nil {
665
+ return nil, fmt.Errorf("failed to read response body: %v", err)
666
+ }
667
+ if s.config.Global.Log.Debug.PrintResponse {
668
+ logger.LogContent("Thinking Service Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
669
+ }
670
+ if resp.StatusCode != http.StatusOK {
671
+ return nil, fmt.Errorf("thinking service returned %d: %s", resp.StatusCode, string(respBody))
672
+ }
673
+
674
+ var thinkingResp ChatCompletionResponse
675
+ if err := json.Unmarshal(respBody, &thinkingResp); err != nil {
676
+ return nil, fmt.Errorf("failed to unmarshal thinking response: %v", err)
677
+ }
678
+ if len(thinkingResp.Choices) == 0 {
679
+ return nil, fmt.Errorf("thinking service returned no choices")
680
+ }
681
+
682
+ result := &ThinkingResponse{}
683
+ choice := thinkingResp.Choices[0]
684
+
685
+ if svc.Mode == "full" {
686
+ result.ReasoningContent = choice.Message.Content
687
+ result.Content = "Based on the above detailed analysis."
688
+ } else {
689
+ if choice.Message.ReasoningContent != nil {
690
+ switch v := choice.Message.ReasoningContent.(type) {
691
+ case string:
692
+ result.ReasoningContent = v
693
+ case map[string]interface{}:
694
+ if j, err := json.Marshal(v); err == nil {
695
+ result.ReasoningContent = string(j)
696
+ }
697
+ }
698
+ }
699
+ if result.ReasoningContent == "" {
700
+ result.ReasoningContent = choice.Message.Content
701
+ }
702
+ result.Content = "Based on the above reasoning."
703
+ }
704
+ return result, nil
705
+ }
706
+
707
+ func (s *Server) prepareEnhancedRequest(originalReq *ChatCompletionRequest, thinkingResp *ThinkingResponse, svc ThinkingService) *ChatCompletionRequest {
708
+ newReq := *originalReq
709
+ var systemPrompt string
710
+ if svc.Mode == "full" {
711
+ systemPrompt = fmt.Sprintf(`Consider the following detailed analysis (not shown to user):
712
+ %s
713
+
714
+ Provide a clear, concise response that incorporates insights from this analysis.`, thinkingResp.ReasoningContent)
715
+ } else {
716
+ systemPrompt = fmt.Sprintf(`Previous thinking process:
717
+ %s
718
+ Please consider the above thinking process in your response.`, thinkingResp.ReasoningContent)
719
+ }
720
+ newReq.Messages = append([]ChatCompletionMessage{
721
+ {Role: "system", Content: systemPrompt},
722
+ }, newReq.Messages...)
723
+ return &newReq
724
+ }
725
+
726
+ func (s *Server) forwardRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, channel Channel) {
727
+ logger := NewRequestLogger(s.config)
728
+ if s.config.Global.Log.Debug.PrintRequest {
729
+ logger.LogContent("Forward Request", req, s.config.Global.Log.Debug.MaxContentLength)
730
+ }
731
+ jsonData, err := json.Marshal(req)
732
+ if err != nil {
733
+ http.Error(w, "Failed to marshal request", http.StatusInternalServerError)
734
+ return
735
+ }
736
+
737
+ client, err := createHTTPClient(channel.Proxy, time.Duration(channel.Timeout)*time.Second)
738
+ if err != nil {
739
+ http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError)
740
+ return
741
+ }
742
+
743
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, channel.GetFullURL(), bytes.NewBuffer(jsonData))
744
+ if err != nil {
745
+ http.Error(w, "Failed to create request", http.StatusInternalServerError)
746
+ return
747
+ }
748
+ httpReq.Header.Set("Content-Type", "application/json")
749
+ httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
750
+
751
+ resp, err := client.Do(httpReq)
752
+ if err != nil {
753
+ http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError)
754
+ return
755
+ }
756
+ defer resp.Body.Close()
757
+
758
+ respBody, err := io.ReadAll(resp.Body)
759
+ if err != nil {
760
+ http.Error(w, "Failed to read response", http.StatusInternalServerError)
761
+ return
762
+ }
763
+ if s.config.Global.Log.Debug.PrintResponse {
764
+ logger.LogContent("Forward Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
765
+ }
766
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
767
+ http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode)
768
+ return
769
+ }
770
+
771
+ for k, vals := range resp.Header {
772
+ for _, v := range vals {
773
+ w.Header().Add(k, v)
774
+ }
775
+ }
776
+ w.WriteHeader(resp.StatusCode)
777
+ w.Write(respBody)
778
+ }
779
+
780
+ func (s *Server) forwardModelsRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, targetChannel Channel) {
781
+ logger := NewRequestLogger(s.config)
782
+ if s.config.Global.Log.Debug.PrintRequest {
783
+ logger.LogContent("/v1/models Request", req, s.config.Global.Log.Debug.MaxContentLength)
784
+ }
785
+ fullChatURL := targetChannel.GetFullURL()
786
+ parsedURL, err := url.Parse(fullChatURL)
787
+ if err != nil {
788
+ http.Error(w, "Failed to parse channel URL", http.StatusInternalServerError)
789
+ return
790
+ }
791
+ baseURL := parsedURL.Scheme + "://" + parsedURL.Host
792
+ modelsURL := strings.TrimSuffix(baseURL, "/") + "/v1/models"
793
+
794
+ client, err := createHTTPClient(targetChannel.Proxy, time.Duration(targetChannel.Timeout)*time.Second)
795
+ if err != nil {
796
+ http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError)
797
+ return
798
+ }
799
+
800
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil)
801
+ if err != nil {
802
+ http.Error(w, "Failed to create request", http.StatusInternalServerError)
803
+ return
804
+ }
805
+ httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
806
+
807
+ resp, err := client.Do(httpReq)
808
+ if err != nil {
809
+ http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError)
810
+ return
811
+ }
812
+ defer resp.Body.Close()
813
+
814
+ respBody, err := io.ReadAll(resp.Body)
815
+ if err != nil {
816
+ http.Error(w, "Failed to read response", http.StatusInternalServerError)
817
+ return
818
+ }
819
+ if s.config.Global.Log.Debug.PrintResponse {
820
+ logger.LogContent("/v1/models Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
821
+ }
822
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
823
+ http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode)
824
+ return
825
+ }
826
+
827
+ for k, vals := range resp.Header {
828
+ for _, v := range vals {
829
+ w.Header().Add(k, v)
830
+ }
831
+ }
832
+ w.WriteHeader(resp.StatusCode)
833
+ w.Write(respBody)
834
+ }
835
+
836
+ // ---------------------- 流式处理 ----------------------
837
+
838
+ // collectedReasoningBuffer 用于收集思考服务返回的 reasoning_content(标准模式只收集 reasoning_content;full 模式收集全部)
839
+ type collectedReasoningBuffer struct {
840
+ builder strings.Builder
841
+ mode string
842
+ }
843
+
844
+ func (cb *collectedReasoningBuffer) append(text string) {
845
+ cb.builder.WriteString(text)
846
+ }
847
+
848
+ func (cb *collectedReasoningBuffer) get() string {
849
+ return cb.builder.String()
850
+ }
851
+
852
+ type StreamHandler struct {
853
+ thinkingService ThinkingService
854
+ targetChannel Channel
855
+ writer http.ResponseWriter
856
+ flusher http.Flusher
857
+ config *Config
858
+ }
859
+
860
+ func NewStreamHandler(w http.ResponseWriter, thinkingService ThinkingService, targetChannel Channel, config *Config) (*StreamHandler, error) {
861
+ flusher, ok := w.(http.Flusher)
862
+ if !ok {
863
+ return nil, fmt.Errorf("streaming not supported by this writer")
864
+ }
865
+ return &StreamHandler{
866
+ thinkingService: thinkingService,
867
+ targetChannel: targetChannel,
868
+ writer: w,
869
+ flusher: flusher,
870
+ config: config,
871
+ }, nil
872
+ }
873
+
874
+ // HandleRequest 分两阶段:先流式接收思考服务 SSE 并转发给客户端(只转发包含 reasoning_content 的 chunk,标准模式遇到纯 content 则中断);
875
+ // 然后使用收集到的 reasoning_content 构造最终请求,发起目标 Channel 的 SSE 并转发给客户端。
876
+ func (h *StreamHandler) HandleRequest(ctx context.Context, req *ChatCompletionRequest) error {
877
+ h.writer.Header().Set("Content-Type", "text/event-stream")
878
+ h.writer.Header().Set("Cache-Control", "no-cache")
879
+ h.writer.Header().Set("Connection", "keep-alive")
880
+
881
+ logger := NewRequestLogger(h.config)
882
+ reasonBuf := &collectedReasoningBuffer{mode: h.thinkingService.Mode}
883
+
884
+ // 阶段一:流式接收思考服务 SSE
885
+ if err := h.streamThinkingService(ctx, req, reasonBuf, logger); err != nil {
886
+ return err
887
+ }
888
+
889
+ // 阶段二:构造最终请求,并流式转发目标 Channel SSE
890
+ finalReq := h.prepareFinalRequest(req, reasonBuf.get())
891
+ if err := h.streamFinalChannel(ctx, finalReq, logger); err != nil {
892
+ return err
893
+ }
894
+ return nil
895
+ }
896
+
897
+ // streamThinkingService 连接思考服务 SSE,原样转发包含 reasoning_content 的 SSE chunk给客户端,同时收集 reasoning_content。
898
+ // 在标准模式下,一旦遇到只返回 non-empty 的 content(且 reasoning_content 为空),则立即中断,不转发该 chunk。
899
+ func (h *StreamHandler) streamThinkingService(ctx context.Context, req *ChatCompletionRequest, reasonBuf *collectedReasoningBuffer, logger *RequestLogger) error {
900
+ thinkingReq := *req
901
+ thinkingReq.Model = h.thinkingService.Model
902
+ thinkingReq.APIKey = h.thinkingService.APIKey
903
+
904
+ var systemPrompt string
905
+ if h.thinkingService.Mode == "full" {
906
+ systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly."
907
+ } else {
908
+ systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step."
909
+ }
910
+ messages := append([]ChatCompletionMessage{
911
+ {Role: "system", Content: systemPrompt},
912
+ }, thinkingReq.Messages...)
913
+
914
+ temp := 0.7
915
+ if h.thinkingService.Temperature != nil {
916
+ temp = *h.thinkingService.Temperature
917
+ }
918
+ bodyMap := map[string]interface{}{
919
+ "model": thinkingReq.Model,
920
+ "messages": messages,
921
+ "temperature": temp,
922
+ "stream": true,
923
+ }
924
+ if isValidReasoningEffort(h.thinkingService.ReasoningEffort) {
925
+ bodyMap["reasoning_effort"] = h.thinkingService.ReasoningEffort
926
+ }
927
+ if isValidReasoningFormat(h.thinkingService.ReasoningFormat) {
928
+ bodyMap["reasoning_format"] = h.thinkingService.ReasoningFormat
929
+ }
930
+
931
+ jsonData, err := json.Marshal(bodyMap)
932
+ if err != nil {
933
+ return err
934
+ }
935
+
936
+ client, err := createHTTPClient(h.thinkingService.Proxy, time.Duration(h.thinkingService.Timeout)*time.Second)
937
+ if err != nil {
938
+ return err
939
+ }
940
+
941
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.thinkingService.GetFullURL(), bytes.NewBuffer(jsonData))
942
+ if err != nil {
943
+ return err
944
+ }
945
+ httpReq.Header.Set("Content-Type", "application/json")
946
+ httpReq.Header.Set("Authorization", "Bearer "+h.thinkingService.APIKey)
947
+
948
+ log.Printf("Starting ThinkingService SSE => %s (mode=%s, force_stop=%v)", h.thinkingService.GetFullURL(), h.thinkingService.Mode, h.thinkingService.ForceStopDeepThinking)
949
+ resp, err := client.Do(httpReq)
950
+ if err != nil {
951
+ return err
952
+ }
953
+ defer resp.Body.Close()
954
+
955
+ if resp.StatusCode != http.StatusOK {
956
+ b, _ := io.ReadAll(resp.Body)
957
+ return fmt.Errorf("thinking service status=%d, body=%s", resp.StatusCode, string(b))
958
+ }
959
+
960
+ reader := bufio.NewReader(resp.Body)
961
+ var lastLine string
962
+ // 标识是否需中断思考 SSE(标准模式下)
963
+ forceStop := false
964
+
965
+ for {
966
+ select {
967
+ case <-ctx.Done():
968
+ return ctx.Err()
969
+ default:
970
+ }
971
+ line, err := reader.ReadString('\n')
972
+ if err != nil {
973
+ if err == io.EOF {
974
+ break
975
+ }
976
+ return err
977
+ }
978
+ line = strings.TrimSpace(line)
979
+ if line == "" || line == lastLine {
980
+ continue
981
+ }
982
+ lastLine = line
983
+
984
+ if !strings.HasPrefix(line, "data: ") {
985
+ continue
986
+ }
987
+ dataPart := strings.TrimPrefix(line, "data: ")
988
+ if dataPart == "[DONE]" {
989
+ // 思考服务结束:直接中断(不发送特殊标记)
990
+ break
991
+ }
992
+
993
+ var chunk struct {
994
+ Choices []struct {
995
+ Delta struct {
996
+ Content string `json:"content,omitempty"`
997
+ ReasoningContent string `json:"reasoning_content,omitempty"`
998
+ } `json:"delta"`
999
+ FinishReason *string `json:"finish_reason,omitempty"`
1000
+ } `json:"choices"`
1001
+ }
1002
+ if err := json.Unmarshal([]byte(dataPart), &chunk); err != nil {
1003
+ // 如果解析失败,直接原样转发
1004
+ h.writer.Write([]byte("data: " + dataPart + "\n\n"))
1005
+ h.flusher.Flush()
1006
+ continue
1007
+ }
1008
+
1009
+ if len(chunk.Choices) > 0 {
1010
+ c := chunk.Choices[0]
1011
+ if h.config.Global.Log.Debug.PrintResponse {
1012
+ logger.LogContent("Thinking SSE chunk", chunk, h.config.Global.Log.Debug.MaxContentLength)
1013
+ }
1014
+
1015
+ if h.thinkingService.Mode == "full" {
1016
+ // full 模式:收集所有内容
1017
+ if c.Delta.ReasoningContent != "" {
1018
+ reasonBuf.append(c.Delta.ReasoningContent)
1019
+ }
1020
+ if c.Delta.Content != "" {
1021
+ reasonBuf.append(c.Delta.Content)
1022
+ }
1023
+ // 原样转发整 chunk
1024
+ forwardLine := "data: " + dataPart + "\n\n"
1025
+ h.writer.Write([]byte(forwardLine))
1026
+ h.flusher.Flush()
1027
+ } else {
1028
+ // standard 模式:只收集 reasoning_content
1029
+ if c.Delta.ReasoningContent != "" {
1030
+ reasonBuf.append(c.Delta.ReasoningContent)
1031
+ // 转发该 chunk
1032
+ forwardLine := "data: " + dataPart + "\n\n"
1033
+ h.writer.Write([]byte(forwardLine))
1034
+ h.flusher.Flush()
1035
+ }
1036
+ // 如果遇到 chunk 中只有 content(且 reasoning_content 为空),认为思考链结束,不转发该 chunk
1037
+ if c.Delta.Content != "" && strings.TrimSpace(c.Delta.ReasoningContent) == "" {
1038
+ forceStop = true
1039
+ }
1040
+ }
1041
+
1042
+ // 如果 finishReason 非空,也认为结束
1043
+ if c.FinishReason != nil && *c.FinishReason != "" {
1044
+ forceStop = true
1045
+ }
1046
+ }
1047
+
1048
+ if forceStop {
1049
+ break
1050
+ }
1051
+ }
1052
+ // 读空剩余数据
1053
+ io.Copy(io.Discard, reader)
1054
+ return nil
1055
+ }
1056
+
1057
+ func (h *StreamHandler) prepareFinalRequest(originalReq *ChatCompletionRequest, reasoningCollected string) *ChatCompletionRequest {
1058
+ req := *originalReq
1059
+ var systemPrompt string
1060
+ if h.thinkingService.Mode == "full" {
1061
+ systemPrompt = fmt.Sprintf(
1062
+ `Consider the following detailed analysis (not shown to user):
1063
+ %s
1064
+
1065
+ Provide a clear, concise response that incorporates insights from this analysis.`,
1066
+ reasoningCollected,
1067
+ )
1068
+ } else {
1069
+ systemPrompt = fmt.Sprintf(
1070
+ `Previous thinking process:
1071
+ %s
1072
+ Please consider the above thinking process in your response.`,
1073
+ reasoningCollected,
1074
+ )
1075
+ }
1076
+ req.Messages = append([]ChatCompletionMessage{
1077
+ {Role: "system", Content: systemPrompt},
1078
+ }, req.Messages...)
1079
+ return &req
1080
+ }
1081
+
1082
+ func (h *StreamHandler) streamFinalChannel(ctx context.Context, req *ChatCompletionRequest, logger *RequestLogger) error {
1083
+ if h.config.Global.Log.Debug.PrintRequest {
1084
+ logger.LogContent("Final Request to Channel", req, h.config.Global.Log.Debug.MaxContentLength)
1085
+ }
1086
+
1087
+ jsonData, err := json.Marshal(req)
1088
+ if err != nil {
1089
+ return err
1090
+ }
1091
+ client, err := createHTTPClient(h.targetChannel.Proxy, time.Duration(h.targetChannel.Timeout)*time.Second)
1092
+ if err != nil {
1093
+ return err
1094
+ }
1095
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.targetChannel.GetFullURL(), bytes.NewBuffer(jsonData))
1096
+ if err != nil {
1097
+ return err
1098
+ }
1099
+ httpReq.Header.Set("Content-Type", "application/json")
1100
+ httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
1101
+
1102
+ resp, err := client.Do(httpReq)
1103
+ if err != nil {
1104
+ return err
1105
+ }
1106
+ defer resp.Body.Close()
1107
+
1108
+ if resp.StatusCode != http.StatusOK {
1109
+ b, _ := io.ReadAll(resp.Body)
1110
+ return fmt.Errorf("target channel status=%d, body=%s", resp.StatusCode, string(b))
1111
+ }
1112
+
1113
+ reader := bufio.NewReader(resp.Body)
1114
+ var lastLine string
1115
+
1116
+ for {
1117
+ select {
1118
+ case <-ctx.Done():
1119
+ return ctx.Err()
1120
+ default:
1121
+ }
1122
+ line, err := reader.ReadString('\n')
1123
+ if err != nil {
1124
+ if err == io.EOF {
1125
+ break
1126
+ }
1127
+ return err
1128
+ }
1129
+ line = strings.TrimSpace(line)
1130
+ if line == "" || line == lastLine {
1131
+ continue
1132
+ }
1133
+ lastLine = line
1134
+
1135
+ if !strings.HasPrefix(line, "data: ") {
1136
+ continue
1137
+ }
1138
+ data := strings.TrimPrefix(line, "data: ")
1139
+ if data == "[DONE]" {
1140
+ doneLine := "data: [DONE]\n\n"
1141
+ h.writer.Write([]byte(doneLine))
1142
+ h.flusher.Flush()
1143
+ break
1144
+ }
1145
+ forwardLine := "data: " + data + "\n\n"
1146
+ h.writer.Write([]byte(forwardLine))
1147
+ h.flusher.Flush()
1148
+
1149
+ if h.config.Global.Log.Debug.PrintResponse {
1150
+ logger.LogContent("Channel SSE chunk", forwardLine, h.config.Global.Log.Debug.MaxContentLength)
1151
+ }
1152
+ }
1153
+ io.Copy(io.Discard, reader)
1154
+ return nil
1155
+ }
1156
+
1157
+ // ---------------------- HTTP Client 工具 ----------------------
1158
+
1159
+ func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) {
1160
+ transport := &http.Transport{
1161
+ DialContext: (&net.Dialer{
1162
+ Timeout: 30 * time.Second,
1163
+ KeepAlive: 30 * time.Second,
1164
+ }).DialContext,
1165
+ ForceAttemptHTTP2: true,
1166
+ MaxIdleConns: 100,
1167
+ IdleConnTimeout: 90 * time.Second,
1168
+ TLSHandshakeTimeout: 10 * time.Second,
1169
+ ExpectContinueTimeout: 1 * time.Second,
1170
+ }
1171
+ if proxyURL != "" {
1172
+ parsedURL, err := url.Parse(proxyURL)
1173
+ if err != nil {
1174
+ return nil, fmt.Errorf("invalid proxy URL: %v", err)
1175
+ }
1176
+ switch parsedURL.Scheme {
1177
+ case "http", "https":
1178
+ transport.Proxy = http.ProxyURL(parsedURL)
1179
+ case "socks5":
1180
+ dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
1181
+ if err != nil {
1182
+ return nil, fmt.Errorf("failed to create SOCKS5 dialer: %v", err)
1183
+ }
1184
+ transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
1185
+ return dialer.Dial(network, addr)
1186
+ }
1187
+ default:
1188
+ return nil, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
1189
+ }
1190
+ }
1191
+ return &http.Client{
1192
+ Transport: transport,
1193
+ Timeout: timeout,
1194
+ }, nil
1195
+ }
1196
+
1197
+ func maskSensitiveHeaders(headers http.Header) http.Header {
1198
+ masked := make(http.Header)
1199
+ for k, vals := range headers {
1200
+ if strings.ToLower(k) == "authorization" {
1201
+ masked[k] = []string{"Bearer ****"}
1202
+ } else {
1203
+ masked[k] = vals
1204
+ }
1205
+ }
1206
+ return masked
1207
+ }
1208
+
1209
+ func isValidReasoningEffort(effort string) bool {
1210
+ switch strings.ToLower(effort) {
1211
+ case "low", "medium", "high":
1212
+ return true
1213
+ }
1214
+ return false
1215
+ }
1216
+
1217
+ func isValidReasoningFormat(format string) bool {
1218
+ switch strings.ToLower(format) {
1219
+ case "parsed", "raw", "hidden":
1220
+ return true
1221
+ }
1222
+ return false
1223
+ }
1224
+
1225
+ // ---------------------- 配置加载与验证 ----------------------
1226
+
1227
+ func loadConfig() (*Config, error) {
1228
+ var configFile string
1229
+ flag.StringVar(&configFile, "config", "", "path to config file")
1230
+ flag.Parse()
1231
+ viper.SetConfigType("yaml")
1232
+
1233
+ if configFile != "" {
1234
+ viper.SetConfigFile(configFile)
1235
+ } else {
1236
+ ex, err := os.Executable()
1237
+ if err != nil {
1238
+ return nil, err
1239
+ }
1240
+ exePath := filepath.Dir(ex)
1241
+ defaultPaths := []string{
1242
+ filepath.Join(exePath, "config.yaml"),
1243
+ filepath.Join(exePath, "conf", "config.yaml"),
1244
+ "./config.yaml",
1245
+ "./conf/config.yaml",
1246
+ }
1247
+ if os.PathSeparator == '\\' {
1248
+ programData := os.Getenv("PROGRAMDATA")
1249
+ if programData != "" {
1250
+ defaultPaths = append(defaultPaths, filepath.Join(programData, "DeepAI", "config.yaml"))
1251
+ }
1252
+ } else {
1253
+ defaultPaths = append(defaultPaths, "/etc/deepai/config.yaml")
1254
+ }
1255
+ for _, p := range defaultPaths {
1256
+ viper.AddConfigPath(filepath.Dir(p))
1257
+ if strings.Contains(p, ".yaml") {
1258
+ viper.SetConfigName(strings.TrimSuffix(filepath.Base(p), ".yaml"))
1259
+ }
1260
+ }
1261
+ }
1262
+
1263
+ if err := viper.ReadInConfig(); err != nil {
1264
+ return nil, fmt.Errorf("failed to read config: %v", err)
1265
+ }
1266
+
1267
+ var cfg Config
1268
+ if err := viper.Unmarshal(&cfg); err != nil {
1269
+ return nil, fmt.Errorf("unmarshal config error: %v", err)
1270
+ }
1271
+ if err := validateConfig(&cfg); err != nil {
1272
+ return nil, fmt.Errorf("config validation error: %v", err)
1273
+ }
1274
+ return &cfg, nil
1275
+ }
1276
+
1277
+ func validateConfig(config *Config) error {
1278
+ if len(config.ThinkingServices) == 0 {
1279
+ return fmt.Errorf("no thinking services configured")
1280
+ }
1281
+ if len(config.Channels) == 0 {
1282
+ return fmt.Errorf("no channels configured")
1283
+ }
1284
+ for i, svc := range config.ThinkingServices {
1285
+ if svc.BaseURL == "" {
1286
+ return fmt.Errorf("thinking service %s has empty baseURL", svc.Name)
1287
+ }
1288
+ if svc.APIKey == "" {
1289
+ return fmt.Errorf("thinking service %s has empty apiKey", svc.Name)
1290
+ }
1291
+ if svc.Timeout <= 0 {
1292
+ return fmt.Errorf("thinking service %s has invalid timeout", svc.Name)
1293
+ }
1294
+ if svc.Model == "" {
1295
+ return fmt.Errorf("thinking service %s has empty model", svc.Name)
1296
+ }
1297
+ if svc.Mode == "" {
1298
+ config.ThinkingServices[i].Mode = "standard"
1299
+ } else if svc.Mode != "standard" && svc.Mode != "full" {
1300
+ return fmt.Errorf("thinking service %s unknown mode=%s", svc.Name, svc.Mode)
1301
+ }
1302
+ }
1303
+ return nil
1304
+ }
1305
+
1306
+ func main() {
1307
+ log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile)
1308
+
1309
+ cfg, err := loadConfig()
1310
+ if err != nil {
1311
+ log.Fatalf("Failed to load config: %v", err)
1312
+ }
1313
+ log.Printf("Using config file: %s", viper.ConfigFileUsed())
1314
+
1315
+ server := NewServer(cfg)
1316
+
1317
+ done := make(chan os.Signal, 1)
1318
+ signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
1319
+
1320
+ go func() {
1321
+ if err := server.Start(); err != nil && err != http.ErrServerClosed {
1322
+ log.Fatalf("start server error: %v", err)
1323
+ }
1324
+ }()
1325
+ log.Printf("Server started successfully")
1326
+
1327
+ <-done
1328
+ log.Print("Server stopping...")
1329
+
1330
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
1331
+ defer cancel()
1332
+
1333
+ if err := server.Shutdown(ctx); err != nil {
1334
+ log.Printf("Server forced to shutdown: %v", err)
1335
+ }
1336
+ log.Print("Server stopped")
1337
+ }