nbugs commited on
Commit
47a75d4
·
verified ·
1 Parent(s): f5a0922

Upload main.go

Browse files
Files changed (1) hide show
  1. main.go +1184 -0
main.go ADDED
@@ -0,0 +1,1184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "bufio"
5
+ "bytes"
6
+ "context"
7
+ "encoding/json"
8
+
9
+ "fmt"
10
+ "io"
11
+ "log"
12
+ "math/rand"
13
+ "net"
14
+ "net/http"
15
+ "net/url"
16
+ "os"
17
+ "os/signal"
18
+
19
+ "regexp"
20
+ "strings"
21
+ "sync"
22
+ "syscall"
23
+ "time"
24
+ "strconv"
25
+ "github.com/google/uuid"
26
+ "github.com/spf13/viper"
27
+ "golang.org/x/net/proxy"
28
+ )
29
+
30
+ // ---------------------- 配置结构 ----------------------
31
+
32
+ type Config struct {
33
+ ThinkingServices []ThinkingService `mapstructure:"thinking_services"`
34
+ Channels map[string]Channel `mapstructure:"channels"`
35
+ Global GlobalConfig `mapstructure:"global"`
36
+ }
37
+
38
+ type ThinkingService struct {
39
+ ID int `mapstructure:"id"`
40
+ Name string `mapstructure:"name"`
41
+ Model string `mapstructure:"model"`
42
+ BaseURL string `mapstructure:"base_url"`
43
+ APIPath string `mapstructure:"api_path"`
44
+ APIKey string `mapstructure:"api_key"`
45
+ Timeout int `mapstructure:"timeout"`
46
+ Retry int `mapstructure:"retry"`
47
+ Weight int `mapstructure:"weight"`
48
+ Proxy string `mapstructure:"proxy"`
49
+ Mode string `mapstructure:"mode"` // "standard" 或 "full"
50
+ ReasoningEffort string `mapstructure:"reasoning_effort"`
51
+ ReasoningFormat string `mapstructure:"reasoning_format"`
52
+ Temperature *float64 `mapstructure:"temperature"`
53
+ ForceStopDeepThinking bool `mapstructure:"force_stop_deep_thinking"` // 配置项:标准模式下遇到 content 时是否立即停止
54
+ }
55
+
56
+ func (s *ThinkingService) GetFullURL() string {
57
+ return s.BaseURL + s.APIPath
58
+ }
59
+
60
+ type Channel struct {
61
+ Name string `mapstructure:"name"`
62
+ BaseURL string `mapstructure:"base_url"`
63
+ APIPath string `mapstructure:"api_path"`
64
+ Timeout int `mapstructure:"timeout"`
65
+ Proxy string `mapstructure:"proxy"`
66
+ }
67
+
68
+ func (c *Channel) GetFullURL() string {
69
+ return c.BaseURL + c.APIPath
70
+ }
71
+
72
+ type LogConfig struct {
73
+ Level string `mapstructure:"level"`
74
+ Format string `mapstructure:"format"`
75
+ Output string `mapstructure:"output"`
76
+ FilePath string `mapstructure:"file_path"`
77
+ Debug DebugConfig `mapstructure:"debug"`
78
+ }
79
+
80
+ type DebugConfig struct {
81
+ Enabled bool `mapstructure:"enabled"`
82
+ PrintRequest bool `mapstructure:"print_request"`
83
+ PrintResponse bool `mapstructure:"print_response"`
84
+ MaxContentLength int `mapstructure:"max_content_length"`
85
+ }
86
+
87
+ type ProxyConfig struct {
88
+ Enabled bool `mapstructure:"enabled"`
89
+ Default string `mapstructure:"default"`
90
+ AllowInsecure bool `mapstructure:"allow_insecure"`
91
+ }
92
+
93
+ type GlobalConfig struct {
94
+ MaxRetries int `mapstructure:"max_retries"`
95
+ DefaultTimeout int `mapstructure:"default_timeout"`
96
+ ErrorCodes struct {
97
+ RetryOn []int `mapstructure:"retry_on"`
98
+ } `mapstructure:"error_codes"`
99
+ Log LogConfig `mapstructure:"log"`
100
+ Server ServerConfig `mapstructure:"server"`
101
+ Proxy ProxyConfig `mapstructure:"proxy"`
102
+ ConfigPaths []string `mapstructure:"config_paths"`
103
+ Thinking ThinkingConfig `mapstructure:"thinking"`
104
+ }
105
+
106
+ type ServerConfig struct {
107
+ Port int `mapstructure:"port"`
108
+ Host string `mapstructure:"host"`
109
+ ReadTimeout int `mapstructure:"read_timeout"`
110
+ WriteTimeout int `mapstructure:"write_timeout"`
111
+ IdleTimeout int `mapstructure:"idle_timeout"`
112
+ }
113
+
114
+ type ThinkingConfig struct {
115
+ Enabled bool `mapstructure:"enabled"`
116
+ AddToAllRequests bool `mapstructure:"add_to_all_requests"`
117
+ Timeout int `mapstructure:"timeout"`
118
+ }
119
+
120
+ // ---------------------- API 相关结构 ----------------------
121
+
122
+ type ChatCompletionRequest struct {
123
+ Model string `json:"model"`
124
+ Messages []ChatCompletionMessage `json:"messages"`
125
+ Temperature float64 `json:"temperature,omitempty"`
126
+ MaxTokens int `json:"max_tokens,omitempty"`
127
+ Stream bool `json:"stream,omitempty"`
128
+ APIKey string `json:"-"` // 内部传递,不序列化
129
+ }
130
+
131
+ type ChatCompletionMessage struct {
132
+ Role string `json:"role"`
133
+ Content string `json:"content"`
134
+ ReasoningContent interface{} `json:"reasoning_content,omitempty"`
135
+ }
136
+
137
+ type ChatCompletionResponse struct {
138
+ ID string `json:"id"`
139
+ Object string `json:"object"`
140
+ Created int64 `json:"created"`
141
+ Model string `json:"model"`
142
+ Choices []Choice `json:"choices"`
143
+ Usage Usage `json:"usage"`
144
+ }
145
+
146
+ type Choice struct {
147
+ Index int `json:"index"`
148
+ Message ChatCompletionMessage `json:"message"`
149
+ FinishReason string `json:"finish_reason"`
150
+ }
151
+
152
+ type Usage struct {
153
+ PromptTokens int `json:"prompt_tokens"`
154
+ CompletionTokens int `json:"completion_tokens"`
155
+ TotalTokens int `json:"total_tokens"`
156
+ }
157
+
158
+ // ---------------------- 日志工具 ----------------------
159
+
160
+ type RequestLogger struct {
161
+ RequestID string
162
+ Model string
163
+ StartTime time.Time
164
+ logs []string
165
+ config *Config
166
+ }
167
+
168
+ func NewRequestLogger(config *Config) *RequestLogger {
169
+ return &RequestLogger{
170
+ RequestID: uuid.New().String(),
171
+ StartTime: time.Now(),
172
+ logs: make([]string, 0),
173
+ config: config,
174
+ }
175
+ }
176
+
177
+ func (l *RequestLogger) Log(format string, args ...interface{}) {
178
+ msg := fmt.Sprintf(format, args...)
179
+ l.logs = append(l.logs, fmt.Sprintf("[%s] %s", time.Now().Format(time.RFC3339), msg))
180
+ log.Printf("[RequestID: %s] %s", l.RequestID, msg)
181
+ }
182
+
183
+ func (l *RequestLogger) LogContent(contentType string, content interface{}, maxLength int) {
184
+ if !l.config.Global.Log.Debug.Enabled {
185
+ return
186
+ }
187
+ sanitizedContent := sanitizeJSON(content)
188
+ truncatedContent := truncateContent(sanitizedContent, maxLength)
189
+ l.Log("%s Content:\n%s", contentType, truncatedContent)
190
+ }
191
+
192
+ func truncateContent(content string, maxLength int) string {
193
+ if len(content) <= maxLength {
194
+ return content
195
+ }
196
+ return content[:maxLength] + "... (truncated)"
197
+ }
198
+
199
+ func sanitizeJSON(data interface{}) string {
200
+ sanitized, err := json.Marshal(data)
201
+ if err != nil {
202
+ return "Failed to marshal JSON"
203
+ }
204
+ content := string(sanitized)
205
+ sensitivePattern := `"api_key":\s*"[^"]*"`
206
+ content = regexp.MustCompile(sensitivePattern).ReplaceAllString(content, `"api_key":"****"`)
207
+ return content
208
+ }
209
+
210
+ func extractRealAPIKey(fullKey string) string {
211
+ parts := strings.Split(fullKey, "-")
212
+ if len(parts) >= 3 && (parts[0] == "deep" || parts[0] == "openai") {
213
+ return strings.Join(parts[2:], "-")
214
+ }
215
+ return fullKey
216
+ }
217
+
218
+ func extractChannelID(fullKey string) string {
219
+ parts := strings.Split(fullKey, "-")
220
+ if len(parts) >= 2 && (parts[0] == "deep" || parts[0] == "openai") {
221
+ return parts[1]
222
+ }
223
+ return "1" // 默认渠道
224
+ }
225
+
226
+ func logAPIKey(key string) string {
227
+ if len(key) <= 8 {
228
+ return "****"
229
+ }
230
+ return key[:4] + "..." + key[len(key)-4:]
231
+ }
232
+
233
+ // ---------------------- Server 结构 ----------------------
234
+
235
+ type Server struct {
236
+ config *Config
237
+ srv *http.Server
238
+ }
239
+
240
+ var (
241
+ randMu sync.Mutex
242
+ randGen = rand.New(rand.NewSource(time.Now().UnixNano()))
243
+ )
244
+
245
+ func NewServer(config *Config) *Server {
246
+ return &Server{
247
+ config: config,
248
+ }
249
+ }
250
+
251
+ func (s *Server) Start() error {
252
+ mux := http.NewServeMux()
253
+ mux.HandleFunc("/v1/chat/completions", s.handleOpenAIRequests)
254
+ mux.HandleFunc("/v1/models", s.handleOpenAIRequests)
255
+ mux.HandleFunc("/health", s.handleHealth)
256
+
257
+ s.srv = &http.Server{
258
+ Addr: fmt.Sprintf("%s:%d", s.config.Global.Server.Host, s.config.Global.Server.Port),
259
+ Handler: mux,
260
+ ReadTimeout: time.Duration(s.config.Global.Server.ReadTimeout) * time.Second,
261
+ WriteTimeout: time.Duration(s.config.Global.Server.WriteTimeout) * time.Second,
262
+ IdleTimeout: time.Duration(s.config.Global.Server.IdleTimeout) * time.Second,
263
+ }
264
+
265
+ log.Printf("Server starting on %s\n", s.srv.Addr)
266
+ return s.srv.ListenAndServe()
267
+ }
268
+
269
+ func (s *Server) Shutdown(ctx context.Context) error {
270
+ return s.srv.Shutdown(ctx)
271
+ }
272
+
273
+ func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
274
+ w.WriteHeader(http.StatusOK)
275
+ json.NewEncoder(w).Encode(map[string]string{"status": "healthy"})
276
+ }
277
+
278
+ func (s *Server) handleOpenAIRequests(w http.ResponseWriter, r *http.Request) {
279
+ logger := NewRequestLogger(s.config)
280
+
281
+ fullAPIKey := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
282
+ apiKey := extractRealAPIKey(fullAPIKey)
283
+ channelID := extractChannelID(fullAPIKey)
284
+
285
+ logger.Log("Received request for %s with API Key: %s", r.URL.Path, logAPIKey(fullAPIKey))
286
+ logger.Log("Extracted channel ID: %s", channelID)
287
+ logger.Log("Extracted real API Key: %s", logAPIKey(apiKey))
288
+
289
+ targetChannel, ok := s.config.Channels[channelID]
290
+ if !ok {
291
+ http.Error(w, "Invalid channel", http.StatusBadRequest)
292
+ return
293
+ }
294
+
295
+ if r.URL.Path == "/v1/models" {
296
+ if r.Method != http.MethodGet {
297
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
298
+ return
299
+ }
300
+ req := &ChatCompletionRequest{APIKey: apiKey}
301
+ s.forwardModelsRequest(w, r.Context(), req, targetChannel)
302
+ return
303
+ }
304
+
305
+ if r.Method != http.MethodPost {
306
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
307
+ return
308
+ }
309
+
310
+ body, err := io.ReadAll(r.Body)
311
+ if err != nil {
312
+ logger.Log("Error reading request body: %v", err)
313
+ http.Error(w, "Failed to read request", http.StatusBadRequest)
314
+ return
315
+ }
316
+ r.Body.Close()
317
+ r.Body = io.NopCloser(bytes.NewBuffer(body))
318
+
319
+ if s.config.Global.Log.Debug.PrintRequest {
320
+ logger.LogContent("Request", string(body), s.config.Global.Log.Debug.MaxContentLength)
321
+ }
322
+
323
+ var req ChatCompletionRequest
324
+ if err := json.NewDecoder(bytes.NewBuffer(body)).Decode(&req); err != nil {
325
+ http.Error(w, "Invalid request body", http.StatusBadRequest)
326
+ return
327
+ }
328
+ req.APIKey = apiKey
329
+
330
+ thinkingService := s.getWeightedRandomThinkingService()
331
+ logger.Log("Using thinking service: %s with API Key: %s", thinkingService.Name, logAPIKey(thinkingService.APIKey))
332
+
333
+ if req.Stream {
334
+ handler, err := NewStreamHandler(w, thinkingService, targetChannel, s.config)
335
+ if err != nil {
336
+ http.Error(w, "Streaming not supported", http.StatusInternalServerError)
337
+ return
338
+ }
339
+ if err := handler.HandleRequest(r.Context(), &req); err != nil {
340
+ logger.Log("Stream handler error: %v", err)
341
+ }
342
+ } else {
343
+ thinkingResp, err := s.processThinkingContent(r.Context(), &req, thinkingService)
344
+ if err != nil {
345
+ logger.Log("Error processing thinking content: %v", err)
346
+ http.Error(w, "Thinking service error: "+err.Error(), http.StatusInternalServerError)
347
+ return
348
+ }
349
+ enhancedReq := s.prepareEnhancedRequest(&req, thinkingResp, thinkingService)
350
+ s.forwardRequest(w, r.Context(), enhancedReq, targetChannel)
351
+ }
352
+ }
353
+
354
+ func (s *Server) getWeightedRandomThinkingService() ThinkingService {
355
+ thinkingServices := s.config.ThinkingServices
356
+ if len(thinkingServices) == 0 {
357
+ return ThinkingService{}
358
+ }
359
+ totalWeight := 0
360
+ for _, svc := range thinkingServices {
361
+ totalWeight += svc.Weight
362
+ }
363
+ if totalWeight <= 0 {
364
+ log.Println("Warning: Total weight of thinking services is not positive, using first service as default.")
365
+ return thinkingServices[0]
366
+ }
367
+ randMu.Lock()
368
+ randNum := randGen.Intn(totalWeight)
369
+ randMu.Unlock()
370
+ currentSum := 0
371
+ for _, svc := range thinkingServices {
372
+ currentSum += svc.Weight
373
+ if randNum < currentSum {
374
+ return svc
375
+ }
376
+ }
377
+ return thinkingServices[0]
378
+ }
379
+
380
+ // ---------------------- 非流式处理思考服务 ----------------------
381
+
382
+ type ThinkingResponse struct {
383
+ Content string
384
+ ReasoningContent string
385
+ }
386
+
387
+ func (s *Server) processThinkingContent(ctx context.Context, req *ChatCompletionRequest, svc ThinkingService) (*ThinkingResponse, error) {
388
+ logger := NewRequestLogger(s.config)
389
+ log.Printf("Getting thinking content from service: %s (mode=%s)", svc.Name, svc.Mode)
390
+
391
+ thinkingReq := *req
392
+ thinkingReq.Model = svc.Model
393
+ thinkingReq.APIKey = svc.APIKey
394
+
395
+ var systemPrompt string
396
+ if svc.Mode == "full" {
397
+ systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly."
398
+ } else {
399
+ systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step."
400
+ }
401
+ thinkingReq.Messages = append([]ChatCompletionMessage{
402
+ {Role: "system", Content: systemPrompt},
403
+ }, thinkingReq.Messages...)
404
+
405
+ temp := 0.7
406
+ if svc.Temperature != nil {
407
+ temp = *svc.Temperature
408
+ }
409
+ payload := map[string]interface{}{
410
+ "model": svc.Model,
411
+ "messages": thinkingReq.Messages,
412
+ "stream": false,
413
+ "temperature": temp,
414
+ }
415
+ if isValidReasoningEffort(svc.ReasoningEffort) {
416
+ payload["reasoning_effort"] = svc.ReasoningEffort
417
+ }
418
+ if isValidReasoningFormat(svc.ReasoningFormat) {
419
+ payload["reasoning_format"] = svc.ReasoningFormat
420
+ }
421
+
422
+ if s.config.Global.Log.Debug.PrintRequest {
423
+ logger.LogContent("Thinking Service Request", payload, s.config.Global.Log.Debug.MaxContentLength)
424
+ }
425
+
426
+ jsonData, err := json.Marshal(payload)
427
+ if err != nil {
428
+ return nil, fmt.Errorf("failed to marshal thinking request: %v", err)
429
+ }
430
+ client, err := createHTTPClient(svc.Proxy, time.Duration(svc.Timeout)*time.Second)
431
+ if err != nil {
432
+ return nil, fmt.Errorf("failed to create HTTP client: %v", err)
433
+ }
434
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, svc.GetFullURL(), bytes.NewBuffer(jsonData))
435
+ if err != nil {
436
+ return nil, fmt.Errorf("failed to create request: %v", err)
437
+ }
438
+ httpReq.Header.Set("Content-Type", "application/json")
439
+ httpReq.Header.Set("Authorization", "Bearer "+svc.APIKey)
440
+
441
+ resp, err := client.Do(httpReq)
442
+ if err != nil {
443
+ return nil, fmt.Errorf("failed to send thinking request: %v", err)
444
+ }
445
+ defer resp.Body.Close()
446
+
447
+ respBody, err := io.ReadAll(resp.Body)
448
+ if err != nil {
449
+ return nil, fmt.Errorf("failed to read response body: %v", err)
450
+ }
451
+ if s.config.Global.Log.Debug.PrintResponse {
452
+ logger.LogContent("Thinking Service Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
453
+ }
454
+ if resp.StatusCode != http.StatusOK {
455
+ return nil, fmt.Errorf("thinking service returned %d: %s", resp.StatusCode, string(respBody))
456
+ }
457
+
458
+ var thinkingResp ChatCompletionResponse
459
+ if err := json.Unmarshal(respBody, &thinkingResp); err != nil {
460
+ return nil, fmt.Errorf("failed to unmarshal thinking response: %v", err)
461
+ }
462
+ if len(thinkingResp.Choices) == 0 {
463
+ return nil, fmt.Errorf("thinking service returned no choices")
464
+ }
465
+
466
+ result := &ThinkingResponse{}
467
+ choice := thinkingResp.Choices[0]
468
+
469
+ if svc.Mode == "full" {
470
+ result.ReasoningContent = choice.Message.Content
471
+ result.Content = "Based on the above detailed analysis."
472
+ } else {
473
+ if choice.Message.ReasoningContent != nil {
474
+ switch v := choice.Message.ReasoningContent.(type) {
475
+ case string:
476
+ result.ReasoningContent = v
477
+ case map[string]interface{}:
478
+ if j, err := json.Marshal(v); err == nil {
479
+ result.ReasoningContent = string(j)
480
+ }
481
+ }
482
+ }
483
+ if result.ReasoningContent == "" {
484
+ result.ReasoningContent = choice.Message.Content
485
+ }
486
+ result.Content = "Based on the above reasoning."
487
+ }
488
+ return result, nil
489
+ }
490
+
491
+ func (s *Server) prepareEnhancedRequest(originalReq *ChatCompletionRequest, thinkingResp *ThinkingResponse, svc ThinkingService) *ChatCompletionRequest {
492
+ newReq := *originalReq
493
+ var systemPrompt string
494
+ if svc.Mode == "full" {
495
+ systemPrompt = fmt.Sprintf(`Consider the following detailed analysis (not shown to user):
496
+ %s
497
+
498
+ Provide a clear, concise response that incorporates insights from this analysis.`, thinkingResp.ReasoningContent)
499
+ } else {
500
+ systemPrompt = fmt.Sprintf(`Previous thinking process:
501
+ %s
502
+ Please consider the above thinking process in your response.`, thinkingResp.ReasoningContent)
503
+ }
504
+ newReq.Messages = append([]ChatCompletionMessage{
505
+ {Role: "system", Content: systemPrompt},
506
+ }, newReq.Messages...)
507
+ return &newReq
508
+ }
509
+
510
+ func (s *Server) forwardRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, channel Channel) {
511
+ logger := NewRequestLogger(s.config)
512
+ if s.config.Global.Log.Debug.PrintRequest {
513
+ logger.LogContent("Forward Request", req, s.config.Global.Log.Debug.MaxContentLength)
514
+ }
515
+ jsonData, err := json.Marshal(req)
516
+ if err != nil {
517
+ http.Error(w, "Failed to marshal request", http.StatusInternalServerError)
518
+ return
519
+ }
520
+
521
+ client, err := createHTTPClient(channel.Proxy, time.Duration(channel.Timeout)*time.Second)
522
+ if err != nil {
523
+ http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError)
524
+ return
525
+ }
526
+
527
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, channel.GetFullURL(), bytes.NewBuffer(jsonData))
528
+ if err != nil {
529
+ http.Error(w, "Failed to create request", http.StatusInternalServerError)
530
+ return
531
+ }
532
+ httpReq.Header.Set("Content-Type", "application/json")
533
+ httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
534
+
535
+ resp, err := client.Do(httpReq)
536
+ if err != nil {
537
+ http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError)
538
+ return
539
+ }
540
+ defer resp.Body.Close()
541
+
542
+ respBody, err := io.ReadAll(resp.Body)
543
+ if err != nil {
544
+ http.Error(w, "Failed to read response", http.StatusInternalServerError)
545
+ return
546
+ }
547
+ if s.config.Global.Log.Debug.PrintResponse {
548
+ logger.LogContent("Forward Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
549
+ }
550
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
551
+ http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode)
552
+ return
553
+ }
554
+
555
+ for k, vals := range resp.Header {
556
+ for _, v := range vals {
557
+ w.Header().Add(k, v)
558
+ }
559
+ }
560
+ w.WriteHeader(resp.StatusCode)
561
+ w.Write(respBody)
562
+ }
563
+
564
+ func (s *Server) forwardModelsRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, targetChannel Channel) {
565
+ logger := NewRequestLogger(s.config)
566
+ if s.config.Global.Log.Debug.PrintRequest {
567
+ logger.LogContent("/v1/models Request", req, s.config.Global.Log.Debug.MaxContentLength)
568
+ }
569
+ fullChatURL := targetChannel.GetFullURL()
570
+ parsedURL, err := url.Parse(fullChatURL)
571
+ if err != nil {
572
+ http.Error(w, "Failed to parse channel URL", http.StatusInternalServerError)
573
+ return
574
+ }
575
+ baseURL := parsedURL.Scheme + "://" + parsedURL.Host
576
+ modelsURL := strings.TrimSuffix(baseURL, "/") + "/v1/models"
577
+
578
+ client, err := createHTTPClient(targetChannel.Proxy, time.Duration(targetChannel.Timeout)*time.Second)
579
+ if err != nil {
580
+ http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError)
581
+ return
582
+ }
583
+
584
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil)
585
+ if err != nil {
586
+ http.Error(w, "Failed to create request", http.StatusInternalServerError)
587
+ return
588
+ }
589
+ httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
590
+
591
+ resp, err := client.Do(httpReq)
592
+ if err != nil {
593
+ http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError)
594
+ return
595
+ }
596
+ defer resp.Body.Close()
597
+
598
+ respBody, err := io.ReadAll(resp.Body)
599
+ if err != nil {
600
+ http.Error(w, "Failed to read response", http.StatusInternalServerError)
601
+ return
602
+ }
603
+ if s.config.Global.Log.Debug.PrintResponse {
604
+ logger.LogContent("/v1/models Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength)
605
+ }
606
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
607
+ http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode)
608
+ return
609
+ }
610
+
611
+ for k, vals := range resp.Header {
612
+ for _, v := range vals {
613
+ w.Header().Add(k, v)
614
+ }
615
+ }
616
+ w.WriteHeader(resp.StatusCode)
617
+ w.Write(respBody)
618
+ }
619
+
620
+ // ---------------------- 流式处理 ----------------------
621
+
622
+ // collectedReasoningBuffer 用于收集思考服务返回的 reasoning_content(标准模式只收集 reasoning_content;full 模式收集全部)
623
+ type collectedReasoningBuffer struct {
624
+ builder strings.Builder
625
+ mode string
626
+ }
627
+
628
+ func (cb *collectedReasoningBuffer) append(text string) {
629
+ cb.builder.WriteString(text)
630
+ }
631
+
632
+ func (cb *collectedReasoningBuffer) get() string {
633
+ return cb.builder.String()
634
+ }
635
+
636
+ type StreamHandler struct {
637
+ thinkingService ThinkingService
638
+ targetChannel Channel
639
+ writer http.ResponseWriter
640
+ flusher http.Flusher
641
+ config *Config
642
+ }
643
+
644
+ func NewStreamHandler(w http.ResponseWriter, thinkingService ThinkingService, targetChannel Channel, config *Config) (*StreamHandler, error) {
645
+ flusher, ok := w.(http.Flusher)
646
+ if !ok {
647
+ return nil, fmt.Errorf("streaming not supported by this writer")
648
+ }
649
+ return &StreamHandler{
650
+ thinkingService: thinkingService,
651
+ targetChannel: targetChannel,
652
+ writer: w,
653
+ flusher: flusher,
654
+ config: config,
655
+ }, nil
656
+ }
657
+
658
+ // HandleRequest 分两阶段:先流式接收思考服务 SSE 并转发给客户端(只转发包含 reasoning_content 的 chunk,标准模式遇到纯 content 则中断);
659
+ // 然后使用收集到的 reasoning_content 构造最终请求,发起目标 Channel 的 SSE 并转发给客户端。
660
+ func (h *StreamHandler) HandleRequest(ctx context.Context, req *ChatCompletionRequest) error {
661
+ h.writer.Header().Set("Content-Type", "text/event-stream")
662
+ h.writer.Header().Set("Cache-Control", "no-cache")
663
+ h.writer.Header().Set("Connection", "keep-alive")
664
+
665
+ logger := NewRequestLogger(h.config)
666
+ reasonBuf := &collectedReasoningBuffer{mode: h.thinkingService.Mode}
667
+
668
+ // 阶段一:流式接收思考服务 SSE
669
+ if err := h.streamThinkingService(ctx, req, reasonBuf, logger); err != nil {
670
+ return err
671
+ }
672
+
673
+ // 阶段二:构造最终请求,并流式转发目标 Channel SSE
674
+ finalReq := h.prepareFinalRequest(req, reasonBuf.get())
675
+ if err := h.streamFinalChannel(ctx, finalReq, logger); err != nil {
676
+ return err
677
+ }
678
+ return nil
679
+ }
680
+
681
+ // streamThinkingService 连接思考服务 SSE,原样转发包含 reasoning_content 的 SSE chunk给客户端,同时收集 reasoning_content。
682
+ // 在标准模式下,一旦遇到只返回 non-empty 的 content(且 reasoning_content 为空),则立即中断,不转发该 chunk。
683
+ func (h *StreamHandler) streamThinkingService(ctx context.Context, req *ChatCompletionRequest, reasonBuf *collectedReasoningBuffer, logger *RequestLogger) error {
684
+ thinkingReq := *req
685
+ thinkingReq.Model = h.thinkingService.Model
686
+ thinkingReq.APIKey = h.thinkingService.APIKey
687
+
688
+ var systemPrompt string
689
+ if h.thinkingService.Mode == "full" {
690
+ systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly."
691
+ } else {
692
+ systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step."
693
+ }
694
+ messages := append([]ChatCompletionMessage{
695
+ {Role: "system", Content: systemPrompt},
696
+ }, thinkingReq.Messages...)
697
+
698
+ temp := 0.7
699
+ if h.thinkingService.Temperature != nil {
700
+ temp = *h.thinkingService.Temperature
701
+ }
702
+ bodyMap := map[string]interface{}{
703
+ "model": thinkingReq.Model,
704
+ "messages": messages,
705
+ "temperature": temp,
706
+ "stream": true,
707
+ }
708
+ if isValidReasoningEffort(h.thinkingService.ReasoningEffort) {
709
+ bodyMap["reasoning_effort"] = h.thinkingService.ReasoningEffort
710
+ }
711
+ if isValidReasoningFormat(h.thinkingService.ReasoningFormat) {
712
+ bodyMap["reasoning_format"] = h.thinkingService.ReasoningFormat
713
+ }
714
+
715
+ jsonData, err := json.Marshal(bodyMap)
716
+ if err != nil {
717
+ return err
718
+ }
719
+
720
+ client, err := createHTTPClient(h.thinkingService.Proxy, time.Duration(h.thinkingService.Timeout)*time.Second)
721
+ if err != nil {
722
+ return err
723
+ }
724
+
725
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.thinkingService.GetFullURL(), bytes.NewBuffer(jsonData))
726
+ if err != nil {
727
+ return err
728
+ }
729
+ httpReq.Header.Set("Content-Type", "application/json")
730
+ httpReq.Header.Set("Authorization", "Bearer "+h.thinkingService.APIKey)
731
+
732
+ log.Printf("Starting ThinkingService SSE => %s (mode=%s, force_stop=%v)", h.thinkingService.GetFullURL(), h.thinkingService.Mode, h.thinkingService.ForceStopDeepThinking)
733
+ resp, err := client.Do(httpReq)
734
+ if err != nil {
735
+ return err
736
+ }
737
+ defer resp.Body.Close()
738
+
739
+ if resp.StatusCode != http.StatusOK {
740
+ b, _ := io.ReadAll(resp.Body)
741
+ return fmt.Errorf("thinking service status=%d, body=%s", resp.StatusCode, string(b))
742
+ }
743
+
744
+ reader := bufio.NewReader(resp.Body)
745
+ var lastLine string
746
+ // 标识是否需中断思考 SSE(标准模式下)
747
+ forceStop := false
748
+
749
+ for {
750
+ select {
751
+ case <-ctx.Done():
752
+ return ctx.Err()
753
+ default:
754
+ }
755
+ line, err := reader.ReadString('\n')
756
+ if err != nil {
757
+ if err == io.EOF {
758
+ break
759
+ }
760
+ return err
761
+ }
762
+ line = strings.TrimSpace(line)
763
+ if line == "" || line == lastLine {
764
+ continue
765
+ }
766
+ lastLine = line
767
+
768
+ if !strings.HasPrefix(line, "data: ") {
769
+ continue
770
+ }
771
+ dataPart := strings.TrimPrefix(line, "data: ")
772
+ if dataPart == "[DONE]" {
773
+ // 思考服务结束:直接中断(不发送特殊标记)
774
+ break
775
+ }
776
+
777
+ var chunk struct {
778
+ Choices []struct {
779
+ Delta struct {
780
+ Content string `json:"content,omitempty"`
781
+ ReasoningContent string `json:"reasoning_content,omitempty"`
782
+ } `json:"delta"`
783
+ FinishReason *string `json:"finish_reason,omitempty"`
784
+ } `json:"choices"`
785
+ }
786
+ if err := json.Unmarshal([]byte(dataPart), &chunk); err != nil {
787
+ // 如果解析失败,直接原样转发
788
+ h.writer.Write([]byte("data: " + dataPart + "\n\n"))
789
+ h.flusher.Flush()
790
+ continue
791
+ }
792
+
793
+ if len(chunk.Choices) > 0 {
794
+ c := chunk.Choices[0]
795
+ if h.config.Global.Log.Debug.PrintResponse {
796
+ logger.LogContent("Thinking SSE chunk", chunk, h.config.Global.Log.Debug.MaxContentLength)
797
+ }
798
+
799
+ if h.thinkingService.Mode == "full" {
800
+ // full 模式:收集所有内容
801
+ if c.Delta.ReasoningContent != "" {
802
+ reasonBuf.append(c.Delta.ReasoningContent)
803
+ }
804
+ if c.Delta.Content != "" {
805
+ reasonBuf.append(c.Delta.Content)
806
+ }
807
+ // 原样转发整 chunk
808
+ forwardLine := "data: " + dataPart + "\n\n"
809
+ h.writer.Write([]byte(forwardLine))
810
+ h.flusher.Flush()
811
+ } else {
812
+ // standard 模式:只收集 reasoning_content
813
+ if c.Delta.ReasoningContent != "" {
814
+ reasonBuf.append(c.Delta.ReasoningContent)
815
+ // 转发该 chunk
816
+ forwardLine := "data: " + dataPart + "\n\n"
817
+ h.writer.Write([]byte(forwardLine))
818
+ h.flusher.Flush()
819
+ }
820
+ // 如果遇到 chunk 中只有 content(且 reasoning_content 为空),认为思考链结束,不转发该 chunk
821
+ if c.Delta.Content != "" && strings.TrimSpace(c.Delta.ReasoningContent) == "" {
822
+ forceStop = true
823
+ }
824
+ }
825
+
826
+ // 如果 finishReason 非空,也认为结束
827
+ if c.FinishReason != nil && *c.FinishReason != "" {
828
+ forceStop = true
829
+ }
830
+ }
831
+
832
+ if forceStop {
833
+ break
834
+ }
835
+ }
836
+ // 读空剩余数据
837
+ io.Copy(io.Discard, reader)
838
+ return nil
839
+ }
840
+
841
+ func (h *StreamHandler) prepareFinalRequest(originalReq *ChatCompletionRequest, reasoningCollected string) *ChatCompletionRequest {
842
+ req := *originalReq
843
+ var systemPrompt string
844
+ if h.thinkingService.Mode == "full" {
845
+ systemPrompt = fmt.Sprintf(
846
+ `Consider the following detailed analysis (not shown to user):
847
+ %s
848
+
849
+ Provide a clear, concise response that incorporates insights from this analysis.`,
850
+ reasoningCollected,
851
+ )
852
+ } else {
853
+ systemPrompt = fmt.Sprintf(
854
+ `Previous thinking process:
855
+ %s
856
+ Please consider the above thinking process in your response.`,
857
+ reasoningCollected,
858
+ )
859
+ }
860
+ req.Messages = append([]ChatCompletionMessage{
861
+ {Role: "system", Content: systemPrompt},
862
+ }, req.Messages...)
863
+ return &req
864
+ }
865
+
866
+ func (h *StreamHandler) streamFinalChannel(ctx context.Context, req *ChatCompletionRequest, logger *RequestLogger) error {
867
+ if h.config.Global.Log.Debug.PrintRequest {
868
+ logger.LogContent("Final Request to Channel", req, h.config.Global.Log.Debug.MaxContentLength)
869
+ }
870
+
871
+ jsonData, err := json.Marshal(req)
872
+ if err != nil {
873
+ return err
874
+ }
875
+ client, err := createHTTPClient(h.targetChannel.Proxy, time.Duration(h.targetChannel.Timeout)*time.Second)
876
+ if err != nil {
877
+ return err
878
+ }
879
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.targetChannel.GetFullURL(), bytes.NewBuffer(jsonData))
880
+ if err != nil {
881
+ return err
882
+ }
883
+ httpReq.Header.Set("Content-Type", "application/json")
884
+ httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
885
+
886
+ resp, err := client.Do(httpReq)
887
+ if err != nil {
888
+ return err
889
+ }
890
+ defer resp.Body.Close()
891
+
892
+ if resp.StatusCode != http.StatusOK {
893
+ b, _ := io.ReadAll(resp.Body)
894
+ return fmt.Errorf("target channel status=%d, body=%s", resp.StatusCode, string(b))
895
+ }
896
+
897
+ reader := bufio.NewReader(resp.Body)
898
+ var lastLine string
899
+
900
+ for {
901
+ select {
902
+ case <-ctx.Done():
903
+ return ctx.Err()
904
+ default:
905
+ }
906
+ line, err := reader.ReadString('\n')
907
+ if err != nil {
908
+ if err == io.EOF {
909
+ break
910
+ }
911
+ return err
912
+ }
913
+ line = strings.TrimSpace(line)
914
+ if line == "" || line == lastLine {
915
+ continue
916
+ }
917
+ lastLine = line
918
+
919
+ if !strings.HasPrefix(line, "data: ") {
920
+ continue
921
+ }
922
+ data := strings.TrimPrefix(line, "data: ")
923
+ if data == "[DONE]" {
924
+ doneLine := "data: [DONE]\n\n"
925
+ h.writer.Write([]byte(doneLine))
926
+ h.flusher.Flush()
927
+ break
928
+ }
929
+ forwardLine := "data: " + data + "\n\n"
930
+ h.writer.Write([]byte(forwardLine))
931
+ h.flusher.Flush()
932
+
933
+ if h.config.Global.Log.Debug.PrintResponse {
934
+ logger.LogContent("Channel SSE chunk", forwardLine, h.config.Global.Log.Debug.MaxContentLength)
935
+ }
936
+ }
937
+ io.Copy(io.Discard, reader)
938
+ return nil
939
+ }
940
+
941
+ // ---------------------- HTTP Client 工具 ----------------------
942
+
943
+ func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) {
944
+ transport := &http.Transport{
945
+ DialContext: (&net.Dialer{
946
+ Timeout: 30 * time.Second,
947
+ KeepAlive: 30 * time.Second,
948
+ }).DialContext,
949
+ ForceAttemptHTTP2: true,
950
+ MaxIdleConns: 100,
951
+ IdleConnTimeout: 90 * time.Second,
952
+ TLSHandshakeTimeout: 10 * time.Second,
953
+ ExpectContinueTimeout: 1 * time.Second,
954
+ }
955
+ if proxyURL != "" {
956
+ parsedURL, err := url.Parse(proxyURL)
957
+ if err != nil {
958
+ return nil, fmt.Errorf("invalid proxy URL: %v", err)
959
+ }
960
+ switch parsedURL.Scheme {
961
+ case "http", "https":
962
+ transport.Proxy = http.ProxyURL(parsedURL)
963
+ case "socks5":
964
+ dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
965
+ if err != nil {
966
+ return nil, fmt.Errorf("failed to create SOCKS5 dialer: %v", err)
967
+ }
968
+ transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
969
+ return dialer.Dial(network, addr)
970
+ }
971
+ default:
972
+ return nil, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
973
+ }
974
+ }
975
+ return &http.Client{
976
+ Transport: transport,
977
+ Timeout: timeout,
978
+ }, nil
979
+ }
980
+
981
+ func maskSensitiveHeaders(headers http.Header) http.Header {
982
+ masked := make(http.Header)
983
+ for k, vals := range headers {
984
+ if strings.ToLower(k) == "authorization" {
985
+ masked[k] = []string{"Bearer ****"}
986
+ } else {
987
+ masked[k] = vals
988
+ }
989
+ }
990
+ return masked
991
+ }
992
+
993
+ func isValidReasoningEffort(effort string) bool {
994
+ switch strings.ToLower(effort) {
995
+ case "low", "medium", "high":
996
+ return true
997
+ }
998
+ return false
999
+ }
1000
+
1001
+ func isValidReasoningFormat(format string) bool {
1002
+ switch strings.ToLower(format) {
1003
+ case "parsed", "raw", "hidden":
1004
+ return true
1005
+ }
1006
+ return false
1007
+ }
1008
+
1009
+ // ---------------------- 配置加载与验证 ----------------------
1010
+
1011
+ // 首先添加这些辅助函数
1012
+ func getEnvOrDefault(key, defaultValue string) string {
1013
+ if value := os.Getenv(key); value != "" {
1014
+ return value
1015
+ }
1016
+ return defaultValue
1017
+ }
1018
+
1019
+ func getEnvIntOrDefault(key string, defaultValue int) int {
1020
+ if value := os.Getenv(key); value != "" {
1021
+ if intValue, err := strconv.Atoi(value); err == nil {
1022
+ return intValue
1023
+ }
1024
+ }
1025
+ return defaultValue
1026
+ }
1027
+
1028
+ func getEnvBoolOrDefault(key string, defaultValue bool) bool {
1029
+ if value := os.Getenv(key); value != "" {
1030
+ if boolValue, err := strconv.ParseBool(value); err == nil {
1031
+ return boolValue
1032
+ }
1033
+ }
1034
+ return defaultValue
1035
+ }
1036
+
1037
+ func getEnvFloatPtr(key string, defaultValue float64) *float64 {
1038
+ if value := os.Getenv(key); value != "" {
1039
+ if floatValue, err := strconv.ParseFloat(value, 64); err == nil {
1040
+ return &floatValue
1041
+ }
1042
+ }
1043
+ return &defaultValue
1044
+ }
1045
+
1046
+ // 然后替换原有的 loadConfig 函数
1047
+ func loadConfig() (*Config, error) {
1048
+ // 优先从环境变量加载完整配置
1049
+ if envConfig := os.Getenv("CONFIG_JSON"); envConfig != "" {
1050
+ var cfg Config
1051
+ if err := json.Unmarshal([]byte(envConfig), &cfg); err == nil {
1052
+ if err := validateConfig(&cfg); err == nil {
1053
+ log.Println("Using configuration from CONFIG_JSON environment variable")
1054
+ return &cfg, nil
1055
+ }
1056
+ }
1057
+ }
1058
+
1059
+ // 如果没有完整的 JSON 配置,则从独立的环境变量构建配置
1060
+ log.Println("Building configuration from individual environment variables")
1061
+ cfg := &Config{
1062
+ ThinkingServices: []ThinkingService{
1063
+ {
1064
+ ID: 1,
1065
+ Name: getEnvOrDefault("THINKING_SERVICE_NAME", "modelscope-deepseek-thinking"),
1066
+ Mode: getEnvOrDefault("THINKING_SERVICE_MODE", "standard"),
1067
+ Model: getEnvOrDefault("THINKING_SERVICE_MODEL", "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"),
1068
+ BaseURL: getEnvOrDefault("THINKING_SERVICE_BASE_URL", "https://api-inference.modelscope.cn"),
1069
+ APIPath: getEnvOrDefault("THINKING_SERVICE_API_PATH", "/v1/chat/completions"),
1070
+ APIKey: os.Getenv("THINKING_SERVICE_API_KEY"), // 敏感信息必须从环境变量读取
1071
+ Timeout: getEnvIntOrDefault("THINKING_SERVICE_TIMEOUT", 6000),
1072
+ Weight: getEnvIntOrDefault("THINKING_SERVICE_WEIGHT", 100),
1073
+ Proxy: getEnvOrDefault("THINKING_SERVICE_PROXY", ""),
1074
+ ReasoningEffort: getEnvOrDefault("THINKING_SERVICE_REASONING_EFFORT", "high"),
1075
+ ReasoningFormat: getEnvOrDefault("THINKING_SERVICE_REASONING_FORMAT", "parsed"),
1076
+ Temperature: getEnvFloatPtr("THINKING_SERVICE_TEMPERATURE", 0.8),
1077
+ ForceStopDeepThinking: getEnvBoolOrDefault("THINKING_SERVICE_FORCE_STOP", false),
1078
+ },
1079
+ },
1080
+ Channels: map[string]Channel{
1081
+ "1": {
1082
+ Name: getEnvOrDefault("CHANNEL_NAME", "Gemini-channel"),
1083
+ BaseURL: getEnvOrDefault("CHANNEL_BASE_URL", "https://Richardlsr-gemini-balance.hf.space/hf"),
1084
+ APIPath: getEnvOrDefault("CHANNEL_API_PATH", "/v1/chat/completions"),
1085
+ Timeout: getEnvIntOrDefault("CHANNEL_TIMEOUT", 600),
1086
+ Proxy: getEnvOrDefault("CHANNEL_PROXY", ""),
1087
+ },
1088
+ },
1089
+ Global: GlobalConfig{
1090
+ Server: ServerConfig{
1091
+ Port: getEnvIntOrDefault("SERVER_PORT", 7860), // HF 默认端口
1092
+ Host: getEnvOrDefault("SERVER_HOST", "0.0.0.0"),
1093
+ ReadTimeout: getEnvIntOrDefault("SERVER_READ_TIMEOUT", 600),
1094
+ WriteTimeout: getEnvIntOrDefault("SERVER_WRITE_TIMEOUT", 600),
1095
+ IdleTimeout: getEnvIntOrDefault("SERVER_IDLE_TIMEOUT", 600),
1096
+ },
1097
+ Log: LogConfig{
1098
+ Level: getEnvOrDefault("LOG_LEVEL", "info"),
1099
+ Format: getEnvOrDefault("LOG_FORMAT", "json"),
1100
+ Output: getEnvOrDefault("LOG_OUTPUT", "console"), // HF 环境下默认输出到控制台
1101
+ FilePath: getEnvOrDefault("LOG_FILE_PATH", "./logs/deepai.log"),
1102
+ Debug: DebugConfig{
1103
+ Enabled: getEnvBoolOrDefault("LOG_DEBUG_ENABLED", true),
1104
+ PrintRequest: getEnvBoolOrDefault("LOG_PRINT_REQUEST", true),
1105
+ PrintResponse: getEnvBoolOrDefault("LOG_PRINT_RESPONSE", true),
1106
+ MaxContentLength: getEnvIntOrDefault("LOG_MAX_CONTENT_LENGTH", 1000),
1107
+ },
1108
+ },
1109
+ },
1110
+ }
1111
+
1112
+ // 验证必需的环境变量
1113
+ if cfg.ThinkingServices[0].APIKey == "" {
1114
+ return nil, fmt.Errorf("THINKING_SERVICE_API_KEY environment variable is required")
1115
+ }
1116
+
1117
+ if err := validateConfig(cfg); err != nil {
1118
+ return nil, fmt.Errorf("config validation error: %v", err)
1119
+ }
1120
+
1121
+ return cfg, nil
1122
+ }
1123
+
1124
+ func validateConfig(config *Config) error {
1125
+ if len(config.ThinkingServices) == 0 {
1126
+ return fmt.Errorf("no thinking services configured")
1127
+ }
1128
+ if len(config.Channels) == 0 {
1129
+ return fmt.Errorf("no channels configured")
1130
+ }
1131
+ for i, svc := range config.ThinkingServices {
1132
+ if svc.BaseURL == "" {
1133
+ return fmt.Errorf("thinking service %s has empty baseURL", svc.Name)
1134
+ }
1135
+ if svc.APIKey == "" {
1136
+ return fmt.Errorf("thinking service %s has empty apiKey", svc.Name)
1137
+ }
1138
+ if svc.Timeout <= 0 {
1139
+ return fmt.Errorf("thinking service %s has invalid timeout", svc.Name)
1140
+ }
1141
+ if svc.Model == "" {
1142
+ return fmt.Errorf("thinking service %s has empty model", svc.Name)
1143
+ }
1144
+ if svc.Mode == "" {
1145
+ config.ThinkingServices[i].Mode = "standard"
1146
+ } else if svc.Mode != "standard" && svc.Mode != "full" {
1147
+ return fmt.Errorf("thinking service %s unknown mode=%s", svc.Name, svc.Mode)
1148
+ }
1149
+ }
1150
+ return nil
1151
+ }
1152
+
1153
+ func main() {
1154
+ log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile)
1155
+
1156
+ cfg, err := loadConfig()
1157
+ if err != nil {
1158
+ log.Fatalf("Failed to load config: %v", err)
1159
+ }
1160
+ log.Printf("Using config file: %s", viper.ConfigFileUsed())
1161
+
1162
+ server := NewServer(cfg)
1163
+
1164
+ done := make(chan os.Signal, 1)
1165
+ signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
1166
+
1167
+ go func() {
1168
+ if err := server.Start(); err != nil && err != http.ErrServerClosed {
1169
+ log.Fatalf("start server error: %v", err)
1170
+ }
1171
+ }()
1172
+ log.Printf("Server started successfully")
1173
+
1174
+ <-done
1175
+ log.Print("Server stopping...")
1176
+
1177
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
1178
+ defer cancel()
1179
+
1180
+ if err := server.Shutdown(ctx); err != nil {
1181
+ log.Printf("Server forced to shutdown: %v", err)
1182
+ }
1183
+ log.Print("Server stopped")
1184
+ }