nbugs commited on
Commit
96c8fad
·
verified ·
1 Parent(s): 59b3d84

Delete main.go

Browse files
Files changed (1) hide show
  1. main.go +0 -1181
main.go DELETED
@@ -1,1181 +0,0 @@
1
- package main
2
-
3
- import (
4
- "bufio"
5
- "bytes"
6
- "context"
7
- "encoding/json"
8
- "flag"
9
- "fmt"
10
- "io"
11
- "log"
12
- "math/rand"
13
- "net"
14
- "net/http"
15
- "net/url"
16
- "os"
17
- "os/signal"
18
- "path/filepath"
19
- "regexp"
20
- "strings"
21
- "sync"
22
- "syscall"
23
- "time"
24
-
25
- "github.com/google/uuid"
26
- "github.com/spf13/viper"
27
- "golang.org/x/net/proxy"
28
- )
29
-
30
- // ---------------------- 配置结构 ----------------------
31
-
32
- type Config struct {
33
- ThinkingServices []ThinkingService `mapstructure:"thinking_services"`
34
- Channels map[string]Channel `mapstructure:"channels"`
35
- Global GlobalConfig `mapstructure:"global"`
36
- }
37
-
38
- type ThinkingService struct {
39
- ID int `mapstructure:"id"`
40
- Name string `mapstructure:"name"`
41
- Model string `mapstructure:"model"`
42
- BaseURL string `mapstructure:"base_url"`
43
- APIPath string `mapstructure:"api_path"`
44
- APIKey string `mapstructure:"api_key"`
45
- Timeout int `mapstructure:"timeout"`
46
- Retry int `mapstructure:"retry"`
47
- Weight int `mapstructure:"weight"`
48
- Proxy string `mapstructure:"proxy"`
49
- Mode string `mapstructure:"mode"` // "standard" 或 "full"
50
- ReasoningEffort string `mapstructure:"reasoning_effort"`
51
- ReasoningFormat string `mapstructure:"reasoning_format"`
52
- Temperature *float64 `mapstructure:"temperature"`
53
- ForceStopDeepThinking bool `mapstructure:"force_stop_deep_thinking"` // 配置项:标准模式下遇到 content 时是否立即停止
54
- }
55
-
56
- func (s *ThinkingService) GetFullURL() string {
57
- return s.BaseURL + s.APIPath
58
- }
59
-
60
- type Channel struct {
61
- Name string `mapstructure:"name"`
62
- BaseURL string `mapstructure:"base_url"`
63
- APIPath string `mapstructure:"api_path"`
64
- Timeout int `mapstructure:"timeout"`
65
- Proxy string `mapstructure:"proxy"`
66
- }
67
-
68
- func (c *Channel) GetFullURL() string {
69
- return c.BaseURL + c.APIPath
70
- }
71
-
72
- type LogConfig struct {
73
- Level string `mapstructure:"level"`
74
- Format string `mapstructure:"format"`
75
- Output string `mapstructure:"output"`
76
- FilePath string `mapstructure:"file_path"`
77
- Debug DebugConfig `mapstructure:"debug"`
78
- }
79
-
80
- type DebugConfig struct {
81
- Enabled bool `mapstructure:"enabled"`
82
- PrintRequest bool `mapstructure:"print_request"`
83
- PrintResponse bool `mapstructure:"print_response"`
84
- MaxContentLength int `mapstructure:"max_content_length"`
85
- }
86
-
87
- type ProxyConfig struct {
88
- Enabled bool `mapstructure:"enabled"`
89
- Default string `mapstructure:"default"`
90
- AllowInsecure bool `mapstructure:"allow_insecure"`
91
- }
92
-
93
- type GlobalConfig struct {
94
- MaxRetries int `mapstructure:"max_retries"`
95
- DefaultTimeout int `mapstructure:"default_timeout"`
96
- ErrorCodes struct {
97
- RetryOn []int `mapstructure:"retry_on"`
98
- } `mapstructure:"error_codes"`
99
- Log LogConfig `mapstructure:"log"`
100
- Server ServerConfig `mapstructure:"server"`
101
- Proxy ProxyConfig `mapstructure:"proxy"`
102
- ConfigPaths []string `mapstructure:"config_paths"`
103
- Thinking ThinkingConfig `mapstructure:"thinking"`
104
- }
105
-
106
- type ServerConfig struct {
107
- Port int `mapstructure:"port"`
108
- Host string `mapstructure:"host"`
109
- ReadTimeout int `mapstructure:"read_timeout"`
110
- WriteTimeout int `mapstructure:"write_timeout"`
111
- IdleTimeout int `mapstructure:"idle_timeout"`
112
- }
113
-
114
- type ThinkingConfig struct {
115
- Enabled bool `mapstructure:"enabled"`
116
- AddToAllRequests bool `mapstructure:"add_to_all_requests"`
117
- Timeout int `mapstructure:"timeout"`
118
- }
119
-
120
- // ---------------------- API 相关结构 ----------------------
121
-
122
- type ChatCompletionRequest struct {
123
- Model string `json:"model"`
124
- Messages []ChatCompletionMessage `json:"messages"`
125
- Temperature float64 `json:"temperature,omitempty"`
126
- MaxTokens int `json:"max_tokens,omitempty"`
127
- Stream bool `json:"stream,omitempty"`
128
- APIKey string `json:"-"` // 内部传递,不序列化
129
- }
130
-
131
- type ChatCompletionMessage struct {
132
- Role string `json:"role"`
133
- Content string `json:"content"`
134
- ReasoningContent interface{} `json:"reasoning_content,omitempty"`
135
- }
136
-
137
- type ChatCompletionResponse struct {
138
- ID string `json:"id"`
139
- Object string `json:"object"`
140
- Created int64 `json:"created"`
141
- Model string `json:"model"`
142
- Choices []Choice `json:"choices"`
143
- Usage Usage `json:"usage"`
144
- }
145
-
146
- type Choice struct {
147
- Index int `json:"index"`
148
- Message ChatCompletionMessage `json:"message"`
149
- FinishReason string `json:"finish_reason"`
150
- }
151
-
152
- type Usage struct {
153
- PromptTokens int `json:"prompt_tokens"`
154
- CompletionTokens int `json:"completion_tokens"`
155
- TotalTokens int `json:"total_tokens"`
156
- }
157
-
158
- // ---------------------- 日志工具 ----------------------
159
-
160
- type RequestLogger struct {
161
- RequestID string
162
- Model string
163
- StartTime time.Time
164
- logs []string
165
- config *Config
166
- }
167
-
168
- func NewRequestLogger(config *Config) *RequestLogger {
169
- return &RequestLogger{
170
- RequestID: uuid.New().String(),
171
- StartTime: time.Now(),
172
- logs: make([]string, 0),
173
- config: config,
174
- }
175
- }
176
-
177
- // 新增函数:替换环境变量
178
- func replaceEnvVars(value string) string {
179
- // 检查是否包含环境变量占位符 ${...}
180
- if !strings.Contains(value, "${") {
181
- return value
182
- }
183
-
184
- // 查找所有 ${...} 格式的环境变量并替换
185
- result := value
186
- for {
187
- start := strings.Index(result, "${")
188
- if start == -1 {
189
- break
190
- }
191
-
192
- end := strings.Index(result[start:], "}")
193
- if end == -1 {
194
- break
195
- }
196
- end += start
197
-
198
- // 提取环境变量名称
199
- envName := result[start+2:end]
200
-
201
- // 获取环境变量值,如果不存在则保持原样
202
- envValue := os.Getenv(envName)
203
- if envValue == "" {
204
- // 可选:记录警告或使用默认值
205
- // 这里我们保持原样,也可以设置为空字符串或其他默认值
206
- }
207
-
208
- // 替换环境变量
209
- result = result[:start] + envValue + result[end+1:]
210
- }
211
-
212
- return result
213
- }
214
-
215
- // 处理配置结构体中的环境变量
216
- func ProcessConfig(config *Config) {
217
- // 处理 ThinkingServices
218
- for i := range config.ThinkingServices {
219
- config.ThinkingServices[i].Name = replaceEnvVars(config.ThinkingServices[i].Name)
220
- config.ThinkingServices[i].Model = replaceEnvVars(config.ThinkingServices[i].Model)
221
- config.ThinkingServices[i].BaseURL = replaceEnvVars(config.ThinkingServices[i].BaseURL)
222
- config.ThinkingServices[i].APIKey = replaceEnvVars(config.ThinkingServices[i].APIKey)
223
- config.ThinkingServices[i].Proxy = replaceEnvVars(config.ThinkingServices[i].Proxy)
224
- }
225
-
226
- // 处理 Channels
227
- for key, channel := range config.Channels {
228
- channel.Name = replaceEnvVars(channel.Name)
229
- channel.BaseURL = replaceEnvVars(channel.BaseURL)
230
- channel.Proxy = replaceEnvVars(channel.Proxy)
231
- config.Channels[key] = channel
232
- }
233
-
234
- // 如果 Global 配置中也有环境变量,可以在这里添加处理逻辑
235
- }
236
-
237
- func (l *RequestLogger) Log(format string, args ...interface{}) {
238
- msg := fmt.Sprintf(format, args...)
239
- l.logs = append(l.logs, fmt.Sprintf("[%s] %s", time.Now().Format(time.RFC3339), msg))
240
- log.Printf("[RequestID: %s] %s", l.RequestID, msg)
241
- }
242
-
243
- func (l *RequestLogger) LogContent(contentType string, content interface{}, maxLength int) {
244
- if !l.config.Global.Log.Debug.Enabled {
245
- return
246
- }
247
- sanitizedContent := sanitizeJSON(content)
248
- truncatedContent := truncateContent(sanitizedContent, maxLength)
249
- l.Log("%s Content:\n%s", contentType, truncatedContent)
250
- }
251
-
252
- func truncateContent(content string, maxLength int) string {
253
- if len(content) <= maxLength {
254
- return content
255
- }
256
- return content[:maxLength] + "... (truncated)"
257
- }
258
-
259
- func sanitizeJSON(data interface{}) string {
260
- sanitized, err := json.Marshal(data)
261
- if err != nil {
262
- return "Failed to marshal JSON"
263
- }
264
- content := string(sanitized)
265
- sensitivePattern := `"api_key":\s*"[^"]*"`
266
- content = regexp.MustCompile(sensitivePattern).ReplaceAllString(content, `"api_key":"****"`)
267
- return content
268
- }
269
-
270
- func extractRealAPIKey(fullKey string) string {
271
- parts := strings.Split(fullKey, "-")
272
- if len(parts) >= 3 && (parts[0] == "deep" || parts[0] == "openai") {
273
- return strings.Join(parts[2:], "-")
274
- }
275
- return fullKey
276
- }
277
-
278
- func extractChannelID(fullKey string) string {
279
- parts := strings.Split(fullKey, "-")
280
- if len(parts) >= 2 && (parts[0] == "deep" || parts[0] == "openai") {
281
- return parts[1]
282
- }
283
- return "1" // 默认渠道
284
- }
285
-
286
- func logAPIKey(key string) string {
287
- if len(key) <= 8 {
288
- return "****"
289
- }
290
- return key[:4] + "..." + key[len(key)-4:]
291
- }
292
-
293
- // ---------------------- Server 结构 ----------------------
294
-
295
- type Server struct {
296
- config *Config
297
- srv *http.Server
298
- }
299
-
300
- var (
301
- randMu sync.Mutex
302
- randGen = rand.New(rand.NewSource(time.Now().UnixNano()))
303
- )
304
-
305
- func NewServer(config *Config) *Server {
306
- return &Server{
307
- config: config,
308
- }
309
- }
310
-
311
- func (s *Server) Start() error {
312
- mux := http.NewServeMux()
313
- mux.HandleFunc("/v1/chat/completions", s.handleOpenAIRequests)
314
- mux.HandleFunc("/v1/models", s.handleOpenAIRequests)
315
- mux.HandleFunc("/health", s.handleHealth)
316
-
317
- s.srv = &http.Server{
318
- Addr: fmt.Sprintf("%s:%d", s.config.Global.Server.Host, s.config.Global.Server.Port),
319
- Handler: mux,
320
- ReadTimeout: time.Duration(s.config.Global.Server.ReadTimeout) * time.Second,
321
- WriteTimeout: time.Duration(s.config.Global.Server.WriteTimeout) * time.Second,
322
- IdleTimeout: time.Duration(s.config.Global.Server.IdleTimeout) * time.Second,
323
- }
324
-
325
- log.Printf("Server starting on %s\n", s.srv.Addr)
326
- return s.srv.ListenAndServe()
327
- }
328
-
329
- func (s *Server) Shutdown(ctx context.Context) error {
330
- return s.srv.Shutdown(ctx)
331
- }
332
-
333
- func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
334
- w.WriteHeader(http.StatusOK)
335
- json.NewEncoder(w).Encode(map[string]string{"status": "healthy"})
336
- }
337
-
338
- func (s *Server) handleOpenAIRequests(w http.ResponseWriter, r *http.Request) {
339
- logger := NewRequestLogger(s.config)
340
-
341
- fullAPIKey := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
342
- apiKey := extractRealAPIKey(fullAPIKey)
343
- channelID := extractChannelID(fullAPIKey)
344
-
345
- logger.Log("Received request for %s with API Key: %s", r.URL.Path, logAPIKey(fullAPIKey))
346
- logger.Log("Extracted channel ID: %s", channelID)
347
- logger.Log("Extracted real API Key: %s", logAPIKey(apiKey))
348
-
349
- targetChannel, ok := s.config.Channels[channelID]
350
- if !ok {
351
- http.Error(w, "Invalid channel", http.StatusBadRequest)
352
- return
353
- }
354
-
355
- if r.URL.Path == "/v1/models" {
356
- if r.Method != http.MethodGet {
357
- http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
358
- return
359
- }
360
- req := &ChatCompletionRequest{APIKey: apiKey}
361
- s.forwardModelsRequest(w, r.Context(), req, targetChannel)
362
- return
363
- }
364
-
365
- if r.Method != http.MethodPost {
366
- http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
367
- return
368
- }
369
-
370
- body, err := io.ReadAll(r.Body)
371
- if err != nil {
372
- logger.Log("Error reading request body: %v", err)
373
- http.Error(w, "Failed to read request", http.StatusBadRequest)
374
- return
375
- }
376
- r.Body.Close()
377
- r.Body = io.NopCloser(bytes.NewBuffer(body))
378
-
379
- if s.config.Global.Log.Debug.PrintRequest {
380
- logger.LogContent("Request", string(body), s.config.Global.Log.Debug.MaxContentLength)
381
- }
382
-
383
- var req ChatCompletionRequest
384
- if err := json.NewDecoder(bytes.NewBuffer(body)).Decode(&req); err != nil {
385
- http.Error(w, "Invalid request body", http.StatusBadRequest)
386
- return
387
- }
388
- req.APIKey = apiKey
389
-
390
- thinkingService := s.getWeightedRandomThinkingService()
391
- logger.Log("Using thinking service: %s with API Key: %s", thinkingService.Name, logAPIKey(thinkingService.APIKey))
392
-
393
- if req.Stream {
394
- handler, err := NewStreamHandler(w, thinkingService, targetChannel, s.config)
395
- if err != nil {
396
- http.Error(w, "Streaming not supported", http.StatusInternalServerError)
397
- return
398
- }
399
- if err := handler.HandleRequest(r.Context(), &req); err != nil {
400
- logger.Log("Stream handler error: %v", err)
401
- }
402
- } else {
403
- thinkingResp, err := s.processThinkingContent(r.Context(), &req, thinkingService)
404
- if err != nil {
405
- logger.Log("Error processing thinking content: %v", err)
406
- http.Error(w, "Thinking service error: "+err.Error(), http.StatusInternalServerError)
407
- return
408
- }
409
- enhancedReq := s.prepareEnhancedRequest(&req, thinkingResp, thinkingService)
410
- s.forwardRequest(w, r.Context(), enhancedReq, targetChannel)
411
- }
412
- }
413
-
414
- func (s *Server) getWeightedRandomThinkingService() ThinkingService {
415
- thinkingServices := s.config.ThinkingServices
416
- if len(thinkingServices) == 0 {
417
- return ThinkingService{}
418
- }
419
- totalWeight := 0
420
- for _, svc := range thinkingServices {
421
- totalWeight += svc.Weight
422
- }
423
- if totalWeight <= 0 {
424
- log.Println("Warning: Total weight of thinking services is not positive, using first service as default.")
425
- return thinkingServices[0]
426
- }
427
- randMu.Lock()
428
- randNum := randGen.Intn(totalWeight)
429
- randMu.Unlock()
430
- currentSum := 0
431
- for _, svc := range thinkingServices {
432
- currentSum += svc.Weight
433
- if randNum < currentSum {
434
- return svc
435
- }
436
- }
437
- return thinkingServices[0]
438
- }
439
-
440
- // ---------------------- 非流式处理思考服务 ----------------------
441
-
442
- type ThinkingResponse struct {
443
- Content string
444
- ReasoningContent string
445
- }
446
-
447
- func (s *Server) processThinkingContent(ctx context.Context, req *ChatCompletionRequest, svc ThinkingService) (*ThinkingResponse, error) {
448
- logger := NewRequestLogger(s.config)
449
- log.Printf("Getting thinking content from service: %s (mode=%s)", svc.Name, svc.Mode)
450
-
451
- thinkingReq := *req
452
- thinkingReq.Model = svc.Model
453
- thinkingReq.APIKey = svc.APIKey
454
-
455
- var systemPrompt string
456
- if svc.Mode == "full" {
457
- systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly."
458
- } else {
459
- systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step."
460
- }
461
- thinkingReq.Messages = append([]ChatCompletionMessage{
462
- {Role: "system", Content: systemPrompt},
463
- }, thinkingReq.Messages...)
464
-
465
- temp := 0.7
466
- if svc.Temperature != nil {
467
- temp = *svc.Temperature
468
- }
469
- payload := map[string]interface{}{
470
- "model": svc.Model,
471
- "messages": thinkingReq.Messages,
472
- "stream": false,
473
- "temperature": temp,
474
- }
475
- if isValidReasoningEffort(svc.ReasoningEffort) {
476
- payload["reasoning_effort"] = svc.ReasoningEffort
477
- }
478
- if isValidReasoningFormat(svc.ReasoningFormat) {
479
- payload["reasoning_format"] = svc.ReasoningFormat
480
- }
481
-
482
- if s.config.Global.Log.Debug.PrintRequest {
483
- logger.LogContent("Thinking Service Request", payload, s.config.Global.Log.Debug.MaxContentLength)
484
- }
485
-
486
- jsonData, err := json.Marshal(payload)
487
- if err != nil {
488
- return nil, fmt.Errorf("failed to marshal thinking request: %v", err)
489
- }
490
- client, err := createHTTPClient(svc.Proxy, time.Duration(svc.Timeout)*time.Second)
491
- if err != nil {
492
- return nil, fmt.Errorf("failed to create HTTP client: %v", err)
493
- }
494
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, svc.GetFullURL(), bytes.NewBuffer(jsonData))
495
- if err != nil {
496
- return nil, fmt.Errorf("failed to create request: %v", err)
497
- }
498
- httpReq.Header.Set("Content-Type", "application/json")
499
- httpReq.Header.Set("Authorization", "Bearer "+svc.APIKey)
500
-
501
- resp, err := client.Do(httpReq)
502
- if err != nil {
503
- return nil, fmt.Errorf("failed to send thinking request: %v", err)
504
- }
505
- defer resp.Body.Close()
506
-
507
- respBody, err := io.ReadAll(resp.Body)
508
- if err != nil {
509
- return nil, fmt.Errorf("failed to read response body: %v", err)
510
- }
511
- if s.config.Global.Log.Debug.PrintResponse {
512
- logger.LogContent("Thinking Service Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
513
- }
514
- if resp.StatusCode != http.StatusOK {
515
- return nil, fmt.Errorf("thinking service returned %d: %s", resp.StatusCode, string(respBody))
516
- }
517
-
518
- var thinkingResp ChatCompletionResponse
519
- if err := json.Unmarshal(respBody, &thinkingResp); err != nil {
520
- return nil, fmt.Errorf("failed to unmarshal thinking response: %v", err)
521
- }
522
- if len(thinkingResp.Choices) == 0 {
523
- return nil, fmt.Errorf("thinking service returned no choices")
524
- }
525
-
526
- result := &ThinkingResponse{}
527
- choice := thinkingResp.Choices[0]
528
-
529
- if svc.Mode == "full" {
530
- result.ReasoningContent = choice.Message.Content
531
- result.Content = "Based on the above detailed analysis."
532
- } else {
533
- if choice.Message.ReasoningContent != nil {
534
- switch v := choice.Message.ReasoningContent.(type) {
535
- case string:
536
- result.ReasoningContent = v
537
- case map[string]interface{}:
538
- if j, err := json.Marshal(v); err == nil {
539
- result.ReasoningContent = string(j)
540
- }
541
- }
542
- }
543
- if result.ReasoningContent == "" {
544
- result.ReasoningContent = choice.Message.Content
545
- }
546
- result.Content = "Based on the above reasoning."
547
- }
548
- return result, nil
549
- }
550
-
551
- func (s *Server) prepareEnhancedRequest(originalReq *ChatCompletionRequest, thinkingResp *ThinkingResponse, svc ThinkingService) *ChatCompletionRequest {
552
- newReq := *originalReq
553
- var systemPrompt string
554
- if svc.Mode == "full" {
555
- systemPrompt = fmt.Sprintf(`Consider the following detailed analysis (not shown to user):
556
- %s
557
-
558
- Provide a clear, concise response that incorporates insights from this analysis.`, thinkingResp.ReasoningContent)
559
- } else {
560
- systemPrompt = fmt.Sprintf(`Previous thinking process:
561
- %s
562
- Please consider the above thinking process in your response.`, thinkingResp.ReasoningContent)
563
- }
564
- newReq.Messages = append([]ChatCompletionMessage{
565
- {Role: "system", Content: systemPrompt},
566
- }, newReq.Messages...)
567
- return &newReq
568
- }
569
-
570
- func (s *Server) forwardRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, channel Channel) {
571
- logger := NewRequestLogger(s.config)
572
- if s.config.Global.Log.Debug.PrintRequest {
573
- logger.LogContent("Forward Request", req, s.config.Global.Log.Debug.MaxContentLength)
574
- }
575
- jsonData, err := json.Marshal(req)
576
- if err != nil {
577
- http.Error(w, "Failed to marshal request", http.StatusInternalServerError)
578
- return
579
- }
580
-
581
- client, err := createHTTPClient(channel.Proxy, time.Duration(channel.Timeout)*time.Second)
582
- if err != nil {
583
- http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError)
584
- return
585
- }
586
-
587
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, channel.GetFullURL(), bytes.NewBuffer(jsonData))
588
- if err != nil {
589
- http.Error(w, "Failed to create request", http.StatusInternalServerError)
590
- return
591
- }
592
- httpReq.Header.Set("Content-Type", "application/json")
593
- httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
594
-
595
- resp, err := client.Do(httpReq)
596
- if err != nil {
597
- http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError)
598
- return
599
- }
600
- defer resp.Body.Close()
601
-
602
- respBody, err := io.ReadAll(resp.Body)
603
- if err != nil {
604
- http.Error(w, "Failed to read response", http.StatusInternalServerError)
605
- return
606
- }
607
- if s.config.Global.Log.Debug.PrintResponse {
608
- logger.LogContent("Forward Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
609
- }
610
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
611
- http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode)
612
- return
613
- }
614
-
615
- for k, vals := range resp.Header {
616
- for _, v := range vals {
617
- w.Header().Add(k, v)
618
- }
619
- }
620
- w.WriteHeader(resp.StatusCode)
621
- w.Write(respBody)
622
- }
623
-
624
- func (s *Server) forwardModelsRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, targetChannel Channel) {
625
- logger := NewRequestLogger(s.config)
626
- if s.config.Global.Log.Debug.PrintRequest {
627
- logger.LogContent("/v1/models Request", req, s.config.Global.Log.Debug.MaxContentLength)
628
- }
629
- fullChatURL := targetChannel.GetFullURL()
630
- parsedURL, err := url.Parse(fullChatURL)
631
- if err != nil {
632
- http.Error(w, "Failed to parse channel URL", http.StatusInternalServerError)
633
- return
634
- }
635
- baseURL := parsedURL.Scheme + "://" + parsedURL.Host
636
- modelsURL := strings.TrimSuffix(baseURL, "/") + "/v1/models"
637
-
638
- client, err := createHTTPClient(targetChannel.Proxy, time.Duration(targetChannel.Timeout)*time.Second)
639
- if err != nil {
640
- http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError)
641
- return
642
- }
643
-
644
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil)
645
- if err != nil {
646
- http.Error(w, "Failed to create request", http.StatusInternalServerError)
647
- return
648
- }
649
- httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
650
-
651
- resp, err := client.Do(httpReq)
652
- if err != nil {
653
- http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError)
654
- return
655
- }
656
- defer resp.Body.Close()
657
-
658
- respBody, err := io.ReadAll(resp.Body)
659
- if err != nil {
660
- http.Error(w, "Failed to read response", http.StatusInternalServerError)
661
- return
662
- }
663
- if s.config.Global.Log.Debug.PrintResponse {
664
- logger.LogContent("/v1/models Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
665
- }
666
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
667
- http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode)
668
- return
669
- }
670
-
671
- for k, vals := range resp.Header {
672
- for _, v := range vals {
673
- w.Header().Add(k, v)
674
- }
675
- }
676
- w.WriteHeader(resp.StatusCode)
677
- w.Write(respBody)
678
- }
679
-
680
- // ---------------------- 流式处理 ----------------------
681
-
682
- // collectedReasoningBuffer 用于收集思考服务返回的 reasoning_content(标准模式只收集 reasoning_content;full 模式收集全部)
683
- type collectedReasoningBuffer struct {
684
- builder strings.Builder
685
- mode string
686
- }
687
-
688
- func (cb *collectedReasoningBuffer) append(text string) {
689
- cb.builder.WriteString(text)
690
- }
691
-
692
- func (cb *collectedReasoningBuffer) get() string {
693
- return cb.builder.String()
694
- }
695
-
696
- type StreamHandler struct {
697
- thinkingService ThinkingService
698
- targetChannel Channel
699
- writer http.ResponseWriter
700
- flusher http.Flusher
701
- config *Config
702
- }
703
-
704
- func NewStreamHandler(w http.ResponseWriter, thinkingService ThinkingService, targetChannel Channel, config *Config) (*StreamHandler, error) {
705
- flusher, ok := w.(http.Flusher)
706
- if !ok {
707
- return nil, fmt.Errorf("streaming not supported by this writer")
708
- }
709
- return &StreamHandler{
710
- thinkingService: thinkingService,
711
- targetChannel: targetChannel,
712
- writer: w,
713
- flusher: flusher,
714
- config: config,
715
- }, nil
716
- }
717
-
718
- // HandleRequest 分两阶段:先流式接收思考服务 SSE 并转发给客户端(只转发包含 reasoning_content 的 chunk,标准模式遇到纯 content 则中断);
719
- // 然后使用收集到的 reasoning_content 构造最终请求,发起目标 Channel 的 SSE 并转发给客户端。
720
- func (h *StreamHandler) HandleRequest(ctx context.Context, req *ChatCompletionRequest) error {
721
- h.writer.Header().Set("Content-Type", "text/event-stream")
722
- h.writer.Header().Set("Cache-Control", "no-cache")
723
- h.writer.Header().Set("Connection", "keep-alive")
724
-
725
- logger := NewRequestLogger(h.config)
726
- reasonBuf := &collectedReasoningBuffer{mode: h.thinkingService.Mode}
727
-
728
- // 阶段一:流式接收思考服务 SSE
729
- if err := h.streamThinkingService(ctx, req, reasonBuf, logger); err != nil {
730
- return err
731
- }
732
-
733
- // 阶段二:构造最终请求,并流式转发目标 Channel SSE
734
- finalReq := h.prepareFinalRequest(req, reasonBuf.get())
735
- if err := h.streamFinalChannel(ctx, finalReq, logger); err != nil {
736
- return err
737
- }
738
- return nil
739
- }
740
-
741
- // streamThinkingService 连接思考服务 SSE,原样转发包含 reasoning_content 的 SSE chunk给客户端,同时收集 reasoning_content。
742
- // 在标准模式下,一旦遇到只返回 non-empty 的 content(且 reasoning_content 为空),则立即中断,不转发该 chunk。
743
- func (h *StreamHandler) streamThinkingService(ctx context.Context, req *ChatCompletionRequest, reasonBuf *collectedReasoningBuffer, logger *RequestLogger) error {
744
- thinkingReq := *req
745
- thinkingReq.Model = h.thinkingService.Model
746
- thinkingReq.APIKey = h.thinkingService.APIKey
747
-
748
- var systemPrompt string
749
- if h.thinkingService.Mode == "full" {
750
- systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly."
751
- } else {
752
- systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step."
753
- }
754
- messages := append([]ChatCompletionMessage{
755
- {Role: "system", Content: systemPrompt},
756
- }, thinkingReq.Messages...)
757
-
758
- temp := 0.7
759
- if h.thinkingService.Temperature != nil {
760
- temp = *h.thinkingService.Temperature
761
- }
762
- bodyMap := map[string]interface{}{
763
- "model": thinkingReq.Model,
764
- "messages": messages,
765
- "temperature": temp,
766
- "stream": true,
767
- }
768
- if isValidReasoningEffort(h.thinkingService.ReasoningEffort) {
769
- bodyMap["reasoning_effort"] = h.thinkingService.ReasoningEffort
770
- }
771
- if isValidReasoningFormat(h.thinkingService.ReasoningFormat) {
772
- bodyMap["reasoning_format"] = h.thinkingService.ReasoningFormat
773
- }
774
-
775
- jsonData, err := json.Marshal(bodyMap)
776
- if err != nil {
777
- return err
778
- }
779
-
780
- client, err := createHTTPClient(h.thinkingService.Proxy, time.Duration(h.thinkingService.Timeout)*time.Second)
781
- if err != nil {
782
- return err
783
- }
784
-
785
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.thinkingService.GetFullURL(), bytes.NewBuffer(jsonData))
786
- if err != nil {
787
- return err
788
- }
789
- httpReq.Header.Set("Content-Type", "application/json")
790
- httpReq.Header.Set("Authorization", "Bearer "+h.thinkingService.APIKey)
791
-
792
- log.Printf("Starting ThinkingService SSE => %s (mode=%s, force_stop=%v)", h.thinkingService.GetFullURL(), h.thinkingService.Mode, h.thinkingService.ForceStopDeepThinking)
793
- resp, err := client.Do(httpReq)
794
- if err != nil {
795
- return err
796
- }
797
- defer resp.Body.Close()
798
-
799
- if resp.StatusCode != http.StatusOK {
800
- b, _ := io.ReadAll(resp.Body)
801
- return fmt.Errorf("thinking service status=%d, body=%s", resp.StatusCode, string(b))
802
- }
803
-
804
- reader := bufio.NewReader(resp.Body)
805
- var lastLine string
806
- // 标识是否需中断思考 SSE(标准模式下)
807
- forceStop := false
808
-
809
- for {
810
- select {
811
- case <-ctx.Done():
812
- return ctx.Err()
813
- default:
814
- }
815
- line, err := reader.ReadString('\n')
816
- if err != nil {
817
- if err == io.EOF {
818
- break
819
- }
820
- return err
821
- }
822
- line = strings.TrimSpace(line)
823
- if line == "" || line == lastLine {
824
- continue
825
- }
826
- lastLine = line
827
-
828
- if !strings.HasPrefix(line, "data: ") {
829
- continue
830
- }
831
- dataPart := strings.TrimPrefix(line, "data: ")
832
- if dataPart == "[DONE]" {
833
- // 思考服务结束:直接中断(不发送特殊标记)
834
- break
835
- }
836
-
837
- var chunk struct {
838
- Choices []struct {
839
- Delta struct {
840
- Content string `json:"content,omitempty"`
841
- ReasoningContent string `json:"reasoning_content,omitempty"`
842
- } `json:"delta"`
843
- FinishReason *string `json:"finish_reason,omitempty"`
844
- } `json:"choices"`
845
- }
846
- if err := json.Unmarshal([]byte(dataPart), &chunk); err != nil {
847
- // 如果解析失败,直接原样转发
848
- h.writer.Write([]byte("data: " + dataPart + "\n\n"))
849
- h.flusher.Flush()
850
- continue
851
- }
852
-
853
- if len(chunk.Choices) > 0 {
854
- c := chunk.Choices[0]
855
- if h.config.Global.Log.Debug.PrintResponse {
856
- logger.LogContent("Thinking SSE chunk", chunk, h.config.Global.Log.Debug.MaxContentLength)
857
- }
858
-
859
- if h.thinkingService.Mode == "full" {
860
- // full 模式:收集所有内容
861
- if c.Delta.ReasoningContent != "" {
862
- reasonBuf.append(c.Delta.ReasoningContent)
863
- }
864
- if c.Delta.Content != "" {
865
- reasonBuf.append(c.Delta.Content)
866
- }
867
- // 原样转发整 chunk
868
- forwardLine := "data: " + dataPart + "\n\n"
869
- h.writer.Write([]byte(forwardLine))
870
- h.flusher.Flush()
871
- } else {
872
- // standard 模式:只收集 reasoning_content
873
- if c.Delta.ReasoningContent != "" {
874
- reasonBuf.append(c.Delta.ReasoningContent)
875
- // 转发该 chunk
876
- forwardLine := "data: " + dataPart + "\n\n"
877
- h.writer.Write([]byte(forwardLine))
878
- h.flusher.Flush()
879
- }
880
- // 如果遇到 chunk 中只有 content(且 reasoning_content 为空),认为思考链结束,不转发该 chunk
881
- if c.Delta.Content != "" && strings.TrimSpace(c.Delta.ReasoningContent) == "" {
882
- forceStop = true
883
- }
884
- }
885
-
886
- // 如果 finishReason 非空,也认为结束
887
- if c.FinishReason != nil && *c.FinishReason != "" {
888
- forceStop = true
889
- }
890
- }
891
-
892
- if forceStop {
893
- break
894
- }
895
- }
896
- // 读空剩余数据
897
- io.Copy(io.Discard, reader)
898
- return nil
899
- }
900
-
901
- func (h *StreamHandler) prepareFinalRequest(originalReq *ChatCompletionRequest, reasoningCollected string) *ChatCompletionRequest {
902
- req := *originalReq
903
- var systemPrompt string
904
- if h.thinkingService.Mode == "full" {
905
- systemPrompt = fmt.Sprintf(
906
- `Consider the following detailed analysis (not shown to user):
907
- %s
908
-
909
- Provide a clear, concise response that incorporates insights from this analysis.`,
910
- reasoningCollected,
911
- )
912
- } else {
913
- systemPrompt = fmt.Sprintf(
914
- `Previous thinking process:
915
- %s
916
- Please consider the above thinking process in your response.`,
917
- reasoningCollected,
918
- )
919
- }
920
- req.Messages = append([]ChatCompletionMessage{
921
- {Role: "system", Content: systemPrompt},
922
- }, req.Messages...)
923
- return &req
924
- }
925
-
926
- func (h *StreamHandler) streamFinalChannel(ctx context.Context, req *ChatCompletionRequest, logger *RequestLogger) error {
927
- if h.config.Global.Log.Debug.PrintRequest {
928
- logger.LogContent("Final Request to Channel", req, h.config.Global.Log.Debug.MaxContentLength)
929
- }
930
-
931
- jsonData, err := json.Marshal(req)
932
- if err != nil {
933
- return err
934
- }
935
- client, err := createHTTPClient(h.targetChannel.Proxy, time.Duration(h.targetChannel.Timeout)*time.Second)
936
- if err != nil {
937
- return err
938
- }
939
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.targetChannel.GetFullURL(), bytes.NewBuffer(jsonData))
940
- if err != nil {
941
- return err
942
- }
943
- httpReq.Header.Set("Content-Type", "application/json")
944
- httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
945
-
946
- resp, err := client.Do(httpReq)
947
- if err != nil {
948
- return err
949
- }
950
- defer resp.Body.Close()
951
-
952
- if resp.StatusCode != http.StatusOK {
953
- b, _ := io.ReadAll(resp.Body)
954
- return fmt.Errorf("target channel status=%d, body=%s", resp.StatusCode, string(b))
955
- }
956
-
957
- reader := bufio.NewReader(resp.Body)
958
- var lastLine string
959
-
960
- for {
961
- select {
962
- case <-ctx.Done():
963
- return ctx.Err()
964
- default:
965
- }
966
- line, err := reader.ReadString('\n')
967
- if err != nil {
968
- if err == io.EOF {
969
- break
970
- }
971
- return err
972
- }
973
- line = strings.TrimSpace(line)
974
- if line == "" || line == lastLine {
975
- continue
976
- }
977
- lastLine = line
978
-
979
- if !strings.HasPrefix(line, "data: ") {
980
- continue
981
- }
982
- data := strings.TrimPrefix(line, "data: ")
983
- if data == "[DONE]" {
984
- doneLine := "data: [DONE]\n\n"
985
- h.writer.Write([]byte(doneLine))
986
- h.flusher.Flush()
987
- break
988
- }
989
- forwardLine := "data: " + data + "\n\n"
990
- h.writer.Write([]byte(forwardLine))
991
- h.flusher.Flush()
992
-
993
- if h.config.Global.Log.Debug.PrintResponse {
994
- logger.LogContent("Channel SSE chunk", forwardLine, h.config.Global.Log.Debug.MaxContentLength)
995
- }
996
- }
997
- io.Copy(io.Discard, reader)
998
- return nil
999
- }
1000
-
1001
- // ---------------------- HTTP Client 工具 ----------------------
1002
-
1003
- func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) {
1004
- transport := &http.Transport{
1005
- DialContext: (&net.Dialer{
1006
- Timeout: 30 * time.Second,
1007
- KeepAlive: 30 * time.Second,
1008
- }).DialContext,
1009
- ForceAttemptHTTP2: true,
1010
- MaxIdleConns: 100,
1011
- IdleConnTimeout: 90 * time.Second,
1012
- TLSHandshakeTimeout: 10 * time.Second,
1013
- ExpectContinueTimeout: 1 * time.Second,
1014
- }
1015
- if proxyURL != "" {
1016
- parsedURL, err := url.Parse(proxyURL)
1017
- if err != nil {
1018
- return nil, fmt.Errorf("invalid proxy URL: %v", err)
1019
- }
1020
- switch parsedURL.Scheme {
1021
- case "http", "https":
1022
- transport.Proxy = http.ProxyURL(parsedURL)
1023
- case "socks5":
1024
- dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
1025
- if err != nil {
1026
- return nil, fmt.Errorf("failed to create SOCKS5 dialer: %v", err)
1027
- }
1028
- transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
1029
- return dialer.Dial(network, addr)
1030
- }
1031
- default:
1032
- return nil, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
1033
- }
1034
- }
1035
- return &http.Client{
1036
- Transport: transport,
1037
- Timeout: timeout,
1038
- }, nil
1039
- }
1040
-
1041
- func maskSensitiveHeaders(headers http.Header) http.Header {
1042
- masked := make(http.Header)
1043
- for k, vals := range headers {
1044
- if strings.ToLower(k) == "authorization" {
1045
- masked[k] = []string{"Bearer ****"}
1046
- } else {
1047
- masked[k] = vals
1048
- }
1049
- }
1050
- return masked
1051
- }
1052
-
1053
- func isValidReasoningEffort(effort string) bool {
1054
- switch strings.ToLower(effort) {
1055
- case "low", "medium", "high":
1056
- return true
1057
- }
1058
- return false
1059
- }
1060
-
1061
- func isValidReasoningFormat(format string) bool {
1062
- switch strings.ToLower(format) {
1063
- case "parsed", "raw", "hidden":
1064
- return true
1065
- }
1066
- return false
1067
- }
1068
-
1069
- // ---------------------- 配置加载与验证 ----------------------
1070
-
1071
- func loadConfig() (*Config, error) {
1072
- var configFile string
1073
- flag.StringVar(&configFile, "config", "", "path to config file")
1074
- flag.Parse()
1075
- viper.SetConfigType("yaml")
1076
-
1077
- if configFile != "" {
1078
- viper.SetConfigFile(configFile)
1079
- } else {
1080
- ex, err := os.Executable()
1081
- if err != nil {
1082
- return nil, err
1083
- }
1084
- exePath := filepath.Dir(ex)
1085
- defaultPaths := []string{
1086
- filepath.Join(exePath, "config.yaml"),
1087
- filepath.Join(exePath, "conf", "config.yaml"),
1088
- "./config.yaml",
1089
- "./conf/config.yaml",
1090
- }
1091
- if os.PathSeparator == '\\' {
1092
- programData := os.Getenv("PROGRAMDATA")
1093
- if programData != "" {
1094
- defaultPaths = append(defaultPaths, filepath.Join(programData, "DeepAI", "config.yaml"))
1095
- }
1096
- } else {
1097
- defaultPaths = append(defaultPaths, "/etc/deepai/config.yaml")
1098
- }
1099
- for _, p := range defaultPaths {
1100
- viper.AddConfigPath(filepath.Dir(p))
1101
- if strings.Contains(p, ".yaml") {
1102
- viper.SetConfigName(strings.TrimSuffix(filepath.Base(p), ".yaml"))
1103
- }
1104
- }
1105
- }
1106
-
1107
- if err := viper.ReadInConfig(); err != nil {
1108
- return nil, fmt.Errorf("failed to read config: %v", err)
1109
- }
1110
-
1111
- var cfg Config
1112
- if err := viper.Unmarshal(&cfg); err != nil {
1113
- return nil, fmt.Errorf("unmarshal config error: %v", err)
1114
- }
1115
- if err := validateConfig(&cfg); err != nil {
1116
- return nil, fmt.Errorf("config validation error: %v", err)
1117
- }
1118
- return &cfg, nil
1119
- }
1120
-
1121
- func validateConfig(config *Config) error {
1122
- if len(config.ThinkingServices) == 0 {
1123
- return fmt.Errorf("no thinking services configured")
1124
- }
1125
- if len(config.Channels) == 0 {
1126
- return fmt.Errorf("no channels configured")
1127
- }
1128
- for i, svc := range config.ThinkingServices {
1129
- if svc.BaseURL == "" {
1130
- return fmt.Errorf("thinking service %s has empty baseURL", svc.Name)
1131
- }
1132
- if svc.APIKey == "" {
1133
- return fmt.Errorf("thinking service %s has empty apiKey", svc.Name)
1134
- }
1135
- if svc.Timeout <= 0 {
1136
- return fmt.Errorf("thinking service %s has invalid timeout", svc.Name)
1137
- }
1138
- if svc.Model == "" {
1139
- return fmt.Errorf("thinking service %s has empty model", svc.Name)
1140
- }
1141
- if svc.Mode == "" {
1142
- config.ThinkingServices[i].Mode = "standard"
1143
- } else if svc.Mode != "standard" && svc.Mode != "full" {
1144
- return fmt.Errorf("thinking service %s unknown mode=%s", svc.Name, svc.Mode)
1145
- }
1146
- }
1147
- return nil
1148
- }
1149
-
1150
- func main() {
1151
- log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile)
1152
-
1153
- cfg, err := loadConfig()
1154
- if err != nil {
1155
- log.Fatalf("Failed to load config: %v", err)
1156
- }
1157
- log.Printf("Using config file: %s", viper.ConfigFileUsed())
1158
-
1159
- server := NewServer(cfg)
1160
-
1161
- done := make(chan os.Signal, 1)
1162
- signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
1163
-
1164
- go func() {
1165
- if err := server.Start(); err != nil && err != http.ErrServerClosed {
1166
- log.Fatalf("start server error: %v", err)
1167
- }
1168
- }()
1169
- log.Printf("Server started successfully")
1170
-
1171
- <-done
1172
- log.Print("Server stopping...")
1173
-
1174
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
1175
- defer cancel()
1176
-
1177
- if err := server.Shutdown(ctx); err != nil {
1178
- log.Printf("Server forced to shutdown: %v", err)
1179
- }
1180
- log.Print("Server stopped")
1181
- }