nbugs commited on
Commit
783912e
·
verified ·
1 Parent(s): 08d5b68

Delete main.go

Browse files
Files changed (1) hide show
  1. main.go +0 -1337
main.go DELETED
@@ -1,1337 +0,0 @@
1
- package main
2
-
3
- import (
4
- "bufio"
5
- "bytes"
6
- "context"
7
- "encoding/json"
8
- "flag"
9
- "fmt"
10
- "io"
11
- "log"
12
- "math/rand"
13
- "net"
14
- "net/http"
15
- "net/url"
16
- "os"
17
- "os/signal"
18
- "path/filepath"
19
- "regexp"
20
- "strings"
21
- "sync"
22
- "syscall"
23
- "time"
24
-
25
- "github.com/google/uuid"
26
- "github.com/spf13/viper"
27
- "golang.org/x/net/proxy"
28
- )
29
-
30
- // ---------------------- 配置结构 ----------------------
31
-
32
- type Config struct {
33
- ThinkingServices []ThinkingService `mapstructure:"thinking_services"`
34
- Channels map[string]Channel `mapstructure:"channels"`
35
- Global GlobalConfig `mapstructure:"global"`
36
- }
37
-
38
- type ThinkingService struct {
39
- ID int `mapstructure:"id"`
40
- Name string `mapstructure:"name"`
41
- Model string `mapstructure:"model"`
42
- BaseURL string `mapstructure:"base_url"`
43
- APIPath string `mapstructure:"api_path"`
44
- APIKey string `mapstructure:"api_key"`
45
- Timeout int `mapstructure:"timeout"`
46
- Retry int `mapstructure:"retry"`
47
- Weight int `mapstructure:"weight"`
48
- Proxy string `mapstructure:"proxy"`
49
- Mode string `mapstructure:"mode"` // "standard" 或 "full"
50
- ReasoningEffort string `mapstructure:"reasoning_effort"`
51
- ReasoningFormat string `mapstructure:"reasoning_format"`
52
- Temperature *float64 `mapstructure:"temperature"`
53
- ForceStopDeepThinking bool `mapstructure:"force_stop_deep_thinking"` // 配置项:标准模式下遇到 content 时是否立即停止
54
- }
55
-
56
- // ... existing code ...
57
-
58
- // 新增函数:从环境变量加载配置
59
- func loadConfigFromEnv() *Config {
60
- cfg := &Config{
61
- ThinkingServices: []ThinkingService{
62
- {
63
- ID: 1,
64
- Name: os.Getenv("THINKING_SERVICE_NAME"),
65
- Mode: getEnvOrDefault("THINKING_SERVICE_MODE", "standard"),
66
- Model: os.Getenv("THINKING_SERVICE_MODEL"),
67
- BaseURL: os.Getenv("THINKING_SERVICE_BASE_URL"),
68
- APIPath: "/v1/chat/completions",
69
- APIKey: os.Getenv("THINKING_SERVICE_API_KEY"),
70
- Timeout: 600,
71
- Weight: 100,
72
- Proxy: "",
73
- ReasoningEffort: "high",
74
- ReasoningFormat: "parsed",
75
- Temperature: getFloatPtr(0.8),
76
- ForceStopDeepThinking: false,
77
- },
78
- },
79
- Channels: map[string]Channel{
80
- "1": {
81
- Name: os.Getenv("CHANNEL_NAME"),
82
- BaseURL: os.Getenv("CHANNEL_BASE_URL"),
83
- APIPath: "/v1/chat/completions",
84
- Timeout: 600,
85
- Proxy: "",
86
- },
87
- },
88
- Global: GlobalConfig{
89
- Log: LogConfig{
90
- Level: "error",
91
- Format: "json",
92
- Output: "console",
93
- Debug: DebugConfig{
94
- Enabled: true,
95
- PrintRequest: true,
96
- PrintResponse: true,
97
- MaxContentLength: 1000,
98
- },
99
- },
100
- Server: ServerConfig{
101
- Port: 8888,
102
- Host: "0.0.0.0",
103
- ReadTimeout: 600,
104
- WriteTimeout: 600,
105
- IdleTimeout: 600,
106
- },
107
- },
108
- }
109
- return cfg
110
- }
111
-
112
- // 新增辅助函数:获取环境变量,如果不存在则返回默认值
113
- func getEnvOrDefault(key, defaultValue string) string {
114
- if value := os.Getenv(key); value != "" {
115
- return value
116
- }
117
- return defaultValue
118
- }
119
-
120
- // 新增辅助函数:创建float64指针
121
- func getFloatPtr(value float64) *float64 {
122
- return &value
123
- }
124
-
125
- // 修改 loadConfig 函数
126
- func loadConfig() (*Config, error) {
127
- // 优先使用环境变量
128
- if os.Getenv("USE_ENV_CONFIG") == "true" {
129
- cfg := loadConfigFromEnv()
130
- if err := validateConfig(cfg); err != nil {
131
- return nil, fmt.Errorf("环境变量配置验证错误: %v", err)
132
- }
133
- return cfg, nil
134
- }
135
-
136
- // 如果不使用环境变量,则使用原有的配置文件加载逻辑
137
- var configFile string
138
- flag.StringVar(&configFile, "config", "", "path to config file")
139
- flag.Parse()
140
- viper.SetConfigType("yaml")
141
-
142
- // ... existing code ...
143
- }
144
- // ... existing code ...
145
-
146
- // 修改第一个 loadConfig 函数名为 loadConfigFromFile
147
- func loadConfigFromFile() (*Config, error) {
148
- var configFile string
149
- flag.StringVar(&configFile, "config", "", "path to config file")
150
- flag.Parse()
151
- viper.SetConfigType("yaml")
152
-
153
- if configFile != "" {
154
- viper.SetConfigFile(configFile)
155
- } else {
156
- ex, err := os.Executable()
157
- if err != nil {
158
- return nil, err
159
- }
160
- exePath := filepath.Dir(ex)
161
- defaultPaths := []string{
162
- filepath.Join(exePath, "config.yaml"),
163
- filepath.Join(exePath, "conf", "config.yaml"),
164
- "./config.yaml",
165
- "./conf/config.yaml",
166
- }
167
- if os.PathSeparator == '\\' {
168
- programData := os.Getenv("PROGRAMDATA")
169
- if programData != "" {
170
- defaultPaths = append(defaultPaths, filepath.Join(programData, "DeepAI", "config.yaml"))
171
- }
172
- } else {
173
- defaultPaths = append(defaultPaths, "/etc/deepai/config.yaml")
174
- }
175
- for _, p := range defaultPaths {
176
- viper.AddConfigPath(filepath.Dir(p))
177
- if strings.Contains(p, ".yaml") {
178
- viper.SetConfigName(strings.TrimSuffix(filepath.Base(p), ".yaml"))
179
- }
180
- }
181
- }
182
-
183
- if err := viper.ReadInConfig(); err != nil {
184
- return nil, fmt.Errorf("failed to read config: %v", err)
185
- }
186
-
187
- var cfg Config
188
- if err := viper.Unmarshal(&cfg); err != nil {
189
- return nil, fmt.Errorf("unmarshal config error: %v", err)
190
- }
191
- if err := validateConfig(&cfg); err != nil {
192
- return nil, fmt.Errorf("config validation error: %v", err)
193
- }
194
- return &cfg, nil
195
- }
196
-
197
- // 修改 loadConfig 函数
198
- func loadConfig() (*Config, error) {
199
- // 优先使用环境变量
200
- if os.Getenv("USE_ENV_CONFIG") == "true" {
201
- cfg := loadConfigFromEnv()
202
- if err := validateConfig(cfg); err != nil {
203
- return nil, fmt.Errorf("环境变量配置验证错误: %v", err)
204
- }
205
- return cfg, nil
206
- }
207
-
208
- // 如果不使用环境变量,则使用配置文件加载逻辑
209
- return loadConfigFromFile()
210
- }
211
-
212
- func (s *ThinkingService) GetFullURL() string {
213
- return s.BaseURL + s.APIPath
214
- }
215
-
216
- type Channel struct {
217
- Name string `mapstructure:"name"`
218
- BaseURL string `mapstructure:"base_url"`
219
- APIPath string `mapstructure:"api_path"`
220
- Timeout int `mapstructure:"timeout"`
221
- Proxy string `mapstructure:"proxy"`
222
- }
223
-
224
- func (c *Channel) GetFullURL() string {
225
- return c.BaseURL + c.APIPath
226
- }
227
-
228
- type LogConfig struct {
229
- Level string `mapstructure:"level"`
230
- Format string `mapstructure:"format"`
231
- Output string `mapstructure:"output"`
232
- FilePath string `mapstructure:"file_path"`
233
- Debug DebugConfig `mapstructure:"debug"`
234
- }
235
-
236
- type DebugConfig struct {
237
- Enabled bool `mapstructure:"enabled"`
238
- PrintRequest bool `mapstructure:"print_request"`
239
- PrintResponse bool `mapstructure:"print_response"`
240
- MaxContentLength int `mapstructure:"max_content_length"`
241
- }
242
-
243
- type ProxyConfig struct {
244
- Enabled bool `mapstructure:"enabled"`
245
- Default string `mapstructure:"default"`
246
- AllowInsecure bool `mapstructure:"allow_insecure"`
247
- }
248
-
249
- type GlobalConfig struct {
250
- MaxRetries int `mapstructure:"max_retries"`
251
- DefaultTimeout int `mapstructure:"default_timeout"`
252
- ErrorCodes struct {
253
- RetryOn []int `mapstructure:"retry_on"`
254
- } `mapstructure:"error_codes"`
255
- Log LogConfig `mapstructure:"log"`
256
- Server ServerConfig `mapstructure:"server"`
257
- Proxy ProxyConfig `mapstructure:"proxy"`
258
- ConfigPaths []string `mapstructure:"config_paths"`
259
- Thinking ThinkingConfig `mapstructure:"thinking"`
260
- }
261
-
262
- type ServerConfig struct {
263
- Port int `mapstructure:"port"`
264
- Host string `mapstructure:"host"`
265
- ReadTimeout int `mapstructure:"read_timeout"`
266
- WriteTimeout int `mapstructure:"write_timeout"`
267
- IdleTimeout int `mapstructure:"idle_timeout"`
268
- }
269
-
270
- type ThinkingConfig struct {
271
- Enabled bool `mapstructure:"enabled"`
272
- AddToAllRequests bool `mapstructure:"add_to_all_requests"`
273
- Timeout int `mapstructure:"timeout"`
274
- }
275
-
276
- // ---------------------- API 相关结构 ----------------------
277
-
278
- type ChatCompletionRequest struct {
279
- Model string `json:"model"`
280
- Messages []ChatCompletionMessage `json:"messages"`
281
- Temperature float64 `json:"temperature,omitempty"`
282
- MaxTokens int `json:"max_tokens,omitempty"`
283
- Stream bool `json:"stream,omitempty"`
284
- APIKey string `json:"-"` // 内部传递,不序列化
285
- }
286
-
287
- type ChatCompletionMessage struct {
288
- Role string `json:"role"`
289
- Content string `json:"content"`
290
- ReasoningContent interface{} `json:"reasoning_content,omitempty"`
291
- }
292
-
293
- type ChatCompletionResponse struct {
294
- ID string `json:"id"`
295
- Object string `json:"object"`
296
- Created int64 `json:"created"`
297
- Model string `json:"model"`
298
- Choices []Choice `json:"choices"`
299
- Usage Usage `json:"usage"`
300
- }
301
-
302
- type Choice struct {
303
- Index int `json:"index"`
304
- Message ChatCompletionMessage `json:"message"`
305
- FinishReason string `json:"finish_reason"`
306
- }
307
-
308
- type Usage struct {
309
- PromptTokens int `json:"prompt_tokens"`
310
- CompletionTokens int `json:"completion_tokens"`
311
- TotalTokens int `json:"total_tokens"`
312
- }
313
-
314
- // ---------------------- 日志工具 ----------------------
315
-
316
- type RequestLogger struct {
317
- RequestID string
318
- Model string
319
- StartTime time.Time
320
- logs []string
321
- config *Config
322
- }
323
-
324
- func NewRequestLogger(config *Config) *RequestLogger {
325
- return &RequestLogger{
326
- RequestID: uuid.New().String(),
327
- StartTime: time.Now(),
328
- logs: make([]string, 0),
329
- config: config,
330
- }
331
- }
332
-
333
- // 新增函数:替换环境变量
334
- func replaceEnvVars(value string) string {
335
- // 检查是否包含环境变量占位符 ${...}
336
- if !strings.Contains(value, "${") {
337
- return value
338
- }
339
-
340
- // 查找所有 ${...} 格式的环境变量并替换
341
- result := value
342
- for {
343
- start := strings.Index(result, "${")
344
- if start == -1 {
345
- break
346
- }
347
-
348
- end := strings.Index(result[start:], "}")
349
- if end == -1 {
350
- break
351
- }
352
- end += start
353
-
354
- // 提取环境变量名称
355
- envName := result[start+2:end]
356
-
357
- // 获取环境变量值,如果不存在则保持原样
358
- envValue := os.Getenv(envName)
359
- if envValue == "" {
360
- // 可选:记录警告或使用默认值
361
- // 这里我们保持原样,也可以设置为空字符串或其他默认值
362
- }
363
-
364
- // 替换环境变量
365
- result = result[:start] + envValue + result[end+1:]
366
- }
367
-
368
- return result
369
- }
370
-
371
- // 处理配置结构体中的环境变量
372
- func ProcessConfig(config *Config) {
373
- // 处理 ThinkingServices
374
- for i := range config.ThinkingServices {
375
- config.ThinkingServices[i].Name = replaceEnvVars(config.ThinkingServices[i].Name)
376
- config.ThinkingServices[i].Model = replaceEnvVars(config.ThinkingServices[i].Model)
377
- config.ThinkingServices[i].BaseURL = replaceEnvVars(config.ThinkingServices[i].BaseURL)
378
- config.ThinkingServices[i].APIKey = replaceEnvVars(config.ThinkingServices[i].APIKey)
379
- config.ThinkingServices[i].Proxy = replaceEnvVars(config.ThinkingServices[i].Proxy)
380
- }
381
-
382
- // 处理 Channels
383
- for key, channel := range config.Channels {
384
- channel.Name = replaceEnvVars(channel.Name)
385
- channel.BaseURL = replaceEnvVars(channel.BaseURL)
386
- channel.Proxy = replaceEnvVars(channel.Proxy)
387
- config.Channels[key] = channel
388
- }
389
-
390
- // 如果 Global 配置中也有环境变量,可以在这里添加处理逻辑
391
- }
392
-
393
- func (l *RequestLogger) Log(format string, args ...interface{}) {
394
- msg := fmt.Sprintf(format, args...)
395
- l.logs = append(l.logs, fmt.Sprintf("[%s] %s", time.Now().Format(time.RFC3339), msg))
396
- log.Printf("[RequestID: %s] %s", l.RequestID, msg)
397
- }
398
-
399
- func (l *RequestLogger) LogContent(contentType string, content interface{}, maxLength int) {
400
- if !l.config.Global.Log.Debug.Enabled {
401
- return
402
- }
403
- sanitizedContent := sanitizeJSON(content)
404
- truncatedContent := truncateContent(sanitizedContent, maxLength)
405
- l.Log("%s Content:\n%s", contentType, truncatedContent)
406
- }
407
-
408
- func truncateContent(content string, maxLength int) string {
409
- if len(content) <= maxLength {
410
- return content
411
- }
412
- return content[:maxLength] + "... (truncated)"
413
- }
414
-
415
- func sanitizeJSON(data interface{}) string {
416
- sanitized, err := json.Marshal(data)
417
- if err != nil {
418
- return "Failed to marshal JSON"
419
- }
420
- content := string(sanitized)
421
- sensitivePattern := `"api_key":\s*"[^"]*"`
422
- content = regexp.MustCompile(sensitivePattern).ReplaceAllString(content, `"api_key":"****"`)
423
- return content
424
- }
425
-
426
- func extractRealAPIKey(fullKey string) string {
427
- parts := strings.Split(fullKey, "-")
428
- if len(parts) >= 3 && (parts[0] == "deep" || parts[0] == "openai") {
429
- return strings.Join(parts[2:], "-")
430
- }
431
- return fullKey
432
- }
433
-
434
- func extractChannelID(fullKey string) string {
435
- parts := strings.Split(fullKey, "-")
436
- if len(parts) >= 2 && (parts[0] == "deep" || parts[0] == "openai") {
437
- return parts[1]
438
- }
439
- return "1" // 默认渠道
440
- }
441
-
442
- func logAPIKey(key string) string {
443
- if len(key) <= 8 {
444
- return "****"
445
- }
446
- return key[:4] + "..." + key[len(key)-4:]
447
- }
448
-
449
- // ---------------------- Server 结构 ----------------------
450
-
451
- type Server struct {
452
- config *Config
453
- srv *http.Server
454
- }
455
-
456
- var (
457
- randMu sync.Mutex
458
- randGen = rand.New(rand.NewSource(time.Now().UnixNano()))
459
- )
460
-
461
- func NewServer(config *Config) *Server {
462
- return &Server{
463
- config: config,
464
- }
465
- }
466
-
467
- func (s *Server) Start() error {
468
- mux := http.NewServeMux()
469
- mux.HandleFunc("/v1/chat/completions", s.handleOpenAIRequests)
470
- mux.HandleFunc("/v1/models", s.handleOpenAIRequests)
471
- mux.HandleFunc("/health", s.handleHealth)
472
-
473
- s.srv = &http.Server{
474
- Addr: fmt.Sprintf("%s:%d", s.config.Global.Server.Host, s.config.Global.Server.Port),
475
- Handler: mux,
476
- ReadTimeout: time.Duration(s.config.Global.Server.ReadTimeout) * time.Second,
477
- WriteTimeout: time.Duration(s.config.Global.Server.WriteTimeout) * time.Second,
478
- IdleTimeout: time.Duration(s.config.Global.Server.IdleTimeout) * time.Second,
479
- }
480
-
481
- log.Printf("Server starting on %s\n", s.srv.Addr)
482
- return s.srv.ListenAndServe()
483
- }
484
-
485
- func (s *Server) Shutdown(ctx context.Context) error {
486
- return s.srv.Shutdown(ctx)
487
- }
488
-
489
- func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
490
- w.WriteHeader(http.StatusOK)
491
- json.NewEncoder(w).Encode(map[string]string{"status": "healthy"})
492
- }
493
-
494
- func (s *Server) handleOpenAIRequests(w http.ResponseWriter, r *http.Request) {
495
- logger := NewRequestLogger(s.config)
496
-
497
- fullAPIKey := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
498
- apiKey := extractRealAPIKey(fullAPIKey)
499
- channelID := extractChannelID(fullAPIKey)
500
-
501
- logger.Log("Received request for %s with API Key: %s", r.URL.Path, logAPIKey(fullAPIKey))
502
- logger.Log("Extracted channel ID: %s", channelID)
503
- logger.Log("Extracted real API Key: %s", logAPIKey(apiKey))
504
-
505
- targetChannel, ok := s.config.Channels[channelID]
506
- if !ok {
507
- http.Error(w, "Invalid channel", http.StatusBadRequest)
508
- return
509
- }
510
-
511
- if r.URL.Path == "/v1/models" {
512
- if r.Method != http.MethodGet {
513
- http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
514
- return
515
- }
516
- req := &ChatCompletionRequest{APIKey: apiKey}
517
- s.forwardModelsRequest(w, r.Context(), req, targetChannel)
518
- return
519
- }
520
-
521
- if r.Method != http.MethodPost {
522
- http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
523
- return
524
- }
525
-
526
- body, err := io.ReadAll(r.Body)
527
- if err != nil {
528
- logger.Log("Error reading request body: %v", err)
529
- http.Error(w, "Failed to read request", http.StatusBadRequest)
530
- return
531
- }
532
- r.Body.Close()
533
- r.Body = io.NopCloser(bytes.NewBuffer(body))
534
-
535
- if s.config.Global.Log.Debug.PrintRequest {
536
- logger.LogContent("Request", string(body), s.config.Global.Log.Debug.MaxContentLength)
537
- }
538
-
539
- var req ChatCompletionRequest
540
- if err := json.NewDecoder(bytes.NewBuffer(body)).Decode(&req); err != nil {
541
- http.Error(w, "Invalid request body", http.StatusBadRequest)
542
- return
543
- }
544
- req.APIKey = apiKey
545
-
546
- thinkingService := s.getWeightedRandomThinkingService()
547
- logger.Log("Using thinking service: %s with API Key: %s", thinkingService.Name, logAPIKey(thinkingService.APIKey))
548
-
549
- if req.Stream {
550
- handler, err := NewStreamHandler(w, thinkingService, targetChannel, s.config)
551
- if err != nil {
552
- http.Error(w, "Streaming not supported", http.StatusInternalServerError)
553
- return
554
- }
555
- if err := handler.HandleRequest(r.Context(), &req); err != nil {
556
- logger.Log("Stream handler error: %v", err)
557
- }
558
- } else {
559
- thinkingResp, err := s.processThinkingContent(r.Context(), &req, thinkingService)
560
- if err != nil {
561
- logger.Log("Error processing thinking content: %v", err)
562
- http.Error(w, "Thinking service error: "+err.Error(), http.StatusInternalServerError)
563
- return
564
- }
565
- enhancedReq := s.prepareEnhancedRequest(&req, thinkingResp, thinkingService)
566
- s.forwardRequest(w, r.Context(), enhancedReq, targetChannel)
567
- }
568
- }
569
-
570
- func (s *Server) getWeightedRandomThinkingService() ThinkingService {
571
- thinkingServices := s.config.ThinkingServices
572
- if len(thinkingServices) == 0 {
573
- return ThinkingService{}
574
- }
575
- totalWeight := 0
576
- for _, svc := range thinkingServices {
577
- totalWeight += svc.Weight
578
- }
579
- if totalWeight <= 0 {
580
- log.Println("Warning: Total weight of thinking services is not positive, using first service as default.")
581
- return thinkingServices[0]
582
- }
583
- randMu.Lock()
584
- randNum := randGen.Intn(totalWeight)
585
- randMu.Unlock()
586
- currentSum := 0
587
- for _, svc := range thinkingServices {
588
- currentSum += svc.Weight
589
- if randNum < currentSum {
590
- return svc
591
- }
592
- }
593
- return thinkingServices[0]
594
- }
595
-
596
- // ---------------------- 非流式处理思考服务 ----------------------
597
-
598
- type ThinkingResponse struct {
599
- Content string
600
- ReasoningContent string
601
- }
602
-
603
- func (s *Server) processThinkingContent(ctx context.Context, req *ChatCompletionRequest, svc ThinkingService) (*ThinkingResponse, error) {
604
- logger := NewRequestLogger(s.config)
605
- log.Printf("Getting thinking content from service: %s (mode=%s)", svc.Name, svc.Mode)
606
-
607
- thinkingReq := *req
608
- thinkingReq.Model = svc.Model
609
- thinkingReq.APIKey = svc.APIKey
610
-
611
- var systemPrompt string
612
- if svc.Mode == "full" {
613
- systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly."
614
- } else {
615
- systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step."
616
- }
617
- thinkingReq.Messages = append([]ChatCompletionMessage{
618
- {Role: "system", Content: systemPrompt},
619
- }, thinkingReq.Messages...)
620
-
621
- temp := 0.7
622
- if svc.Temperature != nil {
623
- temp = *svc.Temperature
624
- }
625
- payload := map[string]interface{}{
626
- "model": svc.Model,
627
- "messages": thinkingReq.Messages,
628
- "stream": false,
629
- "temperature": temp,
630
- }
631
- if isValidReasoningEffort(svc.ReasoningEffort) {
632
- payload["reasoning_effort"] = svc.ReasoningEffort
633
- }
634
- if isValidReasoningFormat(svc.ReasoningFormat) {
635
- payload["reasoning_format"] = svc.ReasoningFormat
636
- }
637
-
638
- if s.config.Global.Log.Debug.PrintRequest {
639
- logger.LogContent("Thinking Service Request", payload, s.config.Global.Log.Debug.MaxContentLength)
640
- }
641
-
642
- jsonData, err := json.Marshal(payload)
643
- if err != nil {
644
- return nil, fmt.Errorf("failed to marshal thinking request: %v", err)
645
- }
646
- client, err := createHTTPClient(svc.Proxy, time.Duration(svc.Timeout)*time.Second)
647
- if err != nil {
648
- return nil, fmt.Errorf("failed to create HTTP client: %v", err)
649
- }
650
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, svc.GetFullURL(), bytes.NewBuffer(jsonData))
651
- if err != nil {
652
- return nil, fmt.Errorf("failed to create request: %v", err)
653
- }
654
- httpReq.Header.Set("Content-Type", "application/json")
655
- httpReq.Header.Set("Authorization", "Bearer "+svc.APIKey)
656
-
657
- resp, err := client.Do(httpReq)
658
- if err != nil {
659
- return nil, fmt.Errorf("failed to send thinking request: %v", err)
660
- }
661
- defer resp.Body.Close()
662
-
663
- respBody, err := io.ReadAll(resp.Body)
664
- if err != nil {
665
- return nil, fmt.Errorf("failed to read response body: %v", err)
666
- }
667
- if s.config.Global.Log.Debug.PrintResponse {
668
- logger.LogContent("Thinking Service Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
669
- }
670
- if resp.StatusCode != http.StatusOK {
671
- return nil, fmt.Errorf("thinking service returned %d: %s", resp.StatusCode, string(respBody))
672
- }
673
-
674
- var thinkingResp ChatCompletionResponse
675
- if err := json.Unmarshal(respBody, &thinkingResp); err != nil {
676
- return nil, fmt.Errorf("failed to unmarshal thinking response: %v", err)
677
- }
678
- if len(thinkingResp.Choices) == 0 {
679
- return nil, fmt.Errorf("thinking service returned no choices")
680
- }
681
-
682
- result := &ThinkingResponse{}
683
- choice := thinkingResp.Choices[0]
684
-
685
- if svc.Mode == "full" {
686
- result.ReasoningContent = choice.Message.Content
687
- result.Content = "Based on the above detailed analysis."
688
- } else {
689
- if choice.Message.ReasoningContent != nil {
690
- switch v := choice.Message.ReasoningContent.(type) {
691
- case string:
692
- result.ReasoningContent = v
693
- case map[string]interface{}:
694
- if j, err := json.Marshal(v); err == nil {
695
- result.ReasoningContent = string(j)
696
- }
697
- }
698
- }
699
- if result.ReasoningContent == "" {
700
- result.ReasoningContent = choice.Message.Content
701
- }
702
- result.Content = "Based on the above reasoning."
703
- }
704
- return result, nil
705
- }
706
-
707
- func (s *Server) prepareEnhancedRequest(originalReq *ChatCompletionRequest, thinkingResp *ThinkingResponse, svc ThinkingService) *ChatCompletionRequest {
708
- newReq := *originalReq
709
- var systemPrompt string
710
- if svc.Mode == "full" {
711
- systemPrompt = fmt.Sprintf(`Consider the following detailed analysis (not shown to user):
712
- %s
713
-
714
- Provide a clear, concise response that incorporates insights from this analysis.`, thinkingResp.ReasoningContent)
715
- } else {
716
- systemPrompt = fmt.Sprintf(`Previous thinking process:
717
- %s
718
- Please consider the above thinking process in your response.`, thinkingResp.ReasoningContent)
719
- }
720
- newReq.Messages = append([]ChatCompletionMessage{
721
- {Role: "system", Content: systemPrompt},
722
- }, newReq.Messages...)
723
- return &newReq
724
- }
725
-
726
- func (s *Server) forwardRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, channel Channel) {
727
- logger := NewRequestLogger(s.config)
728
- if s.config.Global.Log.Debug.PrintRequest {
729
- logger.LogContent("Forward Request", req, s.config.Global.Log.Debug.MaxContentLength)
730
- }
731
- jsonData, err := json.Marshal(req)
732
- if err != nil {
733
- http.Error(w, "Failed to marshal request", http.StatusInternalServerError)
734
- return
735
- }
736
-
737
- client, err := createHTTPClient(channel.Proxy, time.Duration(channel.Timeout)*time.Second)
738
- if err != nil {
739
- http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError)
740
- return
741
- }
742
-
743
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, channel.GetFullURL(), bytes.NewBuffer(jsonData))
744
- if err != nil {
745
- http.Error(w, "Failed to create request", http.StatusInternalServerError)
746
- return
747
- }
748
- httpReq.Header.Set("Content-Type", "application/json")
749
- httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
750
-
751
- resp, err := client.Do(httpReq)
752
- if err != nil {
753
- http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError)
754
- return
755
- }
756
- defer resp.Body.Close()
757
-
758
- respBody, err := io.ReadAll(resp.Body)
759
- if err != nil {
760
- http.Error(w, "Failed to read response", http.StatusInternalServerError)
761
- return
762
- }
763
- if s.config.Global.Log.Debug.PrintResponse {
764
- logger.LogContent("Forward Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
765
- }
766
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
767
- http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode)
768
- return
769
- }
770
-
771
- for k, vals := range resp.Header {
772
- for _, v := range vals {
773
- w.Header().Add(k, v)
774
- }
775
- }
776
- w.WriteHeader(resp.StatusCode)
777
- w.Write(respBody)
778
- }
779
-
780
- func (s *Server) forwardModelsRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, targetChannel Channel) {
781
- logger := NewRequestLogger(s.config)
782
- if s.config.Global.Log.Debug.PrintRequest {
783
- logger.LogContent("/v1/models Request", req, s.config.Global.Log.Debug.MaxContentLength)
784
- }
785
- fullChatURL := targetChannel.GetFullURL()
786
- parsedURL, err := url.Parse(fullChatURL)
787
- if err != nil {
788
- http.Error(w, "Failed to parse channel URL", http.StatusInternalServerError)
789
- return
790
- }
791
- baseURL := parsedURL.Scheme + "://" + parsedURL.Host
792
- modelsURL := strings.TrimSuffix(baseURL, "/") + "/v1/models"
793
-
794
- client, err := createHTTPClient(targetChannel.Proxy, time.Duration(targetChannel.Timeout)*time.Second)
795
- if err != nil {
796
- http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError)
797
- return
798
- }
799
-
800
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil)
801
- if err != nil {
802
- http.Error(w, "Failed to create request", http.StatusInternalServerError)
803
- return
804
- }
805
- httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
806
-
807
- resp, err := client.Do(httpReq)
808
- if err != nil {
809
- http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError)
810
- return
811
- }
812
- defer resp.Body.Close()
813
-
814
- respBody, err := io.ReadAll(resp.Body)
815
- if err != nil {
816
- http.Error(w, "Failed to read response", http.StatusInternalServerError)
817
- return
818
- }
819
- if s.config.Global.Log.Debug.PrintResponse {
820
- logger.LogContent("/v1/models Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
821
- }
822
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
823
- http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode)
824
- return
825
- }
826
-
827
- for k, vals := range resp.Header {
828
- for _, v := range vals {
829
- w.Header().Add(k, v)
830
- }
831
- }
832
- w.WriteHeader(resp.StatusCode)
833
- w.Write(respBody)
834
- }
835
-
836
- // ---------------------- 流式处理 ----------------------
837
-
838
- // collectedReasoningBuffer 用于收集思考服务返回的 reasoning_content(标准模式只收集 reasoning_content;full 模式收集全部)
839
- type collectedReasoningBuffer struct {
840
- builder strings.Builder
841
- mode string
842
- }
843
-
844
- func (cb *collectedReasoningBuffer) append(text string) {
845
- cb.builder.WriteString(text)
846
- }
847
-
848
- func (cb *collectedReasoningBuffer) get() string {
849
- return cb.builder.String()
850
- }
851
-
852
- type StreamHandler struct {
853
- thinkingService ThinkingService
854
- targetChannel Channel
855
- writer http.ResponseWriter
856
- flusher http.Flusher
857
- config *Config
858
- }
859
-
860
- func NewStreamHandler(w http.ResponseWriter, thinkingService ThinkingService, targetChannel Channel, config *Config) (*StreamHandler, error) {
861
- flusher, ok := w.(http.Flusher)
862
- if !ok {
863
- return nil, fmt.Errorf("streaming not supported by this writer")
864
- }
865
- return &StreamHandler{
866
- thinkingService: thinkingService,
867
- targetChannel: targetChannel,
868
- writer: w,
869
- flusher: flusher,
870
- config: config,
871
- }, nil
872
- }
873
-
874
- // HandleRequest 分两阶段:先流式接收思考服务 SSE 并转发给客户端(只转发包含 reasoning_content 的 chunk,标准模式遇到纯 content 则中断);
875
- // 然后使用收集到的 reasoning_content 构造最终请求,发起目标 Channel 的 SSE 并转发给客户端。
876
- func (h *StreamHandler) HandleRequest(ctx context.Context, req *ChatCompletionRequest) error {
877
- h.writer.Header().Set("Content-Type", "text/event-stream")
878
- h.writer.Header().Set("Cache-Control", "no-cache")
879
- h.writer.Header().Set("Connection", "keep-alive")
880
-
881
- logger := NewRequestLogger(h.config)
882
- reasonBuf := &collectedReasoningBuffer{mode: h.thinkingService.Mode}
883
-
884
- // 阶段一:流式接收思考服务 SSE
885
- if err := h.streamThinkingService(ctx, req, reasonBuf, logger); err != nil {
886
- return err
887
- }
888
-
889
- // 阶段二:构造最终请求,并流式转发目标 Channel SSE
890
- finalReq := h.prepareFinalRequest(req, reasonBuf.get())
891
- if err := h.streamFinalChannel(ctx, finalReq, logger); err != nil {
892
- return err
893
- }
894
- return nil
895
- }
896
-
897
- // streamThinkingService 连接思考服务 SSE,原样转发包含 reasoning_content 的 SSE chunk给客户端,同时收集 reasoning_content。
898
- // 在标准模式下,一旦遇到只返回 non-empty 的 content(且 reasoning_content 为空),则立即中断,不转发该 chunk。
899
- func (h *StreamHandler) streamThinkingService(ctx context.Context, req *ChatCompletionRequest, reasonBuf *collectedReasoningBuffer, logger *RequestLogger) error {
900
- thinkingReq := *req
901
- thinkingReq.Model = h.thinkingService.Model
902
- thinkingReq.APIKey = h.thinkingService.APIKey
903
-
904
- var systemPrompt string
905
- if h.thinkingService.Mode == "full" {
906
- systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly."
907
- } else {
908
- systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step."
909
- }
910
- messages := append([]ChatCompletionMessage{
911
- {Role: "system", Content: systemPrompt},
912
- }, thinkingReq.Messages...)
913
-
914
- temp := 0.7
915
- if h.thinkingService.Temperature != nil {
916
- temp = *h.thinkingService.Temperature
917
- }
918
- bodyMap := map[string]interface{}{
919
- "model": thinkingReq.Model,
920
- "messages": messages,
921
- "temperature": temp,
922
- "stream": true,
923
- }
924
- if isValidReasoningEffort(h.thinkingService.ReasoningEffort) {
925
- bodyMap["reasoning_effort"] = h.thinkingService.ReasoningEffort
926
- }
927
- if isValidReasoningFormat(h.thinkingService.ReasoningFormat) {
928
- bodyMap["reasoning_format"] = h.thinkingService.ReasoningFormat
929
- }
930
-
931
- jsonData, err := json.Marshal(bodyMap)
932
- if err != nil {
933
- return err
934
- }
935
-
936
- client, err := createHTTPClient(h.thinkingService.Proxy, time.Duration(h.thinkingService.Timeout)*time.Second)
937
- if err != nil {
938
- return err
939
- }
940
-
941
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.thinkingService.GetFullURL(), bytes.NewBuffer(jsonData))
942
- if err != nil {
943
- return err
944
- }
945
- httpReq.Header.Set("Content-Type", "application/json")
946
- httpReq.Header.Set("Authorization", "Bearer "+h.thinkingService.APIKey)
947
-
948
- log.Printf("Starting ThinkingService SSE => %s (mode=%s, force_stop=%v)", h.thinkingService.GetFullURL(), h.thinkingService.Mode, h.thinkingService.ForceStopDeepThinking)
949
- resp, err := client.Do(httpReq)
950
- if err != nil {
951
- return err
952
- }
953
- defer resp.Body.Close()
954
-
955
- if resp.StatusCode != http.StatusOK {
956
- b, _ := io.ReadAll(resp.Body)
957
- return fmt.Errorf("thinking service status=%d, body=%s", resp.StatusCode, string(b))
958
- }
959
-
960
- reader := bufio.NewReader(resp.Body)
961
- var lastLine string
962
- // 标识是否需中断思考 SSE(标准模式下)
963
- forceStop := false
964
-
965
- for {
966
- select {
967
- case <-ctx.Done():
968
- return ctx.Err()
969
- default:
970
- }
971
- line, err := reader.ReadString('\n')
972
- if err != nil {
973
- if err == io.EOF {
974
- break
975
- }
976
- return err
977
- }
978
- line = strings.TrimSpace(line)
979
- if line == "" || line == lastLine {
980
- continue
981
- }
982
- lastLine = line
983
-
984
- if !strings.HasPrefix(line, "data: ") {
985
- continue
986
- }
987
- dataPart := strings.TrimPrefix(line, "data: ")
988
- if dataPart == "[DONE]" {
989
- // 思考服务结束:直接中断(不发送特殊标记)
990
- break
991
- }
992
-
993
- var chunk struct {
994
- Choices []struct {
995
- Delta struct {
996
- Content string `json:"content,omitempty"`
997
- ReasoningContent string `json:"reasoning_content,omitempty"`
998
- } `json:"delta"`
999
- FinishReason *string `json:"finish_reason,omitempty"`
1000
- } `json:"choices"`
1001
- }
1002
- if err := json.Unmarshal([]byte(dataPart), &chunk); err != nil {
1003
- // 如果解析失败,直接原样转发
1004
- h.writer.Write([]byte("data: " + dataPart + "\n\n"))
1005
- h.flusher.Flush()
1006
- continue
1007
- }
1008
-
1009
- if len(chunk.Choices) > 0 {
1010
- c := chunk.Choices[0]
1011
- if h.config.Global.Log.Debug.PrintResponse {
1012
- logger.LogContent("Thinking SSE chunk", chunk, h.config.Global.Log.Debug.MaxContentLength)
1013
- }
1014
-
1015
- if h.thinkingService.Mode == "full" {
1016
- // full 模式:收集所有内容
1017
- if c.Delta.ReasoningContent != "" {
1018
- reasonBuf.append(c.Delta.ReasoningContent)
1019
- }
1020
- if c.Delta.Content != "" {
1021
- reasonBuf.append(c.Delta.Content)
1022
- }
1023
- // 原样转发整 chunk
1024
- forwardLine := "data: " + dataPart + "\n\n"
1025
- h.writer.Write([]byte(forwardLine))
1026
- h.flusher.Flush()
1027
- } else {
1028
- // standard 模式:只收集 reasoning_content
1029
- if c.Delta.ReasoningContent != "" {
1030
- reasonBuf.append(c.Delta.ReasoningContent)
1031
- // 转发该 chunk
1032
- forwardLine := "data: " + dataPart + "\n\n"
1033
- h.writer.Write([]byte(forwardLine))
1034
- h.flusher.Flush()
1035
- }
1036
- // 如果遇到 chunk 中只有 content(且 reasoning_content 为空),认为思考链结束,不转发该 chunk
1037
- if c.Delta.Content != "" && strings.TrimSpace(c.Delta.ReasoningContent) == "" {
1038
- forceStop = true
1039
- }
1040
- }
1041
-
1042
- // 如果 finishReason 非空,也认为结束
1043
- if c.FinishReason != nil && *c.FinishReason != "" {
1044
- forceStop = true
1045
- }
1046
- }
1047
-
1048
- if forceStop {
1049
- break
1050
- }
1051
- }
1052
- // 读空剩余数据
1053
- io.Copy(io.Discard, reader)
1054
- return nil
1055
- }
1056
-
1057
- func (h *StreamHandler) prepareFinalRequest(originalReq *ChatCompletionRequest, reasoningCollected string) *ChatCompletionRequest {
1058
- req := *originalReq
1059
- var systemPrompt string
1060
- if h.thinkingService.Mode == "full" {
1061
- systemPrompt = fmt.Sprintf(
1062
- `Consider the following detailed analysis (not shown to user):
1063
- %s
1064
-
1065
- Provide a clear, concise response that incorporates insights from this analysis.`,
1066
- reasoningCollected,
1067
- )
1068
- } else {
1069
- systemPrompt = fmt.Sprintf(
1070
- `Previous thinking process:
1071
- %s
1072
- Please consider the above thinking process in your response.`,
1073
- reasoningCollected,
1074
- )
1075
- }
1076
- req.Messages = append([]ChatCompletionMessage{
1077
- {Role: "system", Content: systemPrompt},
1078
- }, req.Messages...)
1079
- return &req
1080
- }
1081
-
1082
- func (h *StreamHandler) streamFinalChannel(ctx context.Context, req *ChatCompletionRequest, logger *RequestLogger) error {
1083
- if h.config.Global.Log.Debug.PrintRequest {
1084
- logger.LogContent("Final Request to Channel", req, h.config.Global.Log.Debug.MaxContentLength)
1085
- }
1086
-
1087
- jsonData, err := json.Marshal(req)
1088
- if err != nil {
1089
- return err
1090
- }
1091
- client, err := createHTTPClient(h.targetChannel.Proxy, time.Duration(h.targetChannel.Timeout)*time.Second)
1092
- if err != nil {
1093
- return err
1094
- }
1095
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.targetChannel.GetFullURL(), bytes.NewBuffer(jsonData))
1096
- if err != nil {
1097
- return err
1098
- }
1099
- httpReq.Header.Set("Content-Type", "application/json")
1100
- httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
1101
-
1102
- resp, err := client.Do(httpReq)
1103
- if err != nil {
1104
- return err
1105
- }
1106
- defer resp.Body.Close()
1107
-
1108
- if resp.StatusCode != http.StatusOK {
1109
- b, _ := io.ReadAll(resp.Body)
1110
- return fmt.Errorf("target channel status=%d, body=%s", resp.StatusCode, string(b))
1111
- }
1112
-
1113
- reader := bufio.NewReader(resp.Body)
1114
- var lastLine string
1115
-
1116
- for {
1117
- select {
1118
- case <-ctx.Done():
1119
- return ctx.Err()
1120
- default:
1121
- }
1122
- line, err := reader.ReadString('\n')
1123
- if err != nil {
1124
- if err == io.EOF {
1125
- break
1126
- }
1127
- return err
1128
- }
1129
- line = strings.TrimSpace(line)
1130
- if line == "" || line == lastLine {
1131
- continue
1132
- }
1133
- lastLine = line
1134
-
1135
- if !strings.HasPrefix(line, "data: ") {
1136
- continue
1137
- }
1138
- data := strings.TrimPrefix(line, "data: ")
1139
- if data == "[DONE]" {
1140
- doneLine := "data: [DONE]\n\n"
1141
- h.writer.Write([]byte(doneLine))
1142
- h.flusher.Flush()
1143
- break
1144
- }
1145
- forwardLine := "data: " + data + "\n\n"
1146
- h.writer.Write([]byte(forwardLine))
1147
- h.flusher.Flush()
1148
-
1149
- if h.config.Global.Log.Debug.PrintResponse {
1150
- logger.LogContent("Channel SSE chunk", forwardLine, h.config.Global.Log.Debug.MaxContentLength)
1151
- }
1152
- }
1153
- io.Copy(io.Discard, reader)
1154
- return nil
1155
- }
1156
-
1157
- // ---------------------- HTTP Client 工具 ----------------------
1158
-
1159
- func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) {
1160
- transport := &http.Transport{
1161
- DialContext: (&net.Dialer{
1162
- Timeout: 30 * time.Second,
1163
- KeepAlive: 30 * time.Second,
1164
- }).DialContext,
1165
- ForceAttemptHTTP2: true,
1166
- MaxIdleConns: 100,
1167
- IdleConnTimeout: 90 * time.Second,
1168
- TLSHandshakeTimeout: 10 * time.Second,
1169
- ExpectContinueTimeout: 1 * time.Second,
1170
- }
1171
- if proxyURL != "" {
1172
- parsedURL, err := url.Parse(proxyURL)
1173
- if err != nil {
1174
- return nil, fmt.Errorf("invalid proxy URL: %v", err)
1175
- }
1176
- switch parsedURL.Scheme {
1177
- case "http", "https":
1178
- transport.Proxy = http.ProxyURL(parsedURL)
1179
- case "socks5":
1180
- dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
1181
- if err != nil {
1182
- return nil, fmt.Errorf("failed to create SOCKS5 dialer: %v", err)
1183
- }
1184
- transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
1185
- return dialer.Dial(network, addr)
1186
- }
1187
- default:
1188
- return nil, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
1189
- }
1190
- }
1191
- return &http.Client{
1192
- Transport: transport,
1193
- Timeout: timeout,
1194
- }, nil
1195
- }
1196
-
1197
- func maskSensitiveHeaders(headers http.Header) http.Header {
1198
- masked := make(http.Header)
1199
- for k, vals := range headers {
1200
- if strings.ToLower(k) == "authorization" {
1201
- masked[k] = []string{"Bearer ****"}
1202
- } else {
1203
- masked[k] = vals
1204
- }
1205
- }
1206
- return masked
1207
- }
1208
-
1209
- func isValidReasoningEffort(effort string) bool {
1210
- switch strings.ToLower(effort) {
1211
- case "low", "medium", "high":
1212
- return true
1213
- }
1214
- return false
1215
- }
1216
-
1217
- func isValidReasoningFormat(format string) bool {
1218
- switch strings.ToLower(format) {
1219
- case "parsed", "raw", "hidden":
1220
- return true
1221
- }
1222
- return false
1223
- }
1224
-
1225
- // ---------------------- 配置加载与验证 ----------------------
1226
-
1227
- func loadConfig() (*Config, error) {
1228
- var configFile string
1229
- flag.StringVar(&configFile, "config", "", "path to config file")
1230
- flag.Parse()
1231
- viper.SetConfigType("yaml")
1232
-
1233
- if configFile != "" {
1234
- viper.SetConfigFile(configFile)
1235
- } else {
1236
- ex, err := os.Executable()
1237
- if err != nil {
1238
- return nil, err
1239
- }
1240
- exePath := filepath.Dir(ex)
1241
- defaultPaths := []string{
1242
- filepath.Join(exePath, "config.yaml"),
1243
- filepath.Join(exePath, "conf", "config.yaml"),
1244
- "./config.yaml",
1245
- "./conf/config.yaml",
1246
- }
1247
- if os.PathSeparator == '\\' {
1248
- programData := os.Getenv("PROGRAMDATA")
1249
- if programData != "" {
1250
- defaultPaths = append(defaultPaths, filepath.Join(programData, "DeepAI", "config.yaml"))
1251
- }
1252
- } else {
1253
- defaultPaths = append(defaultPaths, "/etc/deepai/config.yaml")
1254
- }
1255
- for _, p := range defaultPaths {
1256
- viper.AddConfigPath(filepath.Dir(p))
1257
- if strings.Contains(p, ".yaml") {
1258
- viper.SetConfigName(strings.TrimSuffix(filepath.Base(p), ".yaml"))
1259
- }
1260
- }
1261
- }
1262
-
1263
- if err := viper.ReadInConfig(); err != nil {
1264
- return nil, fmt.Errorf("failed to read config: %v", err)
1265
- }
1266
-
1267
- var cfg Config
1268
- if err := viper.Unmarshal(&cfg); err != nil {
1269
- return nil, fmt.Errorf("unmarshal config error: %v", err)
1270
- }
1271
- if err := validateConfig(&cfg); err != nil {
1272
- return nil, fmt.Errorf("config validation error: %v", err)
1273
- }
1274
- return &cfg, nil
1275
- }
1276
-
1277
- func validateConfig(config *Config) error {
1278
- if len(config.ThinkingServices) == 0 {
1279
- return fmt.Errorf("no thinking services configured")
1280
- }
1281
- if len(config.Channels) == 0 {
1282
- return fmt.Errorf("no channels configured")
1283
- }
1284
- for i, svc := range config.ThinkingServices {
1285
- if svc.BaseURL == "" {
1286
- return fmt.Errorf("thinking service %s has empty baseURL", svc.Name)
1287
- }
1288
- if svc.APIKey == "" {
1289
- return fmt.Errorf("thinking service %s has empty apiKey", svc.Name)
1290
- }
1291
- if svc.Timeout <= 0 {
1292
- return fmt.Errorf("thinking service %s has invalid timeout", svc.Name)
1293
- }
1294
- if svc.Model == "" {
1295
- return fmt.Errorf("thinking service %s has empty model", svc.Name)
1296
- }
1297
- if svc.Mode == "" {
1298
- config.ThinkingServices[i].Mode = "standard"
1299
- } else if svc.Mode != "standard" && svc.Mode != "full" {
1300
- return fmt.Errorf("thinking service %s unknown mode=%s", svc.Name, svc.Mode)
1301
- }
1302
- }
1303
- return nil
1304
- }
1305
-
1306
- func main() {
1307
- log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile)
1308
-
1309
- cfg, err := loadConfig()
1310
- if err != nil {
1311
- log.Fatalf("Failed to load config: %v", err)
1312
- }
1313
- log.Printf("Using config file: %s", viper.ConfigFileUsed())
1314
-
1315
- server := NewServer(cfg)
1316
-
1317
- done := make(chan os.Signal, 1)
1318
- signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
1319
-
1320
- go func() {
1321
- if err := server.Start(); err != nil && err != http.ErrServerClosed {
1322
- log.Fatalf("start server error: %v", err)
1323
- }
1324
- }()
1325
- log.Printf("Server started successfully")
1326
-
1327
- <-done
1328
- log.Print("Server stopping...")
1329
-
1330
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
1331
- defer cancel()
1332
-
1333
- if err := server.Shutdown(ctx); err != nil {
1334
- log.Printf("Server forced to shutdown: %v", err)
1335
- }
1336
- log.Print("Server stopped")
1337
- }