nbugs commited on
Commit
f5a0922
·
verified ·
1 Parent(s): 467940e

Delete main.go

Browse files
Files changed (1) hide show
  1. main.go +0 -1184
main.go DELETED
@@ -1,1184 +0,0 @@
1
- package main
2
-
3
- import (
4
- "bufio"
5
- "bytes"
6
- "context"
7
- "encoding/json"
8
- "flag"
9
- "fmt"
10
- "io"
11
- "log"
12
- "math/rand"
13
- "net"
14
- "net/http"
15
- "net/url"
16
- "os"
17
- "os/signal"
18
- "path/filepath"
19
- "regexp"
20
- "strings"
21
- "sync"
22
- "syscall"
23
- "time"
24
-
25
- "github.com/google/uuid"
26
- "github.com/spf13/viper"
27
- "golang.org/x/net/proxy"
28
- )
29
-
30
- // ---------------------- 配置结构 ----------------------
31
-
32
- type Config struct {
33
- ThinkingServices []ThinkingService `mapstructure:"thinking_services"`
34
- Channels map[string]Channel `mapstructure:"channels"`
35
- Global GlobalConfig `mapstructure:"global"`
36
- }
37
-
38
- type ThinkingService struct {
39
- ID int `mapstructure:"id"`
40
- Name string `mapstructure:"name"`
41
- Model string `mapstructure:"model"`
42
- BaseURL string `mapstructure:"base_url"`
43
- APIPath string `mapstructure:"api_path"`
44
- APIKey string `mapstructure:"api_key"`
45
- Timeout int `mapstructure:"timeout"`
46
- Retry int `mapstructure:"retry"`
47
- Weight int `mapstructure:"weight"`
48
- Proxy string `mapstructure:"proxy"`
49
- Mode string `mapstructure:"mode"` // "standard" 或 "full"
50
- ReasoningEffort string `mapstructure:"reasoning_effort"`
51
- ReasoningFormat string `mapstructure:"reasoning_format"`
52
- Temperature *float64 `mapstructure:"temperature"`
53
- ForceStopDeepThinking bool `mapstructure:"force_stop_deep_thinking"` // 配置项:标准模式下遇到 content 时是否立即停止
54
- }
55
-
56
- func (s *ThinkingService) GetFullURL() string {
57
- return s.BaseURL + s.APIPath
58
- }
59
-
60
- type Channel struct {
61
- Name string `mapstructure:"name"`
62
- BaseURL string `mapstructure:"base_url"`
63
- APIPath string `mapstructure:"api_path"`
64
- Timeout int `mapstructure:"timeout"`
65
- Proxy string `mapstructure:"proxy"`
66
- }
67
-
68
- func (c *Channel) GetFullURL() string {
69
- return c.BaseURL + c.APIPath
70
- }
71
-
72
- type LogConfig struct {
73
- Level string `mapstructure:"level"`
74
- Format string `mapstructure:"format"`
75
- Output string `mapstructure:"output"`
76
- FilePath string `mapstructure:"file_path"`
77
- Debug DebugConfig `mapstructure:"debug"`
78
- }
79
-
80
- type DebugConfig struct {
81
- Enabled bool `mapstructure:"enabled"`
82
- PrintRequest bool `mapstructure:"print_request"`
83
- PrintResponse bool `mapstructure:"print_response"`
84
- MaxContentLength int `mapstructure:"max_content_length"`
85
- }
86
-
87
- type ProxyConfig struct {
88
- Enabled bool `mapstructure:"enabled"`
89
- Default string `mapstructure:"default"`
90
- AllowInsecure bool `mapstructure:"allow_insecure"`
91
- }
92
-
93
- type GlobalConfig struct {
94
- MaxRetries int `mapstructure:"max_retries"`
95
- DefaultTimeout int `mapstructure:"default_timeout"`
96
- ErrorCodes struct {
97
- RetryOn []int `mapstructure:"retry_on"`
98
- } `mapstructure:"error_codes"`
99
- Log LogConfig `mapstructure:"log"`
100
- Server ServerConfig `mapstructure:"server"`
101
- Proxy ProxyConfig `mapstructure:"proxy"`
102
- ConfigPaths []string `mapstructure:"config_paths"`
103
- Thinking ThinkingConfig `mapstructure:"thinking"`
104
- }
105
-
106
- type ServerConfig struct {
107
- Port int `mapstructure:"port"`
108
- Host string `mapstructure:"host"`
109
- ReadTimeout int `mapstructure:"read_timeout"`
110
- WriteTimeout int `mapstructure:"write_timeout"`
111
- IdleTimeout int `mapstructure:"idle_timeout"`
112
- }
113
-
114
- type ThinkingConfig struct {
115
- Enabled bool `mapstructure:"enabled"`
116
- AddToAllRequests bool `mapstructure:"add_to_all_requests"`
117
- Timeout int `mapstructure:"timeout"`
118
- }
119
-
120
- // ---------------------- API 相关结构 ----------------------
121
-
122
- type ChatCompletionRequest struct {
123
- Model string `json:"model"`
124
- Messages []ChatCompletionMessage `json:"messages"`
125
- Temperature float64 `json:"temperature,omitempty"`
126
- MaxTokens int `json:"max_tokens,omitempty"`
127
- Stream bool `json:"stream,omitempty"`
128
- APIKey string `json:"-"` // 内部传递,不序列化
129
- }
130
-
131
- type ChatCompletionMessage struct {
132
- Role string `json:"role"`
133
- Content string `json:"content"`
134
- ReasoningContent interface{} `json:"reasoning_content,omitempty"`
135
- }
136
-
137
- type ChatCompletionResponse struct {
138
- ID string `json:"id"`
139
- Object string `json:"object"`
140
- Created int64 `json:"created"`
141
- Model string `json:"model"`
142
- Choices []Choice `json:"choices"`
143
- Usage Usage `json:"usage"`
144
- }
145
-
146
- type Choice struct {
147
- Index int `json:"index"`
148
- Message ChatCompletionMessage `json:"message"`
149
- FinishReason string `json:"finish_reason"`
150
- }
151
-
152
- type Usage struct {
153
- PromptTokens int `json:"prompt_tokens"`
154
- CompletionTokens int `json:"completion_tokens"`
155
- TotalTokens int `json:"total_tokens"`
156
- }
157
-
158
- // ---------------------- 日志工具 ----------------------
159
-
160
- type RequestLogger struct {
161
- RequestID string
162
- Model string
163
- StartTime time.Time
164
- logs []string
165
- config *Config
166
- }
167
-
168
- func NewRequestLogger(config *Config) *RequestLogger {
169
- return &RequestLogger{
170
- RequestID: uuid.New().String(),
171
- StartTime: time.Now(),
172
- logs: make([]string, 0),
173
- config: config,
174
- }
175
- }
176
-
177
- func (l *RequestLogger) Log(format string, args ...interface{}) {
178
- msg := fmt.Sprintf(format, args...)
179
- l.logs = append(l.logs, fmt.Sprintf("[%s] %s", time.Now().Format(time.RFC3339), msg))
180
- log.Printf("[RequestID: %s] %s", l.RequestID, msg)
181
- }
182
-
183
- func (l *RequestLogger) LogContent(contentType string, content interface{}, maxLength int) {
184
- if !l.config.Global.Log.Debug.Enabled {
185
- return
186
- }
187
- sanitizedContent := sanitizeJSON(content)
188
- truncatedContent := truncateContent(sanitizedContent, maxLength)
189
- l.Log("%s Content:\n%s", contentType, truncatedContent)
190
- }
191
-
192
- func truncateContent(content string, maxLength int) string {
193
- if len(content) <= maxLength {
194
- return content
195
- }
196
- return content[:maxLength] + "... (truncated)"
197
- }
198
-
199
- func sanitizeJSON(data interface{}) string {
200
- sanitized, err := json.Marshal(data)
201
- if err != nil {
202
- return "Failed to marshal JSON"
203
- }
204
- content := string(sanitized)
205
- sensitivePattern := `"api_key":\s*"[^"]*"`
206
- content = regexp.MustCompile(sensitivePattern).ReplaceAllString(content, `"api_key":"****"`)
207
- return content
208
- }
209
-
210
- func extractRealAPIKey(fullKey string) string {
211
- parts := strings.Split(fullKey, "-")
212
- if len(parts) >= 3 && (parts[0] == "deep" || parts[0] == "openai") {
213
- return strings.Join(parts[2:], "-")
214
- }
215
- return fullKey
216
- }
217
-
218
- func extractChannelID(fullKey string) string {
219
- parts := strings.Split(fullKey, "-")
220
- if len(parts) >= 2 && (parts[0] == "deep" || parts[0] == "openai") {
221
- return parts[1]
222
- }
223
- return "1" // 默认渠道
224
- }
225
-
226
- func logAPIKey(key string) string {
227
- if len(key) <= 8 {
228
- return "****"
229
- }
230
- return key[:4] + "..." + key[len(key)-4:]
231
- }
232
-
233
- // ---------------------- Server 结构 ----------------------
234
-
235
- type Server struct {
236
- config *Config
237
- srv *http.Server
238
- }
239
-
240
- var (
241
- randMu sync.Mutex
242
- randGen = rand.New(rand.NewSource(time.Now().UnixNano()))
243
- )
244
-
245
- func NewServer(config *Config) *Server {
246
- return &Server{
247
- config: config,
248
- }
249
- }
250
-
251
- func (s *Server) Start() error {
252
- mux := http.NewServeMux()
253
- mux.HandleFunc("/v1/chat/completions", s.handleOpenAIRequests)
254
- mux.HandleFunc("/v1/models", s.handleOpenAIRequests)
255
- mux.HandleFunc("/health", s.handleHealth)
256
-
257
- s.srv = &http.Server{
258
- Addr: fmt.Sprintf("%s:%d", s.config.Global.Server.Host, s.config.Global.Server.Port),
259
- Handler: mux,
260
- ReadTimeout: time.Duration(s.config.Global.Server.ReadTimeout) * time.Second,
261
- WriteTimeout: time.Duration(s.config.Global.Server.WriteTimeout) * time.Second,
262
- IdleTimeout: time.Duration(s.config.Global.Server.IdleTimeout) * time.Second,
263
- }
264
-
265
- log.Printf("Server starting on %s\n", s.srv.Addr)
266
- return s.srv.ListenAndServe()
267
- }
268
-
269
- func (s *Server) Shutdown(ctx context.Context) error {
270
- return s.srv.Shutdown(ctx)
271
- }
272
-
273
- func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
274
- w.WriteHeader(http.StatusOK)
275
- json.NewEncoder(w).Encode(map[string]string{"status": "healthy"})
276
- }
277
-
278
- func (s *Server) handleOpenAIRequests(w http.ResponseWriter, r *http.Request) {
279
- logger := NewRequestLogger(s.config)
280
-
281
- fullAPIKey := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
282
- apiKey := extractRealAPIKey(fullAPIKey)
283
- channelID := extractChannelID(fullAPIKey)
284
-
285
- logger.Log("Received request for %s with API Key: %s", r.URL.Path, logAPIKey(fullAPIKey))
286
- logger.Log("Extracted channel ID: %s", channelID)
287
- logger.Log("Extracted real API Key: %s", logAPIKey(apiKey))
288
-
289
- targetChannel, ok := s.config.Channels[channelID]
290
- if !ok {
291
- http.Error(w, "Invalid channel", http.StatusBadRequest)
292
- return
293
- }
294
-
295
- if r.URL.Path == "/v1/models" {
296
- if r.Method != http.MethodGet {
297
- http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
298
- return
299
- }
300
- req := &ChatCompletionRequest{APIKey: apiKey}
301
- s.forwardModelsRequest(w, r.Context(), req, targetChannel)
302
- return
303
- }
304
-
305
- if r.Method != http.MethodPost {
306
- http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
307
- return
308
- }
309
-
310
- body, err := io.ReadAll(r.Body)
311
- if err != nil {
312
- logger.Log("Error reading request body: %v", err)
313
- http.Error(w, "Failed to read request", http.StatusBadRequest)
314
- return
315
- }
316
- r.Body.Close()
317
- r.Body = io.NopCloser(bytes.NewBuffer(body))
318
-
319
- if s.config.Global.Log.Debug.PrintRequest {
320
- logger.LogContent("Request", string(body), s.config.Global.Log.Debug.MaxContentLength)
321
- }
322
-
323
- var req ChatCompletionRequest
324
- if err := json.NewDecoder(bytes.NewBuffer(body)).Decode(&req); err != nil {
325
- http.Error(w, "Invalid request body", http.StatusBadRequest)
326
- return
327
- }
328
- req.APIKey = apiKey
329
-
330
- thinkingService := s.getWeightedRandomThinkingService()
331
- logger.Log("Using thinking service: %s with API Key: %s", thinkingService.Name, logAPIKey(thinkingService.APIKey))
332
-
333
- if req.Stream {
334
- handler, err := NewStreamHandler(w, thinkingService, targetChannel, s.config)
335
- if err != nil {
336
- http.Error(w, "Streaming not supported", http.StatusInternalServerError)
337
- return
338
- }
339
- if err := handler.HandleRequest(r.Context(), &req); err != nil {
340
- logger.Log("Stream handler error: %v", err)
341
- }
342
- } else {
343
- thinkingResp, err := s.processThinkingContent(r.Context(), &req, thinkingService)
344
- if err != nil {
345
- logger.Log("Error processing thinking content: %v", err)
346
- http.Error(w, "Thinking service error: "+err.Error(), http.StatusInternalServerError)
347
- return
348
- }
349
- enhancedReq := s.prepareEnhancedRequest(&req, thinkingResp, thinkingService)
350
- s.forwardRequest(w, r.Context(), enhancedReq, targetChannel)
351
- }
352
- }
353
-
354
- func (s *Server) getWeightedRandomThinkingService() ThinkingService {
355
- thinkingServices := s.config.ThinkingServices
356
- if len(thinkingServices) == 0 {
357
- return ThinkingService{}
358
- }
359
- totalWeight := 0
360
- for _, svc := range thinkingServices {
361
- totalWeight += svc.Weight
362
- }
363
- if totalWeight <= 0 {
364
- log.Println("Warning: Total weight of thinking services is not positive, using first service as default.")
365
- return thinkingServices[0]
366
- }
367
- randMu.Lock()
368
- randNum := randGen.Intn(totalWeight)
369
- randMu.Unlock()
370
- currentSum := 0
371
- for _, svc := range thinkingServices {
372
- currentSum += svc.Weight
373
- if randNum < currentSum {
374
- return svc
375
- }
376
- }
377
- return thinkingServices[0]
378
- }
379
-
380
- // ---------------------- 非流式处理思考服务 ----------------------
381
-
382
- type ThinkingResponse struct {
383
- Content string
384
- ReasoningContent string
385
- }
386
-
387
- func (s *Server) processThinkingContent(ctx context.Context, req *ChatCompletionRequest, svc ThinkingService) (*ThinkingResponse, error) {
388
- logger := NewRequestLogger(s.config)
389
- log.Printf("Getting thinking content from service: %s (mode=%s)", svc.Name, svc.Mode)
390
-
391
- thinkingReq := *req
392
- thinkingReq.Model = svc.Model
393
- thinkingReq.APIKey = svc.APIKey
394
-
395
- var systemPrompt string
396
- if svc.Mode == "full" {
397
- systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly."
398
- } else {
399
- systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step."
400
- }
401
- thinkingReq.Messages = append([]ChatCompletionMessage{
402
- {Role: "system", Content: systemPrompt},
403
- }, thinkingReq.Messages...)
404
-
405
- temp := 0.7
406
- if svc.Temperature != nil {
407
- temp = *svc.Temperature
408
- }
409
- payload := map[string]interface{}{
410
- "model": svc.Model,
411
- "messages": thinkingReq.Messages,
412
- "stream": false,
413
- "temperature": temp,
414
- }
415
- if isValidReasoningEffort(svc.ReasoningEffort) {
416
- payload["reasoning_effort"] = svc.ReasoningEffort
417
- }
418
- if isValidReasoningFormat(svc.ReasoningFormat) {
419
- payload["reasoning_format"] = svc.ReasoningFormat
420
- }
421
-
422
- if s.config.Global.Log.Debug.PrintRequest {
423
- logger.LogContent("Thinking Service Request", payload, s.config.Global.Log.Debug.MaxContentLength)
424
- }
425
-
426
- jsonData, err := json.Marshal(payload)
427
- if err != nil {
428
- return nil, fmt.Errorf("failed to marshal thinking request: %v", err)
429
- }
430
- client, err := createHTTPClient(svc.Proxy, time.Duration(svc.Timeout)*time.Second)
431
- if err != nil {
432
- return nil, fmt.Errorf("failed to create HTTP client: %v", err)
433
- }
434
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, svc.GetFullURL(), bytes.NewBuffer(jsonData))
435
- if err != nil {
436
- return nil, fmt.Errorf("failed to create request: %v", err)
437
- }
438
- httpReq.Header.Set("Content-Type", "application/json")
439
- httpReq.Header.Set("Authorization", "Bearer "+svc.APIKey)
440
-
441
- resp, err := client.Do(httpReq)
442
- if err != nil {
443
- return nil, fmt.Errorf("failed to send thinking request: %v", err)
444
- }
445
- defer resp.Body.Close()
446
-
447
- respBody, err := io.ReadAll(resp.Body)
448
- if err != nil {
449
- return nil, fmt.Errorf("failed to read response body: %v", err)
450
- }
451
- if s.config.Global.Log.Debug.PrintResponse {
452
- logger.LogContent("Thinking Service Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
453
- }
454
- if resp.StatusCode != http.StatusOK {
455
- return nil, fmt.Errorf("thinking service returned %d: %s", resp.StatusCode, string(respBody))
456
- }
457
-
458
- var thinkingResp ChatCompletionResponse
459
- if err := json.Unmarshal(respBody, &thinkingResp); err != nil {
460
- return nil, fmt.Errorf("failed to unmarshal thinking response: %v", err)
461
- }
462
- if len(thinkingResp.Choices) == 0 {
463
- return nil, fmt.Errorf("thinking service returned no choices")
464
- }
465
-
466
- result := &ThinkingResponse{}
467
- choice := thinkingResp.Choices[0]
468
-
469
- if svc.Mode == "full" {
470
- result.ReasoningContent = choice.Message.Content
471
- result.Content = "Based on the above detailed analysis."
472
- } else {
473
- if choice.Message.ReasoningContent != nil {
474
- switch v := choice.Message.ReasoningContent.(type) {
475
- case string:
476
- result.ReasoningContent = v
477
- case map[string]interface{}:
478
- if j, err := json.Marshal(v); err == nil {
479
- result.ReasoningContent = string(j)
480
- }
481
- }
482
- }
483
- if result.ReasoningContent == "" {
484
- result.ReasoningContent = choice.Message.Content
485
- }
486
- result.Content = "Based on the above reasoning."
487
- }
488
- return result, nil
489
- }
490
-
491
- func (s *Server) prepareEnhancedRequest(originalReq *ChatCompletionRequest, thinkingResp *ThinkingResponse, svc ThinkingService) *ChatCompletionRequest {
492
- newReq := *originalReq
493
- var systemPrompt string
494
- if svc.Mode == "full" {
495
- systemPrompt = fmt.Sprintf(`Consider the following detailed analysis (not shown to user):
496
- %s
497
-
498
- Provide a clear, concise response that incorporates insights from this analysis.`, thinkingResp.ReasoningContent)
499
- } else {
500
- systemPrompt = fmt.Sprintf(`Previous thinking process:
501
- %s
502
- Please consider the above thinking process in your response.`, thinkingResp.ReasoningContent)
503
- }
504
- newReq.Messages = append([]ChatCompletionMessage{
505
- {Role: "system", Content: systemPrompt},
506
- }, newReq.Messages...)
507
- return &newReq
508
- }
509
-
510
- func (s *Server) forwardRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, channel Channel) {
511
- logger := NewRequestLogger(s.config)
512
- if s.config.Global.Log.Debug.PrintRequest {
513
- logger.LogContent("Forward Request", req, s.config.Global.Log.Debug.MaxContentLength)
514
- }
515
- jsonData, err := json.Marshal(req)
516
- if err != nil {
517
- http.Error(w, "Failed to marshal request", http.StatusInternalServerError)
518
- return
519
- }
520
-
521
- client, err := createHTTPClient(channel.Proxy, time.Duration(channel.Timeout)*time.Second)
522
- if err != nil {
523
- http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError)
524
- return
525
- }
526
-
527
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, channel.GetFullURL(), bytes.NewBuffer(jsonData))
528
- if err != nil {
529
- http.Error(w, "Failed to create request", http.StatusInternalServerError)
530
- return
531
- }
532
- httpReq.Header.Set("Content-Type", "application/json")
533
- httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
534
-
535
- resp, err := client.Do(httpReq)
536
- if err != nil {
537
- http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError)
538
- return
539
- }
540
- defer resp.Body.Close()
541
-
542
- respBody, err := io.ReadAll(resp.Body)
543
- if err != nil {
544
- http.Error(w, "Failed to read response", http.StatusInternalServerError)
545
- return
546
- }
547
- if s.config.Global.Log.Debug.PrintResponse {
548
- logger.LogContent("Forward Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
549
- }
550
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
551
- http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode)
552
- return
553
- }
554
-
555
- for k, vals := range resp.Header {
556
- for _, v := range vals {
557
- w.Header().Add(k, v)
558
- }
559
- }
560
- w.WriteHeader(resp.StatusCode)
561
- w.Write(respBody)
562
- }
563
-
564
- func (s *Server) forwardModelsRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, targetChannel Channel) {
565
- logger := NewRequestLogger(s.config)
566
- if s.config.Global.Log.Debug.PrintRequest {
567
- logger.LogContent("/v1/models Request", req, s.config.Global.Log.Debug.MaxContentLength)
568
- }
569
- fullChatURL := targetChannel.GetFullURL()
570
- parsedURL, err := url.Parse(fullChatURL)
571
- if err != nil {
572
- http.Error(w, "Failed to parse channel URL", http.StatusInternalServerError)
573
- return
574
- }
575
- baseURL := parsedURL.Scheme + "://" + parsedURL.Host
576
- modelsURL := strings.TrimSuffix(baseURL, "/") + "/v1/models"
577
-
578
- client, err := createHTTPClient(targetChannel.Proxy, time.Duration(targetChannel.Timeout)*time.Second)
579
- if err != nil {
580
- http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError)
581
- return
582
- }
583
-
584
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil)
585
- if err != nil {
586
- http.Error(w, "Failed to create request", http.StatusInternalServerError)
587
- return
588
- }
589
- httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
590
-
591
- resp, err := client.Do(httpReq)
592
- if err != nil {
593
- http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError)
594
- return
595
- }
596
- defer resp.Body.Close()
597
-
598
- respBody, err := io.ReadAll(resp.Body)
599
- if err != nil {
600
- http.Error(w, "Failed to read response", http.StatusInternalServerError)
601
- return
602
- }
603
- if s.config.Global.Log.Debug.PrintResponse {
604
- logger.LogContent("/v1/models Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
605
- }
606
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
607
- http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode)
608
- return
609
- }
610
-
611
- for k, vals := range resp.Header {
612
- for _, v := range vals {
613
- w.Header().Add(k, v)
614
- }
615
- }
616
- w.WriteHeader(resp.StatusCode)
617
- w.Write(respBody)
618
- }
619
-
620
- // ---------------------- 流式处理 ----------------------
621
-
622
- // collectedReasoningBuffer 用于收集思考服务返回的 reasoning_content(标准模式只收集 reasoning_content;full 模式收集全部)
623
- type collectedReasoningBuffer struct {
624
- builder strings.Builder
625
- mode string
626
- }
627
-
628
- func (cb *collectedReasoningBuffer) append(text string) {
629
- cb.builder.WriteString(text)
630
- }
631
-
632
- func (cb *collectedReasoningBuffer) get() string {
633
- return cb.builder.String()
634
- }
635
-
636
- type StreamHandler struct {
637
- thinkingService ThinkingService
638
- targetChannel Channel
639
- writer http.ResponseWriter
640
- flusher http.Flusher
641
- config *Config
642
- }
643
-
644
- func NewStreamHandler(w http.ResponseWriter, thinkingService ThinkingService, targetChannel Channel, config *Config) (*StreamHandler, error) {
645
- flusher, ok := w.(http.Flusher)
646
- if !ok {
647
- return nil, fmt.Errorf("streaming not supported by this writer")
648
- }
649
- return &StreamHandler{
650
- thinkingService: thinkingService,
651
- targetChannel: targetChannel,
652
- writer: w,
653
- flusher: flusher,
654
- config: config,
655
- }, nil
656
- }
657
-
658
- // HandleRequest 分两阶段:先流式接收思考服务 SSE 并转发给客户端(只转发包含 reasoning_content 的 chunk,标准模式遇到纯 content 则中断);
659
- // 然后使用收集到的 reasoning_content 构造最终请求,发起目标 Channel 的 SSE 并转发给客户端。
660
- func (h *StreamHandler) HandleRequest(ctx context.Context, req *ChatCompletionRequest) error {
661
- h.writer.Header().Set("Content-Type", "text/event-stream")
662
- h.writer.Header().Set("Cache-Control", "no-cache")
663
- h.writer.Header().Set("Connection", "keep-alive")
664
-
665
- logger := NewRequestLogger(h.config)
666
- reasonBuf := &collectedReasoningBuffer{mode: h.thinkingService.Mode}
667
-
668
- // 阶段一:流式接收思考服务 SSE
669
- if err := h.streamThinkingService(ctx, req, reasonBuf, logger); err != nil {
670
- return err
671
- }
672
-
673
- // 阶段二:构造最终请求,并流式转发目标 Channel SSE
674
- finalReq := h.prepareFinalRequest(req, reasonBuf.get())
675
- if err := h.streamFinalChannel(ctx, finalReq, logger); err != nil {
676
- return err
677
- }
678
- return nil
679
- }
680
-
681
- // streamThinkingService 连接思考服务 SSE,原样转发包含 reasoning_content 的 SSE chunk给客户端,同时收集 reasoning_content。
682
- // 在标准模式下,一旦遇到只返回 non-empty 的 content(且 reasoning_content 为空),则立即中断,不转发该 chunk。
683
- func (h *StreamHandler) streamThinkingService(ctx context.Context, req *ChatCompletionRequest, reasonBuf *collectedReasoningBuffer, logger *RequestLogger) error {
684
- thinkingReq := *req
685
- thinkingReq.Model = h.thinkingService.Model
686
- thinkingReq.APIKey = h.thinkingService.APIKey
687
-
688
- var systemPrompt string
689
- if h.thinkingService.Mode == "full" {
690
- systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly."
691
- } else {
692
- systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step."
693
- }
694
- messages := append([]ChatCompletionMessage{
695
- {Role: "system", Content: systemPrompt},
696
- }, thinkingReq.Messages...)
697
-
698
- temp := 0.7
699
- if h.thinkingService.Temperature != nil {
700
- temp = *h.thinkingService.Temperature
701
- }
702
- bodyMap := map[string]interface{}{
703
- "model": thinkingReq.Model,
704
- "messages": messages,
705
- "temperature": temp,
706
- "stream": true,
707
- }
708
- if isValidReasoningEffort(h.thinkingService.ReasoningEffort) {
709
- bodyMap["reasoning_effort"] = h.thinkingService.ReasoningEffort
710
- }
711
- if isValidReasoningFormat(h.thinkingService.ReasoningFormat) {
712
- bodyMap["reasoning_format"] = h.thinkingService.ReasoningFormat
713
- }
714
-
715
- jsonData, err := json.Marshal(bodyMap)
716
- if err != nil {
717
- return err
718
- }
719
-
720
- client, err := createHTTPClient(h.thinkingService.Proxy, time.Duration(h.thinkingService.Timeout)*time.Second)
721
- if err != nil {
722
- return err
723
- }
724
-
725
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.thinkingService.GetFullURL(), bytes.NewBuffer(jsonData))
726
- if err != nil {
727
- return err
728
- }
729
- httpReq.Header.Set("Content-Type", "application/json")
730
- httpReq.Header.Set("Authorization", "Bearer "+h.thinkingService.APIKey)
731
-
732
- log.Printf("Starting ThinkingService SSE => %s (mode=%s, force_stop=%v)", h.thinkingService.GetFullURL(), h.thinkingService.Mode, h.thinkingService.ForceStopDeepThinking)
733
- resp, err := client.Do(httpReq)
734
- if err != nil {
735
- return err
736
- }
737
- defer resp.Body.Close()
738
-
739
- if resp.StatusCode != http.StatusOK {
740
- b, _ := io.ReadAll(resp.Body)
741
- return fmt.Errorf("thinking service status=%d, body=%s", resp.StatusCode, string(b))
742
- }
743
-
744
- reader := bufio.NewReader(resp.Body)
745
- var lastLine string
746
- // 标识是否需中断思考 SSE(标准模式下)
747
- forceStop := false
748
-
749
- for {
750
- select {
751
- case <-ctx.Done():
752
- return ctx.Err()
753
- default:
754
- }
755
- line, err := reader.ReadString('\n')
756
- if err != nil {
757
- if err == io.EOF {
758
- break
759
- }
760
- return err
761
- }
762
- line = strings.TrimSpace(line)
763
- if line == "" || line == lastLine {
764
- continue
765
- }
766
- lastLine = line
767
-
768
- if !strings.HasPrefix(line, "data: ") {
769
- continue
770
- }
771
- dataPart := strings.TrimPrefix(line, "data: ")
772
- if dataPart == "[DONE]" {
773
- // 思考服务结束:直接中断(不发送特殊标记)
774
- break
775
- }
776
-
777
- var chunk struct {
778
- Choices []struct {
779
- Delta struct {
780
- Content string `json:"content,omitempty"`
781
- ReasoningContent string `json:"reasoning_content,omitempty"`
782
- } `json:"delta"`
783
- FinishReason *string `json:"finish_reason,omitempty"`
784
- } `json:"choices"`
785
- }
786
- if err := json.Unmarshal([]byte(dataPart), &chunk); err != nil {
787
- // 如果解析失败,直接原样转发
788
- h.writer.Write([]byte("data: " + dataPart + "\n\n"))
789
- h.flusher.Flush()
790
- continue
791
- }
792
-
793
- if len(chunk.Choices) > 0 {
794
- c := chunk.Choices[0]
795
- if h.config.Global.Log.Debug.PrintResponse {
796
- logger.LogContent("Thinking SSE chunk", chunk, h.config.Global.Log.Debug.MaxContentLength)
797
- }
798
-
799
- if h.thinkingService.Mode == "full" {
800
- // full 模式:收集所有内容
801
- if c.Delta.ReasoningContent != "" {
802
- reasonBuf.append(c.Delta.ReasoningContent)
803
- }
804
- if c.Delta.Content != "" {
805
- reasonBuf.append(c.Delta.Content)
806
- }
807
- // 原样转发整 chunk
808
- forwardLine := "data: " + dataPart + "\n\n"
809
- h.writer.Write([]byte(forwardLine))
810
- h.flusher.Flush()
811
- } else {
812
- // standard 模式:只收集 reasoning_content
813
- if c.Delta.ReasoningContent != "" {
814
- reasonBuf.append(c.Delta.ReasoningContent)
815
- // 转发该 chunk
816
- forwardLine := "data: " + dataPart + "\n\n"
817
- h.writer.Write([]byte(forwardLine))
818
- h.flusher.Flush()
819
- }
820
- // 如果遇到 chunk 中只有 content(且 reasoning_content 为空),认为思考链结束,不转发该 chunk
821
- if c.Delta.Content != "" && strings.TrimSpace(c.Delta.ReasoningContent) == "" {
822
- forceStop = true
823
- }
824
- }
825
-
826
- // 如果 finishReason 非空,也认为结束
827
- if c.FinishReason != nil && *c.FinishReason != "" {
828
- forceStop = true
829
- }
830
- }
831
-
832
- if forceStop {
833
- break
834
- }
835
- }
836
- // 读空剩余数据
837
- io.Copy(io.Discard, reader)
838
- return nil
839
- }
840
-
841
- func (h *StreamHandler) prepareFinalRequest(originalReq *ChatCompletionRequest, reasoningCollected string) *ChatCompletionRequest {
842
- req := *originalReq
843
- var systemPrompt string
844
- if h.thinkingService.Mode == "full" {
845
- systemPrompt = fmt.Sprintf(
846
- `Consider the following detailed analysis (not shown to user):
847
- %s
848
-
849
- Provide a clear, concise response that incorporates insights from this analysis.`,
850
- reasoningCollected,
851
- )
852
- } else {
853
- systemPrompt = fmt.Sprintf(
854
- `Previous thinking process:
855
- %s
856
- Please consider the above thinking process in your response.`,
857
- reasoningCollected,
858
- )
859
- }
860
- req.Messages = append([]ChatCompletionMessage{
861
- {Role: "system", Content: systemPrompt},
862
- }, req.Messages...)
863
- return &req
864
- }
865
-
866
- func (h *StreamHandler) streamFinalChannel(ctx context.Context, req *ChatCompletionRequest, logger *RequestLogger) error {
867
- if h.config.Global.Log.Debug.PrintRequest {
868
- logger.LogContent("Final Request to Channel", req, h.config.Global.Log.Debug.MaxContentLength)
869
- }
870
-
871
- jsonData, err := json.Marshal(req)
872
- if err != nil {
873
- return err
874
- }
875
- client, err := createHTTPClient(h.targetChannel.Proxy, time.Duration(h.targetChannel.Timeout)*time.Second)
876
- if err != nil {
877
- return err
878
- }
879
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.targetChannel.GetFullURL(), bytes.NewBuffer(jsonData))
880
- if err != nil {
881
- return err
882
- }
883
- httpReq.Header.Set("Content-Type", "application/json")
884
- httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
885
-
886
- resp, err := client.Do(httpReq)
887
- if err != nil {
888
- return err
889
- }
890
- defer resp.Body.Close()
891
-
892
- if resp.StatusCode != http.StatusOK {
893
- b, _ := io.ReadAll(resp.Body)
894
- return fmt.Errorf("target channel status=%d, body=%s", resp.StatusCode, string(b))
895
- }
896
-
897
- reader := bufio.NewReader(resp.Body)
898
- var lastLine string
899
-
900
- for {
901
- select {
902
- case <-ctx.Done():
903
- return ctx.Err()
904
- default:
905
- }
906
- line, err := reader.ReadString('\n')
907
- if err != nil {
908
- if err == io.EOF {
909
- break
910
- }
911
- return err
912
- }
913
- line = strings.TrimSpace(line)
914
- if line == "" || line == lastLine {
915
- continue
916
- }
917
- lastLine = line
918
-
919
- if !strings.HasPrefix(line, "data: ") {
920
- continue
921
- }
922
- data := strings.TrimPrefix(line, "data: ")
923
- if data == "[DONE]" {
924
- doneLine := "data: [DONE]\n\n"
925
- h.writer.Write([]byte(doneLine))
926
- h.flusher.Flush()
927
- break
928
- }
929
- forwardLine := "data: " + data + "\n\n"
930
- h.writer.Write([]byte(forwardLine))
931
- h.flusher.Flush()
932
-
933
- if h.config.Global.Log.Debug.PrintResponse {
934
- logger.LogContent("Channel SSE chunk", forwardLine, h.config.Global.Log.Debug.MaxContentLength)
935
- }
936
- }
937
- io.Copy(io.Discard, reader)
938
- return nil
939
- }
940
-
941
- // ---------------------- HTTP Client 工具 ----------------------
942
-
943
- func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) {
944
- transport := &http.Transport{
945
- DialContext: (&net.Dialer{
946
- Timeout: 30 * time.Second,
947
- KeepAlive: 30 * time.Second,
948
- }).DialContext,
949
- ForceAttemptHTTP2: true,
950
- MaxIdleConns: 100,
951
- IdleConnTimeout: 90 * time.Second,
952
- TLSHandshakeTimeout: 10 * time.Second,
953
- ExpectContinueTimeout: 1 * time.Second,
954
- }
955
- if proxyURL != "" {
956
- parsedURL, err := url.Parse(proxyURL)
957
- if err != nil {
958
- return nil, fmt.Errorf("invalid proxy URL: %v", err)
959
- }
960
- switch parsedURL.Scheme {
961
- case "http", "https":
962
- transport.Proxy = http.ProxyURL(parsedURL)
963
- case "socks5":
964
- dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
965
- if err != nil {
966
- return nil, fmt.Errorf("failed to create SOCKS5 dialer: %v", err)
967
- }
968
- transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
969
- return dialer.Dial(network, addr)
970
- }
971
- default:
972
- return nil, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
973
- }
974
- }
975
- return &http.Client{
976
- Transport: transport,
977
- Timeout: timeout,
978
- }, nil
979
- }
980
-
981
- func maskSensitiveHeaders(headers http.Header) http.Header {
982
- masked := make(http.Header)
983
- for k, vals := range headers {
984
- if strings.ToLower(k) == "authorization" {
985
- masked[k] = []string{"Bearer ****"}
986
- } else {
987
- masked[k] = vals
988
- }
989
- }
990
- return masked
991
- }
992
-
993
- func isValidReasoningEffort(effort string) bool {
994
- switch strings.ToLower(effort) {
995
- case "low", "medium", "high":
996
- return true
997
- }
998
- return false
999
- }
1000
-
1001
- func isValidReasoningFormat(format string) bool {
1002
- switch strings.ToLower(format) {
1003
- case "parsed", "raw", "hidden":
1004
- return true
1005
- }
1006
- return false
1007
- }
1008
-
1009
- // ---------------------- 配置加载与验证 ----------------------
1010
-
1011
- // 首先添加这些辅助函数
1012
- func getEnvOrDefault(key, defaultValue string) string {
1013
- if value := os.Getenv(key); value != "" {
1014
- return value
1015
- }
1016
- return defaultValue
1017
- }
1018
-
1019
- func getEnvIntOrDefault(key string, defaultValue int) int {
1020
- if value := os.Getenv(key); value != "" {
1021
- if intValue, err := strconv.Atoi(value); err == nil {
1022
- return intValue
1023
- }
1024
- }
1025
- return defaultValue
1026
- }
1027
-
1028
- func getEnvBoolOrDefault(key string, defaultValue bool) bool {
1029
- if value := os.Getenv(key); value != "" {
1030
- if boolValue, err := strconv.ParseBool(value); err == nil {
1031
- return boolValue
1032
- }
1033
- }
1034
- return defaultValue
1035
- }
1036
-
1037
- func getEnvFloatPtr(key string, defaultValue float64) *float64 {
1038
- if value := os.Getenv(key); value != "" {
1039
- if floatValue, err := strconv.ParseFloat(value, 64); err == nil {
1040
- return &floatValue
1041
- }
1042
- }
1043
- return &defaultValue
1044
- }
1045
-
1046
- // 然后替换原有的 loadConfig 函数
1047
- func loadConfig() (*Config, error) {
1048
- // 优先从环境变量加载完整配置
1049
- if envConfig := os.Getenv("CONFIG_JSON"); envConfig != "" {
1050
- var cfg Config
1051
- if err := json.Unmarshal([]byte(envConfig), &cfg); err == nil {
1052
- if err := validateConfig(&cfg); err == nil {
1053
- log.Println("Using configuration from CONFIG_JSON environment variable")
1054
- return &cfg, nil
1055
- }
1056
- }
1057
- }
1058
-
1059
- // 如果没有完整的 JSON 配置,则从独立的环境变量构建配置
1060
- log.Println("Building configuration from individual environment variables")
1061
- cfg := &Config{
1062
- ThinkingServices: []ThinkingService{
1063
- {
1064
- ID: 1,
1065
- Name: getEnvOrDefault("THINKING_SERVICE_NAME", "modelscope-deepseek-thinking"),
1066
- Mode: getEnvOrDefault("THINKING_SERVICE_MODE", "standard"),
1067
- Model: getEnvOrDefault("THINKING_SERVICE_MODEL", "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"),
1068
- BaseURL: getEnvOrDefault("THINKING_SERVICE_BASE_URL", "https://api-inference.modelscope.cn"),
1069
- APIPath: getEnvOrDefault("THINKING_SERVICE_API_PATH", "/v1/chat/completions"),
1070
- APIKey: os.Getenv("THINKING_SERVICE_API_KEY"), // 敏感信息必须从环境变量读取
1071
- Timeout: getEnvIntOrDefault("THINKING_SERVICE_TIMEOUT", 6000),
1072
- Weight: getEnvIntOrDefault("THINKING_SERVICE_WEIGHT", 100),
1073
- Proxy: getEnvOrDefault("THINKING_SERVICE_PROXY", ""),
1074
- ReasoningEffort: getEnvOrDefault("THINKING_SERVICE_REASONING_EFFORT", "high"),
1075
- ReasoningFormat: getEnvOrDefault("THINKING_SERVICE_REASONING_FORMAT", "parsed"),
1076
- Temperature: getEnvFloatPtr("THINKING_SERVICE_TEMPERATURE", 0.8),
1077
- ForceStopDeepThinking: getEnvBoolOrDefault("THINKING_SERVICE_FORCE_STOP", false),
1078
- },
1079
- },
1080
- Channels: map[string]Channel{
1081
- "1": {
1082
- Name: getEnvOrDefault("CHANNEL_NAME", "Gemini-channel"),
1083
- BaseURL: getEnvOrDefault("CHANNEL_BASE_URL", "https://Richardlsr-gemini-balance.hf.space/hf"),
1084
- APIPath: getEnvOrDefault("CHANNEL_API_PATH", "/v1/chat/completions"),
1085
- Timeout: getEnvIntOrDefault("CHANNEL_TIMEOUT", 600),
1086
- Proxy: getEnvOrDefault("CHANNEL_PROXY", ""),
1087
- },
1088
- },
1089
- Global: GlobalConfig{
1090
- Server: ServerConfig{
1091
- Port: getEnvIntOrDefault("SERVER_PORT", 7860), // HF 默认端口
1092
- Host: getEnvOrDefault("SERVER_HOST", "0.0.0.0"),
1093
- ReadTimeout: getEnvIntOrDefault("SERVER_READ_TIMEOUT", 600),
1094
- WriteTimeout: getEnvIntOrDefault("SERVER_WRITE_TIMEOUT", 600),
1095
- IdleTimeout: getEnvIntOrDefault("SERVER_IDLE_TIMEOUT", 600),
1096
- },
1097
- Log: LogConfig{
1098
- Level: getEnvOrDefault("LOG_LEVEL", "info"),
1099
- Format: getEnvOrDefault("LOG_FORMAT", "json"),
1100
- Output: getEnvOrDefault("LOG_OUTPUT", "console"), // HF 环境下默认输出到控制台
1101
- FilePath: getEnvOrDefault("LOG_FILE_PATH", "./logs/deepai.log"),
1102
- Debug: DebugConfig{
1103
- Enabled: getEnvBoolOrDefault("LOG_DEBUG_ENABLED", true),
1104
- PrintRequest: getEnvBoolOrDefault("LOG_PRINT_REQUEST", true),
1105
- PrintResponse: getEnvBoolOrDefault("LOG_PRINT_RESPONSE", true),
1106
- MaxContentLength: getEnvIntOrDefault("LOG_MAX_CONTENT_LENGTH", 1000),
1107
- },
1108
- },
1109
- },
1110
- }
1111
-
1112
- // 验证必需的环境变量
1113
- if cfg.ThinkingServices[0].APIKey == "" {
1114
- return nil, fmt.Errorf("THINKING_SERVICE_API_KEY environment variable is required")
1115
- }
1116
-
1117
- if err := validateConfig(cfg); err != nil {
1118
- return nil, fmt.Errorf("config validation error: %v", err)
1119
- }
1120
-
1121
- return cfg, nil
1122
- }
1123
-
1124
- func validateConfig(config *Config) error {
1125
- if len(config.ThinkingServices) == 0 {
1126
- return fmt.Errorf("no thinking services configured")
1127
- }
1128
- if len(config.Channels) == 0 {
1129
- return fmt.Errorf("no channels configured")
1130
- }
1131
- for i, svc := range config.ThinkingServices {
1132
- if svc.BaseURL == "" {
1133
- return fmt.Errorf("thinking service %s has empty baseURL", svc.Name)
1134
- }
1135
- if svc.APIKey == "" {
1136
- return fmt.Errorf("thinking service %s has empty apiKey", svc.Name)
1137
- }
1138
- if svc.Timeout <= 0 {
1139
- return fmt.Errorf("thinking service %s has invalid timeout", svc.Name)
1140
- }
1141
- if svc.Model == "" {
1142
- return fmt.Errorf("thinking service %s has empty model", svc.Name)
1143
- }
1144
- if svc.Mode == "" {
1145
- config.ThinkingServices[i].Mode = "standard"
1146
- } else if svc.Mode != "standard" && svc.Mode != "full" {
1147
- return fmt.Errorf("thinking service %s unknown mode=%s", svc.Name, svc.Mode)
1148
- }
1149
- }
1150
- return nil
1151
- }
1152
-
1153
- func main() {
1154
- log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile)
1155
-
1156
- cfg, err := loadConfig()
1157
- if err != nil {
1158
- log.Fatalf("Failed to load config: %v", err)
1159
- }
1160
- log.Printf("Using config file: %s", viper.ConfigFileUsed())
1161
-
1162
- server := NewServer(cfg)
1163
-
1164
- done := make(chan os.Signal, 1)
1165
- signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
1166
-
1167
- go func() {
1168
- if err := server.Start(); err != nil && err != http.ErrServerClosed {
1169
- log.Fatalf("start server error: %v", err)
1170
- }
1171
- }()
1172
- log.Printf("Server started successfully")
1173
-
1174
- <-done
1175
- log.Print("Server stopping...")
1176
-
1177
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
1178
- defer cancel()
1179
-
1180
- if err := server.Shutdown(ctx); err != nil {
1181
- log.Printf("Server forced to shutdown: %v", err)
1182
- }
1183
- log.Print("Server stopped")
1184
- }