nbugs commited on
Commit
924b5f2
·
verified ·
1 Parent(s): f3b2eee

Delete main.go

Browse files
Files changed (1) hide show
  1. main.go +0 -1125
main.go DELETED
@@ -1,1125 +0,0 @@
1
- package main
2
-
3
- import (
4
- "bufio"
5
- "bytes"
6
- "context"
7
- "encoding/json"
8
- "flag"
9
- "fmt"
10
- "io"
11
- "log"
12
- "math/rand"
13
- "net"
14
- "net/http"
15
- "net/url"
16
- "os"
17
- "os/signal"
18
- "path/filepath"
19
- "regexp"
20
- "strings"
21
- "sync"
22
- "syscall"
23
- "time"
24
-
25
-
26
- "gopkg.in/yaml.v2"
27
- "github.com/google/uuid"
28
- "github.com/spf13/viper"
29
- "golang.org/x/net/proxy"
30
- )
31
-
32
-
33
-
34
- // ---------------------- 配置结构 ----------------------
35
-
36
- type Config struct {
37
- ThinkingServices []ThinkingService `mapstructure:"thinking_services"`
38
- Channels map[string]Channel `mapstructure:"channels"`
39
- Global GlobalConfig `mapstructure:"global"`
40
- }
41
-
42
- type ThinkingService struct {
43
- ID int `mapstructure:"id"`
44
- Name string `mapstructure:"name"`
45
- Model string `mapstructure:"model"`
46
- BaseURL string `mapstructure:"base_url"`
47
- APIPath string `mapstructure:"api_path"`
48
- APIKey string `mapstructure:"api_key"`
49
- Timeout int `mapstructure:"timeout"`
50
- Retry int `mapstructure:"retry"`
51
- Weight int `mapstructure:"weight"`
52
- Proxy string `mapstructure:"proxy"`
53
- Mode string `mapstructure:"mode"` // "standard" 或 "full"
54
- ReasoningEffort string `mapstructure:"reasoning_effort"`
55
- ReasoningFormat string `mapstructure:"reasoning_format"`
56
- Temperature *float64 `mapstructure:"temperature"`
57
- ForceStopDeepThinking bool `mapstructure:"force_stop_deep_thinking"` // 配置项:标准模式下遇到 content 时是否立即停止
58
- }
59
-
60
- func (s *ThinkingService) GetFullURL() string {
61
- return s.BaseURL + s.APIPath
62
- }
63
-
64
- type Channel struct {
65
- Name string `mapstructure:"name"`
66
- BaseURL string `mapstructure:"base_url"`
67
- APIPath string `mapstructure:"api_path"`
68
- Timeout int `mapstructure:"timeout"`
69
- Proxy string `mapstructure:"proxy"`
70
- }
71
-
72
- func (c *Channel) GetFullURL() string {
73
- return c.BaseURL + c.APIPath
74
- }
75
-
76
- type LogConfig struct {
77
- Level string `mapstructure:"level"`
78
- Format string `mapstructure:"format"`
79
- Output string `mapstructure:"output"`
80
- FilePath string `mapstructure:"file_path"`
81
- Debug DebugConfig `mapstructure:"debug"`
82
- }
83
-
84
- type DebugConfig struct {
85
- Enabled bool `mapstructure:"enabled"`
86
- PrintRequest bool `mapstructure:"print_request"`
87
- PrintResponse bool `mapstructure:"print_response"`
88
- MaxContentLength int `mapstructure:"max_content_length"`
89
- }
90
-
91
- type ProxyConfig struct {
92
- Enabled bool `mapstructure:"enabled"`
93
- Default string `mapstructure:"default"`
94
- AllowInsecure bool `mapstructure:"allow_insecure"`
95
- }
96
-
97
- type GlobalConfig struct {
98
- MaxRetries int `mapstructure:"max_retries"`
99
- DefaultTimeout int `mapstructure:"default_timeout"`
100
- ErrorCodes struct {
101
- RetryOn []int `mapstructure:"retry_on"`
102
- } `mapstructure:"error_codes"`
103
- Log LogConfig `mapstructure:"log"`
104
- Server ServerConfig `mapstructure:"server"`
105
- Proxy ProxyConfig `mapstructure:"proxy"`
106
- ConfigPaths []string `mapstructure:"config_paths"`
107
- Thinking ThinkingConfig `mapstructure:"thinking"`
108
- }
109
-
110
- type ServerConfig struct {
111
- Port int `mapstructure:"port"`
112
- Host string `mapstructure:"host"`
113
- ReadTimeout int `mapstructure:"read_timeout"`
114
- WriteTimeout int `mapstructure:"write_timeout"`
115
- IdleTimeout int `mapstructure:"idle_timeout"`
116
- }
117
-
118
- type ThinkingConfig struct {
119
- Enabled bool `mapstructure:"enabled"`
120
- AddToAllRequests bool `mapstructure:"add_to_all_requests"`
121
- Timeout int `mapstructure:"timeout"`
122
- }
123
-
124
- // ---------------------- API 相关结构 ----------------------
125
-
126
- type ChatCompletionRequest struct {
127
- Model string `json:"model"`
128
- Messages []ChatCompletionMessage `json:"messages"`
129
- Temperature float64 `json:"temperature,omitempty"`
130
- MaxTokens int `json:"max_tokens,omitempty"`
131
- Stream bool `json:"stream,omitempty"`
132
- APIKey string `json:"-"` // 内部传递,不序列化
133
- }
134
-
135
- type ChatCompletionMessage struct {
136
- Role string `json:"role"`
137
- Content string `json:"content"`
138
- ReasoningContent interface{} `json:"reasoning_content,omitempty"`
139
- }
140
-
141
- type ChatCompletionResponse struct {
142
- ID string `json:"id"`
143
- Object string `json:"object"`
144
- Created int64 `json:"created"`
145
- Model string `json:"model"`
146
- Choices []Choice `json:"choices"`
147
- Usage Usage `json:"usage"`
148
- }
149
-
150
- type Choice struct {
151
- Index int `json:"index"`
152
- Message ChatCompletionMessage `json:"message"`
153
- FinishReason string `json:"finish_reason"`
154
- }
155
-
156
- type Usage struct {
157
- PromptTokens int `json:"prompt_tokens"`
158
- CompletionTokens int `json:"completion_tokens"`
159
- TotalTokens int `json:"total_tokens"`
160
- }
161
-
162
- // ---------------------- 日志工具 ----------------------
163
-
164
- type RequestLogger struct {
165
- RequestID string
166
- Model string
167
- StartTime time.Time
168
- logs []string
169
- config *Config
170
- }
171
-
172
- func NewRequestLogger(config *Config) *RequestLogger {
173
- return &RequestLogger{
174
- RequestID: uuid.New().String(),
175
- StartTime: time.Now(),
176
- logs: make([]string, 0),
177
- config: config,
178
- }
179
- }
180
-
181
- func (l *RequestLogger) Log(format string, args ...interface{}) {
182
- msg := fmt.Sprintf(format, args...)
183
- l.logs = append(l.logs, fmt.Sprintf("[%s] %s", time.Now().Format(time.RFC3339), msg))
184
- log.Printf("[RequestID: %s] %s", l.RequestID, msg)
185
- }
186
-
187
- func (l *RequestLogger) LogContent(contentType string, content interface{}, maxLength int) {
188
- if !l.config.Global.Log.Debug.Enabled {
189
- return
190
- }
191
- sanitizedContent := sanitizeJSON(content)
192
- truncatedContent := truncateContent(sanitizedContent, maxLength)
193
- l.Log("%s Content:\n%s", contentType, truncatedContent)
194
- }
195
-
196
- func truncateContent(content string, maxLength int) string {
197
- if len(content) <= maxLength {
198
- return content
199
- }
200
- return content[:maxLength] + "... (truncated)"
201
- }
202
-
203
- func sanitizeJSON(data interface{}) string {
204
- sanitized, err := json.Marshal(data)
205
- if err != nil {
206
- return "Failed to marshal JSON"
207
- }
208
- content := string(sanitized)
209
- sensitivePattern := `"api_key":\s*"[^"]*"`
210
- content = regexp.MustCompile(sensitivePattern).ReplaceAllString(content, `"api_key":"****"`)
211
- return content
212
- }
213
-
214
- func extractRealAPIKey(fullKey string) string {
215
- parts := strings.Split(fullKey, "-")
216
- if len(parts) >= 3 && (parts[0] == "deep" || parts[0] == "openai") {
217
- return strings.Join(parts[2:], "-")
218
- }
219
- return fullKey
220
- }
221
-
222
- func extractChannelID(fullKey string) string {
223
- parts := strings.Split(fullKey, "-")
224
- if len(parts) >= 2 && (parts[0] == "deep" || parts[0] == "openai") {
225
- return parts[1]
226
- }
227
- return "1" // 默认渠道
228
- }
229
-
230
- func logAPIKey(key string) string {
231
- if len(key) <= 8 {
232
- return "****"
233
- }
234
- return key[:4] + "..." + key[len(key)-4:]
235
- }
236
-
237
- // ---------------------- Server 结构 ----------------------
238
-
239
- type Server struct {
240
- config *Config
241
- srv *http.Server
242
- }
243
-
244
- var (
245
- randMu sync.Mutex
246
- randGen = rand.New(rand.NewSource(time.Now().UnixNano()))
247
- )
248
-
249
- func NewServer(config *Config) *Server {
250
- return &Server{
251
- config: config,
252
- }
253
- }
254
-
255
- func (s *Server) Start() error {
256
- mux := http.NewServeMux()
257
- mux.HandleFunc("/v1/chat/completions", s.handleOpenAIRequests)
258
- mux.HandleFunc("/v1/models", s.handleOpenAIRequests)
259
- mux.HandleFunc("/health", s.handleHealth)
260
-
261
- s.srv = &http.Server{
262
- Addr: fmt.Sprintf("%s:%d", s.config.Global.Server.Host, s.config.Global.Server.Port),
263
- Handler: mux,
264
- ReadTimeout: time.Duration(s.config.Global.Server.ReadTimeout) * time.Second,
265
- WriteTimeout: time.Duration(s.config.Global.Server.WriteTimeout) * time.Second,
266
- IdleTimeout: time.Duration(s.config.Global.Server.IdleTimeout) * time.Second,
267
- }
268
-
269
- log.Printf("Server starting on %s\n", s.srv.Addr)
270
- return s.srv.ListenAndServe()
271
- }
272
-
273
- func (s *Server) Shutdown(ctx context.Context) error {
274
- return s.srv.Shutdown(ctx)
275
- }
276
-
277
- func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
278
- w.WriteHeader(http.StatusOK)
279
- json.NewEncoder(w).Encode(map[string]string{"status": "healthy"})
280
- }
281
-
282
- func (s *Server) handleOpenAIRequests(w http.ResponseWriter, r *http.Request) {
283
- logger := NewRequestLogger(s.config)
284
-
285
- fullAPIKey := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
286
- apiKey := extractRealAPIKey(fullAPIKey)
287
- channelID := extractChannelID(fullAPIKey)
288
-
289
- logger.Log("Received request for %s with API Key: %s", r.URL.Path, logAPIKey(fullAPIKey))
290
- logger.Log("Extracted channel ID: %s", channelID)
291
- logger.Log("Extracted real API Key: %s", logAPIKey(apiKey))
292
-
293
- targetChannel, ok := s.config.Channels[channelID]
294
- if !ok {
295
- http.Error(w, "Invalid channel", http.StatusBadRequest)
296
- return
297
- }
298
-
299
- if r.URL.Path == "/v1/models" {
300
- if r.Method != http.MethodGet {
301
- http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
302
- return
303
- }
304
- req := &ChatCompletionRequest{APIKey: apiKey}
305
- s.forwardModelsRequest(w, r.Context(), req, targetChannel)
306
- return
307
- }
308
-
309
- if r.Method != http.MethodPost {
310
- http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
311
- return
312
- }
313
-
314
- body, err := io.ReadAll(r.Body)
315
- if err != nil {
316
- logger.Log("Error reading request body: %v", err)
317
- http.Error(w, "Failed to read request", http.StatusBadRequest)
318
- return
319
- }
320
- r.Body.Close()
321
- r.Body = io.NopCloser(bytes.NewBuffer(body))
322
-
323
- if s.config.Global.Log.Debug.PrintRequest {
324
- logger.LogContent("Request", string(body), s.config.Global.Log.Debug.MaxContentLength)
325
- }
326
-
327
- var req ChatCompletionRequest
328
- if err := json.NewDecoder(bytes.NewBuffer(body)).Decode(&req); err != nil {
329
- http.Error(w, "Invalid request body", http.StatusBadRequest)
330
- return
331
- }
332
- req.APIKey = apiKey
333
-
334
- thinkingService := s.getWeightedRandomThinkingService()
335
- logger.Log("Using thinking service: %s with API Key: %s", thinkingService.Name, logAPIKey(thinkingService.APIKey))
336
-
337
- if req.Stream {
338
- handler, err := NewStreamHandler(w, thinkingService, targetChannel, s.config)
339
- if err != nil {
340
- http.Error(w, "Streaming not supported", http.StatusInternalServerError)
341
- return
342
- }
343
- if err := handler.HandleRequest(r.Context(), &req); err != nil {
344
- logger.Log("Stream handler error: %v", err)
345
- }
346
- } else {
347
- thinkingResp, err := s.processThinkingContent(r.Context(), &req, thinkingService)
348
- if err != nil {
349
- logger.Log("Error processing thinking content: %v", err)
350
- http.Error(w, "Thinking service error: "+err.Error(), http.StatusInternalServerError)
351
- return
352
- }
353
- enhancedReq := s.prepareEnhancedRequest(&req, thinkingResp, thinkingService)
354
- s.forwardRequest(w, r.Context(), enhancedReq, targetChannel)
355
- }
356
- }
357
-
358
- func (s *Server) getWeightedRandomThinkingService() ThinkingService {
359
- thinkingServices := s.config.ThinkingServices
360
- if len(thinkingServices) == 0 {
361
- return ThinkingService{}
362
- }
363
- totalWeight := 0
364
- for _, svc := range thinkingServices {
365
- totalWeight += svc.Weight
366
- }
367
- if totalWeight <= 0 {
368
- log.Println("Warning: Total weight of thinking services is not positive, using first service as default.")
369
- return thinkingServices[0]
370
- }
371
- randMu.Lock()
372
- randNum := randGen.Intn(totalWeight)
373
- randMu.Unlock()
374
- currentSum := 0
375
- for _, svc := range thinkingServices {
376
- currentSum += svc.Weight
377
- if randNum < currentSum {
378
- return svc
379
- }
380
- }
381
- return thinkingServices[0]
382
- }
383
-
384
- // ---------------------- 非流式处理思考服务 ----------------------
385
-
386
- type ThinkingResponse struct {
387
- Content string
388
- ReasoningContent string
389
- }
390
-
391
- func (s *Server) processThinkingContent(ctx context.Context, req *ChatCompletionRequest, svc ThinkingService) (*ThinkingResponse, error) {
392
- logger := NewRequestLogger(s.config)
393
- log.Printf("Getting thinking content from service: %s (mode=%s)", svc.Name, svc.Mode)
394
-
395
- thinkingReq := *req
396
- thinkingReq.Model = svc.Model
397
- thinkingReq.APIKey = svc.APIKey
398
-
399
- var systemPrompt string
400
- if svc.Mode == "full" {
401
- systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly."
402
- } else {
403
- systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step."
404
- }
405
- thinkingReq.Messages = append([]ChatCompletionMessage{
406
- {Role: "system", Content: systemPrompt},
407
- }, thinkingReq.Messages...)
408
-
409
- temp := 0.7
410
- if svc.Temperature != nil {
411
- temp = *svc.Temperature
412
- }
413
- payload := map[string]interface{}{
414
- "model": svc.Model,
415
- "messages": thinkingReq.Messages,
416
- "stream": false,
417
- "temperature": temp,
418
- }
419
- if isValidReasoningEffort(svc.ReasoningEffort) {
420
- payload["reasoning_effort"] = svc.ReasoningEffort
421
- }
422
- if isValidReasoningFormat(svc.ReasoningFormat) {
423
- payload["reasoning_format"] = svc.ReasoningFormat
424
- }
425
-
426
- if s.config.Global.Log.Debug.PrintRequest {
427
- logger.LogContent("Thinking Service Request", payload, s.config.Global.Log.Debug.MaxContentLength)
428
- }
429
-
430
- jsonData, err := json.Marshal(payload)
431
- if err != nil {
432
- return nil, fmt.Errorf("failed to marshal thinking request: %v", err)
433
- }
434
- client, err := createHTTPClient(svc.Proxy, time.Duration(svc.Timeout)*time.Second)
435
- if err != nil {
436
- return nil, fmt.Errorf("failed to create HTTP client: %v", err)
437
- }
438
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, svc.GetFullURL(), bytes.NewBuffer(jsonData))
439
- if err != nil {
440
- return nil, fmt.Errorf("failed to create request: %v", err)
441
- }
442
- httpReq.Header.Set("Content-Type", "application/json")
443
- httpReq.Header.Set("Authorization", "Bearer "+svc.APIKey)
444
-
445
- resp, err := client.Do(httpReq)
446
- if err != nil {
447
- return nil, fmt.Errorf("failed to send thinking request: %v", err)
448
- }
449
- defer resp.Body.Close()
450
-
451
- respBody, err := io.ReadAll(resp.Body)
452
- if err != nil {
453
- return nil, fmt.Errorf("failed to read response body: %v", err)
454
- }
455
- if s.config.Global.Log.Debug.PrintResponse {
456
- logger.LogContent("Thinking Service Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
457
- }
458
- if resp.StatusCode != http.StatusOK {
459
- return nil, fmt.Errorf("thinking service returned %d: %s", resp.StatusCode, string(respBody))
460
- }
461
-
462
- var thinkingResp ChatCompletionResponse
463
- if err := json.Unmarshal(respBody, &thinkingResp); err != nil {
464
- return nil, fmt.Errorf("failed to unmarshal thinking response: %v", err)
465
- }
466
- if len(thinkingResp.Choices) == 0 {
467
- return nil, fmt.Errorf("thinking service returned no choices")
468
- }
469
-
470
- result := &ThinkingResponse{}
471
- choice := thinkingResp.Choices[0]
472
-
473
- if svc.Mode == "full" {
474
- result.ReasoningContent = choice.Message.Content
475
- result.Content = "Based on the above detailed analysis."
476
- } else {
477
- if choice.Message.ReasoningContent != nil {
478
- switch v := choice.Message.ReasoningContent.(type) {
479
- case string:
480
- result.ReasoningContent = v
481
- case map[string]interface{}:
482
- if j, err := json.Marshal(v); err == nil {
483
- result.ReasoningContent = string(j)
484
- }
485
- }
486
- }
487
- if result.ReasoningContent == "" {
488
- result.ReasoningContent = choice.Message.Content
489
- }
490
- result.Content = "Based on the above reasoning."
491
- }
492
- return result, nil
493
- }
494
-
495
- func (s *Server) prepareEnhancedRequest(originalReq *ChatCompletionRequest, thinkingResp *ThinkingResponse, svc ThinkingService) *ChatCompletionRequest {
496
- newReq := *originalReq
497
- var systemPrompt string
498
- if svc.Mode == "full" {
499
- systemPrompt = fmt.Sprintf(`Consider the following detailed analysis (not shown to user):
500
- %s
501
-
502
- Provide a clear, concise response that incorporates insights from this analysis.`, thinkingResp.ReasoningContent)
503
- } else {
504
- systemPrompt = fmt.Sprintf(`Previous thinking process:
505
- %s
506
- Please consider the above thinking process in your response.`, thinkingResp.ReasoningContent)
507
- }
508
- newReq.Messages = append([]ChatCompletionMessage{
509
- {Role: "system", Content: systemPrompt},
510
- }, newReq.Messages...)
511
- return &newReq
512
- }
513
-
514
- func (s *Server) forwardRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, channel Channel) {
515
- logger := NewRequestLogger(s.config)
516
- if s.config.Global.Log.Debug.PrintRequest {
517
- logger.LogContent("Forward Request", req, s.config.Global.Log.Debug.MaxContentLength)
518
- }
519
- jsonData, err := json.Marshal(req)
520
- if err != nil {
521
- http.Error(w, "Failed to marshal request", http.StatusInternalServerError)
522
- return
523
- }
524
-
525
- client, err := createHTTPClient(channel.Proxy, time.Duration(channel.Timeout)*time.Second)
526
- if err != nil {
527
- http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError)
528
- return
529
- }
530
-
531
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, channel.GetFullURL(), bytes.NewBuffer(jsonData))
532
- if err != nil {
533
- http.Error(w, "Failed to create request", http.StatusInternalServerError)
534
- return
535
- }
536
- httpReq.Header.Set("Content-Type", "application/json")
537
- httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
538
-
539
- resp, err := client.Do(httpReq)
540
- if err != nil {
541
- http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError)
542
- return
543
- }
544
- defer resp.Body.Close()
545
-
546
- respBody, err := io.ReadAll(resp.Body)
547
- if err != nil {
548
- http.Error(w, "Failed to read response", http.StatusInternalServerError)
549
- return
550
- }
551
- if s.config.Global.Log.Debug.PrintResponse {
552
- logger.LogContent("Forward Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
553
- }
554
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
555
- http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode)
556
- return
557
- }
558
-
559
- for k, vals := range resp.Header {
560
- for _, v := range vals {
561
- w.Header().Add(k, v)
562
- }
563
- }
564
- w.WriteHeader(resp.StatusCode)
565
- w.Write(respBody)
566
- }
567
-
568
- func (s *Server) forwardModelsRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, targetChannel Channel) {
569
- logger := NewRequestLogger(s.config)
570
- if s.config.Global.Log.Debug.PrintRequest {
571
- logger.LogContent("/v1/models Request", req, s.config.Global.Log.Debug.MaxContentLength)
572
- }
573
- fullChatURL := targetChannel.GetFullURL()
574
- parsedURL, err := url.Parse(fullChatURL)
575
- if err != nil {
576
- http.Error(w, "Failed to parse channel URL", http.StatusInternalServerError)
577
- return
578
- }
579
- baseURL := parsedURL.Scheme + "://" + parsedURL.Host
580
- modelsURL := strings.TrimSuffix(baseURL, "/") + "/v1/models"
581
-
582
- client, err := createHTTPClient(targetChannel.Proxy, time.Duration(targetChannel.Timeout)*time.Second)
583
- if err != nil {
584
- http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError)
585
- return
586
- }
587
-
588
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil)
589
- if err != nil {
590
- http.Error(w, "Failed to create request", http.StatusInternalServerError)
591
- return
592
- }
593
- httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
594
-
595
- resp, err := client.Do(httpReq)
596
- if err != nil {
597
- http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError)
598
- return
599
- }
600
- defer resp.Body.Close()
601
-
602
- respBody, err := io.ReadAll(resp.Body)
603
- if err != nil {
604
- http.Error(w, "Failed to read response", http.StatusInternalServerError)
605
- return
606
- }
607
- if s.config.Global.Log.Debug.PrintResponse {
608
- logger.LogContent("/v1/models Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
609
- }
610
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
611
- http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode)
612
- return
613
- }
614
-
615
- for k, vals := range resp.Header {
616
- for _, v := range vals {
617
- w.Header().Add(k, v)
618
- }
619
- }
620
- w.WriteHeader(resp.StatusCode)
621
- w.Write(respBody)
622
- }
623
-
624
- // ---------------------- 流式处理 ----------------------
625
-
626
- // collectedReasoningBuffer 用于收集思考服务返回的 reasoning_content(标准模式只收集 reasoning_content;full 模式收集全部)
627
- type collectedReasoningBuffer struct {
628
- builder strings.Builder
629
- mode string
630
- }
631
-
632
- func (cb *collectedReasoningBuffer) append(text string) {
633
- cb.builder.WriteString(text)
634
- }
635
-
636
- func (cb *collectedReasoningBuffer) get() string {
637
- return cb.builder.String()
638
- }
639
-
640
- type StreamHandler struct {
641
- thinkingService ThinkingService
642
- targetChannel Channel
643
- writer http.ResponseWriter
644
- flusher http.Flusher
645
- config *Config
646
- }
647
-
648
- func NewStreamHandler(w http.ResponseWriter, thinkingService ThinkingService, targetChannel Channel, config *Config) (*StreamHandler, error) {
649
- flusher, ok := w.(http.Flusher)
650
- if !ok {
651
- return nil, fmt.Errorf("streaming not supported by this writer")
652
- }
653
- return &StreamHandler{
654
- thinkingService: thinkingService,
655
- targetChannel: targetChannel,
656
- writer: w,
657
- flusher: flusher,
658
- config: config,
659
- }, nil
660
- }
661
-
662
- // HandleRequest 分两阶段:先流式接收思考服务 SSE 并转发给客户端(只转发包含 reasoning_content 的 chunk,标准模式遇到纯 content 则中断);
663
- // 然后使用收集到的 reasoning_content 构造最终请求,发起目标 Channel 的 SSE 并转发给客户端。
664
- func (h *StreamHandler) HandleRequest(ctx context.Context, req *ChatCompletionRequest) error {
665
- h.writer.Header().Set("Content-Type", "text/event-stream")
666
- h.writer.Header().Set("Cache-Control", "no-cache")
667
- h.writer.Header().Set("Connection", "keep-alive")
668
-
669
- logger := NewRequestLogger(h.config)
670
- reasonBuf := &collectedReasoningBuffer{mode: h.thinkingService.Mode}
671
-
672
- // 阶段一:流式接收思考服务 SSE
673
- if err := h.streamThinkingService(ctx, req, reasonBuf, logger); err != nil {
674
- return err
675
- }
676
-
677
- // 阶段二:构造最终请求,并流式转发目标 Channel SSE
678
- finalReq := h.prepareFinalRequest(req, reasonBuf.get())
679
- if err := h.streamFinalChannel(ctx, finalReq, logger); err != nil {
680
- return err
681
- }
682
- return nil
683
- }
684
-
685
- // streamThinkingService 连接思考服务 SSE,原样转发包含 reasoning_content 的 SSE chunk给客户端,同时收集 reasoning_content。
686
- // 在标准模式下,一旦遇到只返回 non-empty 的 content(且 reasoning_content 为空),则立即中断,不转发该 chunk。
687
- func (h *StreamHandler) streamThinkingService(ctx context.Context, req *ChatCompletionRequest, reasonBuf *collectedReasoningBuffer, logger *RequestLogger) error {
688
- thinkingReq := *req
689
- thinkingReq.Model = h.thinkingService.Model
690
- thinkingReq.APIKey = h.thinkingService.APIKey
691
-
692
- var systemPrompt string
693
- if h.thinkingService.Mode == "full" {
694
- systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly."
695
- } else {
696
- systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step."
697
- }
698
- messages := append([]ChatCompletionMessage{
699
- {Role: "system", Content: systemPrompt},
700
- }, thinkingReq.Messages...)
701
-
702
- temp := 0.7
703
- if h.thinkingService.Temperature != nil {
704
- temp = *h.thinkingService.Temperature
705
- }
706
- bodyMap := map[string]interface{}{
707
- "model": thinkingReq.Model,
708
- "messages": messages,
709
- "temperature": temp,
710
- "stream": true,
711
- }
712
- if isValidReasoningEffort(h.thinkingService.ReasoningEffort) {
713
- bodyMap["reasoning_effort"] = h.thinkingService.ReasoningEffort
714
- }
715
- if isValidReasoningFormat(h.thinkingService.ReasoningFormat) {
716
- bodyMap["reasoning_format"] = h.thinkingService.ReasoningFormat
717
- }
718
-
719
- jsonData, err := json.Marshal(bodyMap)
720
- if err != nil {
721
- return err
722
- }
723
-
724
- client, err := createHTTPClient(h.thinkingService.Proxy, time.Duration(h.thinkingService.Timeout)*time.Second)
725
- if err != nil {
726
- return err
727
- }
728
-
729
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.thinkingService.GetFullURL(), bytes.NewBuffer(jsonData))
730
- if err != nil {
731
- return err
732
- }
733
- httpReq.Header.Set("Content-Type", "application/json")
734
- httpReq.Header.Set("Authorization", "Bearer "+h.thinkingService.APIKey)
735
-
736
- log.Printf("Starting ThinkingService SSE => %s (mode=%s, force_stop=%v)", h.thinkingService.GetFullURL(), h.thinkingService.Mode, h.thinkingService.ForceStopDeepThinking)
737
- resp, err := client.Do(httpReq)
738
- if err != nil {
739
- return err
740
- }
741
- defer resp.Body.Close()
742
-
743
- if resp.StatusCode != http.StatusOK {
744
- b, _ := io.ReadAll(resp.Body)
745
- return fmt.Errorf("thinking service status=%d, body=%s", resp.StatusCode, string(b))
746
- }
747
-
748
- reader := bufio.NewReader(resp.Body)
749
- var lastLine string
750
- // 标识是否需中断思考 SSE(标准模式下)
751
- forceStop := false
752
-
753
- for {
754
- select {
755
- case <-ctx.Done():
756
- return ctx.Err()
757
- default:
758
- }
759
- line, err := reader.ReadString('\n')
760
- if err != nil {
761
- if err == io.EOF {
762
- break
763
- }
764
- return err
765
- }
766
- line = strings.TrimSpace(line)
767
- if line == "" || line == lastLine {
768
- continue
769
- }
770
- lastLine = line
771
-
772
- if !strings.HasPrefix(line, "data: ") {
773
- continue
774
- }
775
- dataPart := strings.TrimPrefix(line, "data: ")
776
- if dataPart == "[DONE]" {
777
- // 思考服务结束:直接中断(不发送特殊标记)
778
- break
779
- }
780
-
781
- var chunk struct {
782
- Choices []struct {
783
- Delta struct {
784
- Content string `json:"content,omitempty"`
785
- ReasoningContent string `json:"reasoning_content,omitempty"`
786
- } `json:"delta"`
787
- FinishReason *string `json:"finish_reason,omitempty"`
788
- } `json:"choices"`
789
- }
790
- if err := json.Unmarshal([]byte(dataPart), &chunk); err != nil {
791
- // 如果解析失败,直接原样转发
792
- h.writer.Write([]byte("data: " + dataPart + "\n\n"))
793
- h.flusher.Flush()
794
- continue
795
- }
796
-
797
- if len(chunk.Choices) > 0 {
798
- c := chunk.Choices[0]
799
- if h.config.Global.Log.Debug.PrintResponse {
800
- logger.LogContent("Thinking SSE chunk", chunk, h.config.Global.Log.Debug.MaxContentLength)
801
- }
802
-
803
- if h.thinkingService.Mode == "full" {
804
- // full 模式:收集所有内容
805
- if c.Delta.ReasoningContent != "" {
806
- reasonBuf.append(c.Delta.ReasoningContent)
807
- }
808
- if c.Delta.Content != "" {
809
- reasonBuf.append(c.Delta.Content)
810
- }
811
- // 原样转发整 chunk
812
- forwardLine := "data: " + dataPart + "\n\n"
813
- h.writer.Write([]byte(forwardLine))
814
- h.flusher.Flush()
815
- } else {
816
- // standard 模式:只收集 reasoning_content
817
- if c.Delta.ReasoningContent != "" {
818
- reasonBuf.append(c.Delta.ReasoningContent)
819
- // 转发该 chunk
820
- forwardLine := "data: " + dataPart + "\n\n"
821
- h.writer.Write([]byte(forwardLine))
822
- h.flusher.Flush()
823
- }
824
- // 如果遇到 chunk 中只有 content(且 reasoning_content 为空),认为思考链结束,不转发该 chunk
825
- if c.Delta.Content != "" && strings.TrimSpace(c.Delta.ReasoningContent) == "" {
826
- forceStop = true
827
- }
828
- }
829
-
830
- // 如果 finishReason 非空,也认为结束
831
- if c.FinishReason != nil && *c.FinishReason != "" {
832
- forceStop = true
833
- }
834
- }
835
-
836
- if forceStop {
837
- break
838
- }
839
- }
840
- // 读空剩余数据
841
- io.Copy(io.Discard, reader)
842
- return nil
843
- }
844
-
845
- func (h *StreamHandler) prepareFinalRequest(originalReq *ChatCompletionRequest, reasoningCollected string) *ChatCompletionRequest {
846
- req := *originalReq
847
- var systemPrompt string
848
- if h.thinkingService.Mode == "full" {
849
- systemPrompt = fmt.Sprintf(
850
- `Consider the following detailed analysis (not shown to user):
851
- %s
852
-
853
- Provide a clear, concise response that incorporates insights from this analysis.`,
854
- reasoningCollected,
855
- )
856
- } else {
857
- systemPrompt = fmt.Sprintf(
858
- `Previous thinking process:
859
- %s
860
- Please consider the above thinking process in your response.`,
861
- reasoningCollected,
862
- )
863
- }
864
- req.Messages = append([]ChatCompletionMessage{
865
- {Role: "system", Content: systemPrompt},
866
- }, req.Messages...)
867
- return &req
868
- }
869
-
870
- func (h *StreamHandler) streamFinalChannel(ctx context.Context, req *ChatCompletionRequest, logger *RequestLogger) error {
871
- if h.config.Global.Log.Debug.PrintRequest {
872
- logger.LogContent("Final Request to Channel", req, h.config.Global.Log.Debug.MaxContentLength)
873
- }
874
-
875
- jsonData, err := json.Marshal(req)
876
- if err != nil {
877
- return err
878
- }
879
- client, err := createHTTPClient(h.targetChannel.Proxy, time.Duration(h.targetChannel.Timeout)*time.Second)
880
- if err != nil {
881
- return err
882
- }
883
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.targetChannel.GetFullURL(), bytes.NewBuffer(jsonData))
884
- if err != nil {
885
- return err
886
- }
887
- httpReq.Header.Set("Content-Type", "application/json")
888
- httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
889
-
890
- resp, err := client.Do(httpReq)
891
- if err != nil {
892
- return err
893
- }
894
- defer resp.Body.Close()
895
-
896
- if resp.StatusCode != http.StatusOK {
897
- b, _ := io.ReadAll(resp.Body)
898
- return fmt.Errorf("target channel status=%d, body=%s", resp.StatusCode, string(b))
899
- }
900
-
901
- reader := bufio.NewReader(resp.Body)
902
- var lastLine string
903
-
904
- for {
905
- select {
906
- case <-ctx.Done():
907
- return ctx.Err()
908
- default:
909
- }
910
- line, err := reader.ReadString('\n')
911
- if err != nil {
912
- if err == io.EOF {
913
- break
914
- }
915
- return err
916
- }
917
- line = strings.TrimSpace(line)
918
- if line == "" || line == lastLine {
919
- continue
920
- }
921
- lastLine = line
922
-
923
- if !strings.HasPrefix(line, "data: ") {
924
- continue
925
- }
926
- data := strings.TrimPrefix(line, "data: ")
927
- if data == "[DONE]" {
928
- doneLine := "data: [DONE]\n\n"
929
- h.writer.Write([]byte(doneLine))
930
- h.flusher.Flush()
931
- break
932
- }
933
- forwardLine := "data: " + data + "\n\n"
934
- h.writer.Write([]byte(forwardLine))
935
- h.flusher.Flush()
936
-
937
- if h.config.Global.Log.Debug.PrintResponse {
938
- logger.LogContent("Channel SSE chunk", forwardLine, h.config.Global.Log.Debug.MaxContentLength)
939
- }
940
- }
941
- io.Copy(io.Discard, reader)
942
- return nil
943
- }
944
-
945
- // ---------------------- HTTP Client 工具 ----------------------
946
-
947
- func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) {
948
- transport := &http.Transport{
949
- DialContext: (&net.Dialer{
950
- Timeout: 30 * time.Second,
951
- KeepAlive: 30 * time.Second,
952
- }).DialContext,
953
- ForceAttemptHTTP2: true,
954
- MaxIdleConns: 100,
955
- IdleConnTimeout: 90 * time.Second,
956
- TLSHandshakeTimeout: 10 * time.Second,
957
- ExpectContinueTimeout: 1 * time.Second,
958
- }
959
- if proxyURL != "" {
960
- parsedURL, err := url.Parse(proxyURL)
961
- if err != nil {
962
- return nil, fmt.Errorf("invalid proxy URL: %v", err)
963
- }
964
- switch parsedURL.Scheme {
965
- case "http", "https":
966
- transport.Proxy = http.ProxyURL(parsedURL)
967
- case "socks5":
968
- dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
969
- if err != nil {
970
- return nil, fmt.Errorf("failed to create SOCKS5 dialer: %v", err)
971
- }
972
- transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
973
- return dialer.Dial(network, addr)
974
- }
975
- default:
976
- return nil, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
977
- }
978
- }
979
- return &http.Client{
980
- Transport: transport,
981
- Timeout: timeout,
982
- }, nil
983
- }
984
-
985
- func maskSensitiveHeaders(headers http.Header) http.Header {
986
- masked := make(http.Header)
987
- for k, vals := range headers {
988
- if strings.ToLower(k) == "authorization" {
989
- masked[k] = []string{"Bearer ****"}
990
- } else {
991
- masked[k] = vals
992
- }
993
- }
994
- return masked
995
- }
996
-
997
- func isValidReasoningEffort(effort string) bool {
998
- switch strings.ToLower(effort) {
999
- case "low", "medium", "high":
1000
- return true
1001
- }
1002
- return false
1003
- }
1004
-
1005
- func isValidReasoningFormat(format string) bool {
1006
- switch strings.ToLower(format) {
1007
- case "parsed", "raw", "hidden":
1008
- return true
1009
- }
1010
- return false
1011
- }
1012
-
1013
- // ---------------------- 配置加载与验证 ----------------------
1014
-
1015
- func loadConfig() (*Config, error) {
1016
- var configFile string
1017
- flag.StringVar(&configFile, "config", "", "path to config file")
1018
- flag.Parse()
1019
- viper.SetConfigType("yaml")
1020
-
1021
- if configFile != "" {
1022
- viper.SetConfigFile(configFile)
1023
- } else {
1024
- ex, err := os.Executable()
1025
- if err != nil {
1026
- return nil, err
1027
- }
1028
- exePath := filepath.Dir(ex)
1029
- defaultPaths := []string{
1030
- filepath.Join(exePath, "config.yaml"),
1031
- filepath.Join(exePath, "conf", "config.yaml"),
1032
- "./config.yaml",
1033
- "./conf/config.yaml",
1034
- }
1035
- if os.PathSeparator == '\\' {
1036
- programData := os.Getenv("PROGRAMDATA")
1037
- if programData != "" {
1038
- defaultPaths = append(defaultPaths, filepath.Join(programData, "DeepAI", "config.yaml"))
1039
- }
1040
- } else {
1041
- defaultPaths = append(defaultPaths, "/etc/deepai/config.yaml")
1042
- }
1043
- for _, p := range defaultPaths {
1044
- viper.AddConfigPath(filepath.Dir(p))
1045
- if strings.Contains(p, ".yaml") {
1046
- viper.SetConfigName(strings.TrimSuffix(filepath.Base(p), ".yaml"))
1047
- }
1048
- }
1049
- }
1050
-
1051
- if err := viper.ReadInConfig(); err != nil {
1052
- return nil, fmt.Errorf("failed to read config: %v", err)
1053
- }
1054
-
1055
- var cfg Config
1056
- if err := viper.Unmarshal(&cfg); err != nil {
1057
- return nil, fmt.Errorf("unmarshal config error: %v", err)
1058
- }
1059
- if err := validateConfig(&cfg); err != nil {
1060
- return nil, fmt.Errorf("config validation error: %v", err)
1061
- }
1062
- return &cfg, nil
1063
- }
1064
-
1065
- func validateConfig(config *Config) error {
1066
- if len(config.ThinkingServices) == 0 {
1067
- return fmt.Errorf("no thinking services configured")
1068
- }
1069
- if len(config.Channels) == 0 {
1070
- return fmt.Errorf("no channels configured")
1071
- }
1072
- for i, svc := range config.ThinkingServices {
1073
- if svc.BaseURL == "" {
1074
- return fmt.Errorf("thinking service %s has empty baseURL", svc.Name)
1075
- }
1076
- if svc.APIKey == "" {
1077
- return fmt.Errorf("thinking service %s has empty apiKey", svc.Name)
1078
- }
1079
- if svc.Timeout <= 0 {
1080
- return fmt.Errorf("thinking service %s has invalid timeout", svc.Name)
1081
- }
1082
- if svc.Model == "" {
1083
- return fmt.Errorf("thinking service %s has empty model", svc.Name)
1084
- }
1085
- if svc.Mode == "" {
1086
- config.ThinkingServices[i].Mode = "standard"
1087
- } else if svc.Mode != "standard" && svc.Mode != "full" {
1088
- return fmt.Errorf("thinking service %s unknown mode=%s", svc.Name, svc.Mode)
1089
- }
1090
- }
1091
- return nil
1092
- }
1093
-
1094
- func main() {
1095
- log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile)
1096
-
1097
- cfg, err := loadConfig()
1098
- if err != nil {
1099
- log.Fatalf("Failed to load config: %v", err)
1100
- }
1101
- log.Printf("Using config file: %s", viper.ConfigFileUsed())
1102
-
1103
- server := NewServer(cfg)
1104
-
1105
- done := make(chan os.Signal, 1)
1106
- signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
1107
-
1108
- go func() {
1109
- if err := server.Start(); err != nil && err != http.ErrServerClosed {
1110
- log.Fatalf("start server error: %v", err)
1111
- }
1112
- }()
1113
- log.Printf("Server started successfully")
1114
-
1115
- <-done
1116
- log.Print("Server stopping...")
1117
-
1118
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
1119
- defer cancel()
1120
-
1121
- if err := server.Shutdown(ctx); err != nil {
1122
- log.Printf("Server forced to shutdown: %v", err)
1123
- }
1124
- log.Print("Server stopped")
1125
- }