|
package main |
|
|
|
import ( |
|
"bufio" |
|
"bytes" |
|
"context" |
|
"encoding/json" |
|
|
|
"fmt" |
|
"io" |
|
"log" |
|
"math/rand" |
|
"net" |
|
"net/http" |
|
"net/url" |
|
"os" |
|
"os/signal" |
|
|
|
"regexp" |
|
"strings" |
|
"sync" |
|
"syscall" |
|
"time" |
|
"strconv" |
|
"github.com/google/uuid" |
|
"github.com/spf13/viper" |
|
"golang.org/x/net/proxy" |
|
) |
|
|
|
|
|
|
|
type Config struct { |
|
ThinkingServices []ThinkingService `mapstructure:"thinking_services"` |
|
Channels map[string]Channel `mapstructure:"channels"` |
|
Global GlobalConfig `mapstructure:"global"` |
|
} |
|
|
|
type ThinkingService struct { |
|
ID int `mapstructure:"id"` |
|
Name string `mapstructure:"name"` |
|
Model string `mapstructure:"model"` |
|
BaseURL string `mapstructure:"base_url"` |
|
APIPath string `mapstructure:"api_path"` |
|
APIKey string `mapstructure:"api_key"` |
|
Timeout int `mapstructure:"timeout"` |
|
Retry int `mapstructure:"retry"` |
|
Weight int `mapstructure:"weight"` |
|
Proxy string `mapstructure:"proxy"` |
|
Mode string `mapstructure:"mode"` |
|
ReasoningEffort string `mapstructure:"reasoning_effort"` |
|
ReasoningFormat string `mapstructure:"reasoning_format"` |
|
Temperature *float64 `mapstructure:"temperature"` |
|
ForceStopDeepThinking bool `mapstructure:"force_stop_deep_thinking"` |
|
} |
|
|
|
func (s *ThinkingService) GetFullURL() string { |
|
return s.BaseURL + s.APIPath |
|
} |
|
|
|
type Channel struct { |
|
Name string `mapstructure:"name"` |
|
BaseURL string `mapstructure:"base_url"` |
|
APIPath string `mapstructure:"api_path"` |
|
Timeout int `mapstructure:"timeout"` |
|
Proxy string `mapstructure:"proxy"` |
|
} |
|
|
|
func (c *Channel) GetFullURL() string { |
|
return c.BaseURL + c.APIPath |
|
} |
|
|
|
type LogConfig struct { |
|
Level string `mapstructure:"level"` |
|
Format string `mapstructure:"format"` |
|
Output string `mapstructure:"output"` |
|
FilePath string `mapstructure:"file_path"` |
|
Debug DebugConfig `mapstructure:"debug"` |
|
} |
|
|
|
type DebugConfig struct { |
|
Enabled bool `mapstructure:"enabled"` |
|
PrintRequest bool `mapstructure:"print_request"` |
|
PrintResponse bool `mapstructure:"print_response"` |
|
MaxContentLength int `mapstructure:"max_content_length"` |
|
} |
|
|
|
type ProxyConfig struct { |
|
Enabled bool `mapstructure:"enabled"` |
|
Default string `mapstructure:"default"` |
|
AllowInsecure bool `mapstructure:"allow_insecure"` |
|
} |
|
|
|
type GlobalConfig struct { |
|
MaxRetries int `mapstructure:"max_retries"` |
|
DefaultTimeout int `mapstructure:"default_timeout"` |
|
ErrorCodes struct { |
|
RetryOn []int `mapstructure:"retry_on"` |
|
} `mapstructure:"error_codes"` |
|
Log LogConfig `mapstructure:"log"` |
|
Server ServerConfig `mapstructure:"server"` |
|
Proxy ProxyConfig `mapstructure:"proxy"` |
|
ConfigPaths []string `mapstructure:"config_paths"` |
|
Thinking ThinkingConfig `mapstructure:"thinking"` |
|
} |
|
|
|
type ServerConfig struct { |
|
Port int `mapstructure:"port"` |
|
Host string `mapstructure:"host"` |
|
ReadTimeout int `mapstructure:"read_timeout"` |
|
WriteTimeout int `mapstructure:"write_timeout"` |
|
IdleTimeout int `mapstructure:"idle_timeout"` |
|
} |
|
|
|
type ThinkingConfig struct { |
|
Enabled bool `mapstructure:"enabled"` |
|
AddToAllRequests bool `mapstructure:"add_to_all_requests"` |
|
Timeout int `mapstructure:"timeout"` |
|
} |
|
|
|
|
|
|
|
type ChatCompletionRequest struct { |
|
Model string `json:"model"` |
|
Messages []ChatCompletionMessage `json:"messages"` |
|
Temperature float64 `json:"temperature,omitempty"` |
|
MaxTokens int `json:"max_tokens,omitempty"` |
|
Stream bool `json:"stream,omitempty"` |
|
APIKey string `json:"-"` |
|
} |
|
|
|
type ChatCompletionMessage struct { |
|
Role string `json:"role"` |
|
Content string `json:"content"` |
|
ReasoningContent interface{} `json:"reasoning_content,omitempty"` |
|
} |
|
|
|
type ChatCompletionResponse struct { |
|
ID string `json:"id"` |
|
Object string `json:"object"` |
|
Created int64 `json:"created"` |
|
Model string `json:"model"` |
|
Choices []Choice `json:"choices"` |
|
Usage Usage `json:"usage"` |
|
} |
|
|
|
type Choice struct { |
|
Index int `json:"index"` |
|
Message ChatCompletionMessage `json:"message"` |
|
FinishReason string `json:"finish_reason"` |
|
} |
|
|
|
type Usage struct { |
|
PromptTokens int `json:"prompt_tokens"` |
|
CompletionTokens int `json:"completion_tokens"` |
|
TotalTokens int `json:"total_tokens"` |
|
} |
|
|
|
|
|
|
|
type RequestLogger struct { |
|
RequestID string |
|
Model string |
|
StartTime time.Time |
|
logs []string |
|
config *Config |
|
} |
|
|
|
func NewRequestLogger(config *Config) *RequestLogger { |
|
return &RequestLogger{ |
|
RequestID: uuid.New().String(), |
|
StartTime: time.Now(), |
|
logs: make([]string, 0), |
|
config: config, |
|
} |
|
} |
|
|
|
func (l *RequestLogger) Log(format string, args ...interface{}) { |
|
msg := fmt.Sprintf(format, args...) |
|
l.logs = append(l.logs, fmt.Sprintf("[%s] %s", time.Now().Format(time.RFC3339), msg)) |
|
log.Printf("[RequestID: %s] %s", l.RequestID, msg) |
|
} |
|
|
|
func (l *RequestLogger) LogContent(contentType string, content interface{}, maxLength int) { |
|
if !l.config.Global.Log.Debug.Enabled { |
|
return |
|
} |
|
sanitizedContent := sanitizeJSON(content) |
|
truncatedContent := truncateContent(sanitizedContent, maxLength) |
|
l.Log("%s Content:\n%s", contentType, truncatedContent) |
|
} |
|
|
|
func truncateContent(content string, maxLength int) string { |
|
if len(content) <= maxLength { |
|
return content |
|
} |
|
return content[:maxLength] + "... (truncated)" |
|
} |
|
|
|
func sanitizeJSON(data interface{}) string { |
|
sanitized, err := json.Marshal(data) |
|
if err != nil { |
|
return "Failed to marshal JSON" |
|
} |
|
content := string(sanitized) |
|
sensitivePattern := `"api_key":\s*"[^"]*"` |
|
content = regexp.MustCompile(sensitivePattern).ReplaceAllString(content, `"api_key":"****"`) |
|
return content |
|
} |
|
|
|
func extractRealAPIKey(fullKey string) string { |
|
parts := strings.Split(fullKey, "-") |
|
if len(parts) >= 3 && (parts[0] == "deep" || parts[0] == "openai") { |
|
return strings.Join(parts[2:], "-") |
|
} |
|
return fullKey |
|
} |
|
|
|
func extractChannelID(fullKey string) string { |
|
parts := strings.Split(fullKey, "-") |
|
if len(parts) >= 2 && (parts[0] == "deep" || parts[0] == "openai") { |
|
return parts[1] |
|
} |
|
return "1" |
|
} |
|
|
|
func logAPIKey(key string) string { |
|
if len(key) <= 8 { |
|
return "****" |
|
} |
|
return key[:4] + "..." + key[len(key)-4:] |
|
} |
|
|
|
|
|
|
|
type Server struct { |
|
config *Config |
|
srv *http.Server |
|
} |
|
|
|
var ( |
|
randMu sync.Mutex |
|
randGen = rand.New(rand.NewSource(time.Now().UnixNano())) |
|
) |
|
|
|
func NewServer(config *Config) *Server { |
|
return &Server{ |
|
config: config, |
|
} |
|
} |
|
|
|
func (s *Server) Start() error { |
|
mux := http.NewServeMux() |
|
mux.HandleFunc("/v1/chat/completions", s.handleOpenAIRequests) |
|
mux.HandleFunc("/v1/models", s.handleOpenAIRequests) |
|
mux.HandleFunc("/health", s.handleHealth) |
|
|
|
s.srv = &http.Server{ |
|
Addr: fmt.Sprintf("%s:%d", s.config.Global.Server.Host, s.config.Global.Server.Port), |
|
Handler: mux, |
|
ReadTimeout: time.Duration(s.config.Global.Server.ReadTimeout) * time.Second, |
|
WriteTimeout: time.Duration(s.config.Global.Server.WriteTimeout) * time.Second, |
|
IdleTimeout: time.Duration(s.config.Global.Server.IdleTimeout) * time.Second, |
|
} |
|
|
|
log.Printf("Server starting on %s\n", s.srv.Addr) |
|
return s.srv.ListenAndServe() |
|
} |
|
|
|
func (s *Server) Shutdown(ctx context.Context) error { |
|
return s.srv.Shutdown(ctx) |
|
} |
|
|
|
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { |
|
w.WriteHeader(http.StatusOK) |
|
json.NewEncoder(w).Encode(map[string]string{"status": "healthy"}) |
|
} |
|
|
|
func (s *Server) handleOpenAIRequests(w http.ResponseWriter, r *http.Request) { |
|
logger := NewRequestLogger(s.config) |
|
|
|
fullAPIKey := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") |
|
apiKey := extractRealAPIKey(fullAPIKey) |
|
channelID := extractChannelID(fullAPIKey) |
|
|
|
logger.Log("Received request for %s with API Key: %s", r.URL.Path, logAPIKey(fullAPIKey)) |
|
logger.Log("Extracted channel ID: %s", channelID) |
|
logger.Log("Extracted real API Key: %s", logAPIKey(apiKey)) |
|
|
|
targetChannel, ok := s.config.Channels[channelID] |
|
if !ok { |
|
http.Error(w, "Invalid channel", http.StatusBadRequest) |
|
return |
|
} |
|
|
|
if r.URL.Path == "/v1/models" { |
|
if r.Method != http.MethodGet { |
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) |
|
return |
|
} |
|
req := &ChatCompletionRequest{APIKey: apiKey} |
|
s.forwardModelsRequest(w, r.Context(), req, targetChannel) |
|
return |
|
} |
|
|
|
if r.Method != http.MethodPost { |
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) |
|
return |
|
} |
|
|
|
body, err := io.ReadAll(r.Body) |
|
if err != nil { |
|
logger.Log("Error reading request body: %v", err) |
|
http.Error(w, "Failed to read request", http.StatusBadRequest) |
|
return |
|
} |
|
r.Body.Close() |
|
r.Body = io.NopCloser(bytes.NewBuffer(body)) |
|
|
|
if s.config.Global.Log.Debug.PrintRequest { |
|
logger.LogContent("Request", string(body), s.config.Global.Log.Debug.MaxContentLength) |
|
} |
|
|
|
var req ChatCompletionRequest |
|
if err := json.NewDecoder(bytes.NewBuffer(body)).Decode(&req); err != nil { |
|
http.Error(w, "Invalid request body", http.StatusBadRequest) |
|
return |
|
} |
|
req.APIKey = apiKey |
|
|
|
thinkingService := s.getWeightedRandomThinkingService() |
|
logger.Log("Using thinking service: %s with API Key: %s", thinkingService.Name, logAPIKey(thinkingService.APIKey)) |
|
|
|
if req.Stream { |
|
handler, err := NewStreamHandler(w, thinkingService, targetChannel, s.config) |
|
if err != nil { |
|
http.Error(w, "Streaming not supported", http.StatusInternalServerError) |
|
return |
|
} |
|
if err := handler.HandleRequest(r.Context(), &req); err != nil { |
|
logger.Log("Stream handler error: %v", err) |
|
} |
|
} else { |
|
thinkingResp, err := s.processThinkingContent(r.Context(), &req, thinkingService) |
|
if err != nil { |
|
logger.Log("Error processing thinking content: %v", err) |
|
http.Error(w, "Thinking service error: "+err.Error(), http.StatusInternalServerError) |
|
return |
|
} |
|
enhancedReq := s.prepareEnhancedRequest(&req, thinkingResp, thinkingService) |
|
s.forwardRequest(w, r.Context(), enhancedReq, targetChannel) |
|
} |
|
} |
|
|
|
func (s *Server) getWeightedRandomThinkingService() ThinkingService { |
|
thinkingServices := s.config.ThinkingServices |
|
if len(thinkingServices) == 0 { |
|
return ThinkingService{} |
|
} |
|
totalWeight := 0 |
|
for _, svc := range thinkingServices { |
|
totalWeight += svc.Weight |
|
} |
|
if totalWeight <= 0 { |
|
log.Println("Warning: Total weight of thinking services is not positive, using first service as default.") |
|
return thinkingServices[0] |
|
} |
|
randMu.Lock() |
|
randNum := randGen.Intn(totalWeight) |
|
randMu.Unlock() |
|
currentSum := 0 |
|
for _, svc := range thinkingServices { |
|
currentSum += svc.Weight |
|
if randNum < currentSum { |
|
return svc |
|
} |
|
} |
|
return thinkingServices[0] |
|
} |
|
|
|
|
|
|
|
type ThinkingResponse struct { |
|
Content string |
|
ReasoningContent string |
|
} |
|
|
|
func (s *Server) processThinkingContent(ctx context.Context, req *ChatCompletionRequest, svc ThinkingService) (*ThinkingResponse, error) { |
|
logger := NewRequestLogger(s.config) |
|
log.Printf("Getting thinking content from service: %s (mode=%s)", svc.Name, svc.Mode) |
|
|
|
thinkingReq := *req |
|
thinkingReq.Model = svc.Model |
|
thinkingReq.APIKey = svc.APIKey |
|
|
|
var systemPrompt string |
|
if svc.Mode == "full" { |
|
systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly." |
|
} else { |
|
systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step." |
|
} |
|
thinkingReq.Messages = append([]ChatCompletionMessage{ |
|
{Role: "system", Content: systemPrompt}, |
|
}, thinkingReq.Messages...) |
|
|
|
temp := 0.7 |
|
if svc.Temperature != nil { |
|
temp = *svc.Temperature |
|
} |
|
payload := map[string]interface{}{ |
|
"model": svc.Model, |
|
"messages": thinkingReq.Messages, |
|
"stream": false, |
|
"temperature": temp, |
|
} |
|
if isValidReasoningEffort(svc.ReasoningEffort) { |
|
payload["reasoning_effort"] = svc.ReasoningEffort |
|
} |
|
if isValidReasoningFormat(svc.ReasoningFormat) { |
|
payload["reasoning_format"] = svc.ReasoningFormat |
|
} |
|
|
|
if s.config.Global.Log.Debug.PrintRequest { |
|
logger.LogContent("Thinking Service Request", payload, s.config.Global.Log.Debug.MaxContentLength) |
|
} |
|
|
|
jsonData, err := json.Marshal(payload) |
|
if err != nil { |
|
return nil, fmt.Errorf("failed to marshal thinking request: %v", err) |
|
} |
|
client, err := createHTTPClient(svc.Proxy, time.Duration(svc.Timeout)*time.Second) |
|
if err != nil { |
|
return nil, fmt.Errorf("failed to create HTTP client: %v", err) |
|
} |
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, svc.GetFullURL(), bytes.NewBuffer(jsonData)) |
|
if err != nil { |
|
return nil, fmt.Errorf("failed to create request: %v", err) |
|
} |
|
httpReq.Header.Set("Content-Type", "application/json") |
|
httpReq.Header.Set("Authorization", "Bearer "+svc.APIKey) |
|
|
|
resp, err := client.Do(httpReq) |
|
if err != nil { |
|
return nil, fmt.Errorf("failed to send thinking request: %v", err) |
|
} |
|
defer resp.Body.Close() |
|
|
|
respBody, err := io.ReadAll(resp.Body) |
|
if err != nil { |
|
return nil, fmt.Errorf("failed to read response body: %v", err) |
|
} |
|
if s.config.Global.Log.Debug.PrintResponse { |
|
logger.LogContent("Thinking Service Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength) |
|
} |
|
if resp.StatusCode != http.StatusOK { |
|
return nil, fmt.Errorf("thinking service returned %d: %s", resp.StatusCode, string(respBody)) |
|
} |
|
|
|
var thinkingResp ChatCompletionResponse |
|
if err := json.Unmarshal(respBody, &thinkingResp); err != nil { |
|
return nil, fmt.Errorf("failed to unmarshal thinking response: %v", err) |
|
} |
|
if len(thinkingResp.Choices) == 0 { |
|
return nil, fmt.Errorf("thinking service returned no choices") |
|
} |
|
|
|
result := &ThinkingResponse{} |
|
choice := thinkingResp.Choices[0] |
|
|
|
if svc.Mode == "full" { |
|
result.ReasoningContent = choice.Message.Content |
|
result.Content = "Based on the above detailed analysis." |
|
} else { |
|
if choice.Message.ReasoningContent != nil { |
|
switch v := choice.Message.ReasoningContent.(type) { |
|
case string: |
|
result.ReasoningContent = v |
|
case map[string]interface{}: |
|
if j, err := json.Marshal(v); err == nil { |
|
result.ReasoningContent = string(j) |
|
} |
|
} |
|
} |
|
if result.ReasoningContent == "" { |
|
result.ReasoningContent = choice.Message.Content |
|
} |
|
result.Content = "Based on the above reasoning." |
|
} |
|
return result, nil |
|
} |
|
|
|
func (s *Server) prepareEnhancedRequest(originalReq *ChatCompletionRequest, thinkingResp *ThinkingResponse, svc ThinkingService) *ChatCompletionRequest { |
|
newReq := *originalReq |
|
var systemPrompt string |
|
if svc.Mode == "full" { |
|
systemPrompt = fmt.Sprintf(`Consider the following detailed analysis (not shown to user): |
|
%s |
|
|
|
Provide a clear, concise response that incorporates insights from this analysis.`, thinkingResp.ReasoningContent) |
|
} else { |
|
systemPrompt = fmt.Sprintf(`Previous thinking process: |
|
%s |
|
Please consider the above thinking process in your response.`, thinkingResp.ReasoningContent) |
|
} |
|
newReq.Messages = append([]ChatCompletionMessage{ |
|
{Role: "system", Content: systemPrompt}, |
|
}, newReq.Messages...) |
|
return &newReq |
|
} |
|
|
|
func (s *Server) forwardRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, channel Channel) { |
|
logger := NewRequestLogger(s.config) |
|
if s.config.Global.Log.Debug.PrintRequest { |
|
logger.LogContent("Forward Request", req, s.config.Global.Log.Debug.MaxContentLength) |
|
} |
|
jsonData, err := json.Marshal(req) |
|
if err != nil { |
|
http.Error(w, "Failed to marshal request", http.StatusInternalServerError) |
|
return |
|
} |
|
|
|
client, err := createHTTPClient(channel.Proxy, time.Duration(channel.Timeout)*time.Second) |
|
if err != nil { |
|
http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError) |
|
return |
|
} |
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, channel.GetFullURL(), bytes.NewBuffer(jsonData)) |
|
if err != nil { |
|
http.Error(w, "Failed to create request", http.StatusInternalServerError) |
|
return |
|
} |
|
httpReq.Header.Set("Content-Type", "application/json") |
|
httpReq.Header.Set("Authorization", "Bearer "+req.APIKey) |
|
|
|
resp, err := client.Do(httpReq) |
|
if err != nil { |
|
http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError) |
|
return |
|
} |
|
defer resp.Body.Close() |
|
|
|
respBody, err := io.ReadAll(resp.Body) |
|
if err != nil { |
|
http.Error(w, "Failed to read response", http.StatusInternalServerError) |
|
return |
|
} |
|
if s.config.Global.Log.Debug.PrintResponse { |
|
logger.LogContent("Forward Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength) |
|
} |
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 { |
|
http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode) |
|
return |
|
} |
|
|
|
for k, vals := range resp.Header { |
|
for _, v := range vals { |
|
w.Header().Add(k, v) |
|
} |
|
} |
|
w.WriteHeader(resp.StatusCode) |
|
w.Write(respBody) |
|
} |
|
|
|
func (s *Server) forwardModelsRequest(w http.ResponseWriter, ctx context.Context, req *ChatCompletionRequest, targetChannel Channel) { |
|
logger := NewRequestLogger(s.config) |
|
if s.config.Global.Log.Debug.PrintRequest { |
|
logger.LogContent("/v1/models Request", req, s.config.Global.Log.Debug.MaxContentLength) |
|
} |
|
fullChatURL := targetChannel.GetFullURL() |
|
parsedURL, err := url.Parse(fullChatURL) |
|
if err != nil { |
|
http.Error(w, "Failed to parse channel URL", http.StatusInternalServerError) |
|
return |
|
} |
|
baseURL := parsedURL.Scheme + "://" + parsedURL.Host |
|
modelsURL := strings.TrimSuffix(baseURL, "/") + "/v1/models" |
|
|
|
client, err := createHTTPClient(targetChannel.Proxy, time.Duration(targetChannel.Timeout)*time.Second) |
|
if err != nil { |
|
http.Error(w, fmt.Sprintf("Failed to create HTTP client: %v", err), http.StatusInternalServerError) |
|
return |
|
} |
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil) |
|
if err != nil { |
|
http.Error(w, "Failed to create request", http.StatusInternalServerError) |
|
return |
|
} |
|
httpReq.Header.Set("Authorization", "Bearer "+req.APIKey) |
|
|
|
resp, err := client.Do(httpReq) |
|
if err != nil { |
|
http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusInternalServerError) |
|
return |
|
} |
|
defer resp.Body.Close() |
|
|
|
respBody, err := io.ReadAll(resp.Body) |
|
if err != nil { |
|
http.Error(w, "Failed to read response", http.StatusInternalServerError) |
|
return |
|
} |
|
if s.config.Global.Log.Debug.PrintResponse { |
|
logger.LogContent("/v1/models Response", string(respBody), s.config.Global.Log.Debug.MaxContentLength) |
|
} |
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 { |
|
http.Error(w, fmt.Sprintf("Target server error: %s", resp.Status), resp.StatusCode) |
|
return |
|
} |
|
|
|
for k, vals := range resp.Header { |
|
for _, v := range vals { |
|
w.Header().Add(k, v) |
|
} |
|
} |
|
w.WriteHeader(resp.StatusCode) |
|
w.Write(respBody) |
|
} |
|
|
|
|
|
|
|
|
|
type collectedReasoningBuffer struct { |
|
builder strings.Builder |
|
mode string |
|
} |
|
|
|
func (cb *collectedReasoningBuffer) append(text string) { |
|
cb.builder.WriteString(text) |
|
} |
|
|
|
func (cb *collectedReasoningBuffer) get() string { |
|
return cb.builder.String() |
|
} |
|
|
|
type StreamHandler struct { |
|
thinkingService ThinkingService |
|
targetChannel Channel |
|
writer http.ResponseWriter |
|
flusher http.Flusher |
|
config *Config |
|
} |
|
|
|
func NewStreamHandler(w http.ResponseWriter, thinkingService ThinkingService, targetChannel Channel, config *Config) (*StreamHandler, error) { |
|
flusher, ok := w.(http.Flusher) |
|
if !ok { |
|
return nil, fmt.Errorf("streaming not supported by this writer") |
|
} |
|
return &StreamHandler{ |
|
thinkingService: thinkingService, |
|
targetChannel: targetChannel, |
|
writer: w, |
|
flusher: flusher, |
|
config: config, |
|
}, nil |
|
} |
|
|
|
|
|
|
|
func (h *StreamHandler) HandleRequest(ctx context.Context, req *ChatCompletionRequest) error { |
|
h.writer.Header().Set("Content-Type", "text/event-stream") |
|
h.writer.Header().Set("Cache-Control", "no-cache") |
|
h.writer.Header().Set("Connection", "keep-alive") |
|
|
|
logger := NewRequestLogger(h.config) |
|
reasonBuf := &collectedReasoningBuffer{mode: h.thinkingService.Mode} |
|
|
|
|
|
if err := h.streamThinkingService(ctx, req, reasonBuf, logger); err != nil { |
|
return err |
|
} |
|
|
|
|
|
finalReq := h.prepareFinalRequest(req, reasonBuf.get()) |
|
if err := h.streamFinalChannel(ctx, finalReq, logger); err != nil { |
|
return err |
|
} |
|
return nil |
|
} |
|
|
|
|
|
|
|
func (h *StreamHandler) streamThinkingService(ctx context.Context, req *ChatCompletionRequest, reasonBuf *collectedReasoningBuffer, logger *RequestLogger) error { |
|
thinkingReq := *req |
|
thinkingReq.Model = h.thinkingService.Model |
|
thinkingReq.APIKey = h.thinkingService.APIKey |
|
|
|
var systemPrompt string |
|
if h.thinkingService.Mode == "full" { |
|
systemPrompt = "Provide a detailed step-by-step analysis of the question. Your entire response will be used as reasoning and won't be shown to the user directly." |
|
} else { |
|
systemPrompt = "Please provide a detailed reasoning process for your response. Think step by step." |
|
} |
|
messages := append([]ChatCompletionMessage{ |
|
{Role: "system", Content: systemPrompt}, |
|
}, thinkingReq.Messages...) |
|
|
|
temp := 0.7 |
|
if h.thinkingService.Temperature != nil { |
|
temp = *h.thinkingService.Temperature |
|
} |
|
bodyMap := map[string]interface{}{ |
|
"model": thinkingReq.Model, |
|
"messages": messages, |
|
"temperature": temp, |
|
"stream": true, |
|
} |
|
if isValidReasoningEffort(h.thinkingService.ReasoningEffort) { |
|
bodyMap["reasoning_effort"] = h.thinkingService.ReasoningEffort |
|
} |
|
if isValidReasoningFormat(h.thinkingService.ReasoningFormat) { |
|
bodyMap["reasoning_format"] = h.thinkingService.ReasoningFormat |
|
} |
|
|
|
jsonData, err := json.Marshal(bodyMap) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
client, err := createHTTPClient(h.thinkingService.Proxy, time.Duration(h.thinkingService.Timeout)*time.Second) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.thinkingService.GetFullURL(), bytes.NewBuffer(jsonData)) |
|
if err != nil { |
|
return err |
|
} |
|
httpReq.Header.Set("Content-Type", "application/json") |
|
httpReq.Header.Set("Authorization", "Bearer "+h.thinkingService.APIKey) |
|
|
|
log.Printf("Starting ThinkingService SSE => %s (mode=%s, force_stop=%v)", h.thinkingService.GetFullURL(), h.thinkingService.Mode, h.thinkingService.ForceStopDeepThinking) |
|
resp, err := client.Do(httpReq) |
|
if err != nil { |
|
return err |
|
} |
|
defer resp.Body.Close() |
|
|
|
if resp.StatusCode != http.StatusOK { |
|
b, _ := io.ReadAll(resp.Body) |
|
return fmt.Errorf("thinking service status=%d, body=%s", resp.StatusCode, string(b)) |
|
} |
|
|
|
reader := bufio.NewReader(resp.Body) |
|
var lastLine string |
|
|
|
forceStop := false |
|
|
|
for { |
|
select { |
|
case <-ctx.Done(): |
|
return ctx.Err() |
|
default: |
|
} |
|
line, err := reader.ReadString('\n') |
|
if err != nil { |
|
if err == io.EOF { |
|
break |
|
} |
|
return err |
|
} |
|
line = strings.TrimSpace(line) |
|
if line == "" || line == lastLine { |
|
continue |
|
} |
|
lastLine = line |
|
|
|
if !strings.HasPrefix(line, "data: ") { |
|
continue |
|
} |
|
dataPart := strings.TrimPrefix(line, "data: ") |
|
if dataPart == "[DONE]" { |
|
|
|
break |
|
} |
|
|
|
var chunk struct { |
|
Choices []struct { |
|
Delta struct { |
|
Content string `json:"content,omitempty"` |
|
ReasoningContent string `json:"reasoning_content,omitempty"` |
|
} `json:"delta"` |
|
FinishReason *string `json:"finish_reason,omitempty"` |
|
} `json:"choices"` |
|
} |
|
if err := json.Unmarshal([]byte(dataPart), &chunk); err != nil { |
|
|
|
h.writer.Write([]byte("data: " + dataPart + "\n\n")) |
|
h.flusher.Flush() |
|
continue |
|
} |
|
|
|
if len(chunk.Choices) > 0 { |
|
c := chunk.Choices[0] |
|
if h.config.Global.Log.Debug.PrintResponse { |
|
logger.LogContent("Thinking SSE chunk", chunk, h.config.Global.Log.Debug.MaxContentLength) |
|
} |
|
|
|
if h.thinkingService.Mode == "full" { |
|
|
|
if c.Delta.ReasoningContent != "" { |
|
reasonBuf.append(c.Delta.ReasoningContent) |
|
} |
|
if c.Delta.Content != "" { |
|
reasonBuf.append(c.Delta.Content) |
|
} |
|
|
|
forwardLine := "data: " + dataPart + "\n\n" |
|
h.writer.Write([]byte(forwardLine)) |
|
h.flusher.Flush() |
|
} else { |
|
|
|
if c.Delta.ReasoningContent != "" { |
|
reasonBuf.append(c.Delta.ReasoningContent) |
|
|
|
forwardLine := "data: " + dataPart + "\n\n" |
|
h.writer.Write([]byte(forwardLine)) |
|
h.flusher.Flush() |
|
} |
|
|
|
if c.Delta.Content != "" && strings.TrimSpace(c.Delta.ReasoningContent) == "" { |
|
forceStop = true |
|
} |
|
} |
|
|
|
|
|
if c.FinishReason != nil && *c.FinishReason != "" { |
|
forceStop = true |
|
} |
|
} |
|
|
|
if forceStop { |
|
break |
|
} |
|
} |
|
|
|
io.Copy(io.Discard, reader) |
|
return nil |
|
} |
|
|
|
func (h *StreamHandler) prepareFinalRequest(originalReq *ChatCompletionRequest, reasoningCollected string) *ChatCompletionRequest { |
|
req := *originalReq |
|
var systemPrompt string |
|
if h.thinkingService.Mode == "full" { |
|
systemPrompt = fmt.Sprintf( |
|
`Consider the following detailed analysis (not shown to user): |
|
%s |
|
|
|
Provide a clear, concise response that incorporates insights from this analysis.`, |
|
reasoningCollected, |
|
) |
|
} else { |
|
systemPrompt = fmt.Sprintf( |
|
`Previous thinking process: |
|
%s |
|
Please consider the above thinking process in your response.`, |
|
reasoningCollected, |
|
) |
|
} |
|
req.Messages = append([]ChatCompletionMessage{ |
|
{Role: "system", Content: systemPrompt}, |
|
}, req.Messages...) |
|
return &req |
|
} |
|
|
|
func (h *StreamHandler) streamFinalChannel(ctx context.Context, req *ChatCompletionRequest, logger *RequestLogger) error { |
|
if h.config.Global.Log.Debug.PrintRequest { |
|
logger.LogContent("Final Request to Channel", req, h.config.Global.Log.Debug.MaxContentLength) |
|
} |
|
|
|
jsonData, err := json.Marshal(req) |
|
if err != nil { |
|
return err |
|
} |
|
client, err := createHTTPClient(h.targetChannel.Proxy, time.Duration(h.targetChannel.Timeout)*time.Second) |
|
if err != nil { |
|
return err |
|
} |
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, h.targetChannel.GetFullURL(), bytes.NewBuffer(jsonData)) |
|
if err != nil { |
|
return err |
|
} |
|
httpReq.Header.Set("Content-Type", "application/json") |
|
httpReq.Header.Set("Authorization", "Bearer "+req.APIKey) |
|
|
|
resp, err := client.Do(httpReq) |
|
if err != nil { |
|
return err |
|
} |
|
defer resp.Body.Close() |
|
|
|
if resp.StatusCode != http.StatusOK { |
|
b, _ := io.ReadAll(resp.Body) |
|
return fmt.Errorf("target channel status=%d, body=%s", resp.StatusCode, string(b)) |
|
} |
|
|
|
reader := bufio.NewReader(resp.Body) |
|
var lastLine string |
|
|
|
for { |
|
select { |
|
case <-ctx.Done(): |
|
return ctx.Err() |
|
default: |
|
} |
|
line, err := reader.ReadString('\n') |
|
if err != nil { |
|
if err == io.EOF { |
|
break |
|
} |
|
return err |
|
} |
|
line = strings.TrimSpace(line) |
|
if line == "" || line == lastLine { |
|
continue |
|
} |
|
lastLine = line |
|
|
|
if !strings.HasPrefix(line, "data: ") { |
|
continue |
|
} |
|
data := strings.TrimPrefix(line, "data: ") |
|
if data == "[DONE]" { |
|
doneLine := "data: [DONE]\n\n" |
|
h.writer.Write([]byte(doneLine)) |
|
h.flusher.Flush() |
|
break |
|
} |
|
forwardLine := "data: " + data + "\n\n" |
|
h.writer.Write([]byte(forwardLine)) |
|
h.flusher.Flush() |
|
|
|
if h.config.Global.Log.Debug.PrintResponse { |
|
logger.LogContent("Channel SSE chunk", forwardLine, h.config.Global.Log.Debug.MaxContentLength) |
|
} |
|
} |
|
io.Copy(io.Discard, reader) |
|
return nil |
|
} |
|
|
|
|
|
|
|
func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) { |
|
transport := &http.Transport{ |
|
DialContext: (&net.Dialer{ |
|
Timeout: 30 * time.Second, |
|
KeepAlive: 30 * time.Second, |
|
}).DialContext, |
|
ForceAttemptHTTP2: true, |
|
MaxIdleConns: 100, |
|
IdleConnTimeout: 90 * time.Second, |
|
TLSHandshakeTimeout: 10 * time.Second, |
|
ExpectContinueTimeout: 1 * time.Second, |
|
} |
|
if proxyURL != "" { |
|
parsedURL, err := url.Parse(proxyURL) |
|
if err != nil { |
|
return nil, fmt.Errorf("invalid proxy URL: %v", err) |
|
} |
|
switch parsedURL.Scheme { |
|
case "http", "https": |
|
transport.Proxy = http.ProxyURL(parsedURL) |
|
case "socks5": |
|
dialer, err := proxy.FromURL(parsedURL, proxy.Direct) |
|
if err != nil { |
|
return nil, fmt.Errorf("failed to create SOCKS5 dialer: %v", err) |
|
} |
|
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { |
|
return dialer.Dial(network, addr) |
|
} |
|
default: |
|
return nil, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme) |
|
} |
|
} |
|
return &http.Client{ |
|
Transport: transport, |
|
Timeout: timeout, |
|
}, nil |
|
} |
|
|
|
func maskSensitiveHeaders(headers http.Header) http.Header { |
|
masked := make(http.Header) |
|
for k, vals := range headers { |
|
if strings.ToLower(k) == "authorization" { |
|
masked[k] = []string{"Bearer ****"} |
|
} else { |
|
masked[k] = vals |
|
} |
|
} |
|
return masked |
|
} |
|
|
|
func isValidReasoningEffort(effort string) bool { |
|
switch strings.ToLower(effort) { |
|
case "low", "medium", "high": |
|
return true |
|
} |
|
return false |
|
} |
|
|
|
func isValidReasoningFormat(format string) bool { |
|
switch strings.ToLower(format) { |
|
case "parsed", "raw", "hidden": |
|
return true |
|
} |
|
return false |
|
} |
|
|
|
|
|
|
|
|
|
func getEnvOrDefault(key, defaultValue string) string { |
|
if value := os.Getenv(key); value != "" { |
|
return value |
|
} |
|
return defaultValue |
|
} |
|
|
|
func getEnvIntOrDefault(key string, defaultValue int) int { |
|
if value := os.Getenv(key); value != "" { |
|
if intValue, err := strconv.Atoi(value); err == nil { |
|
return intValue |
|
} |
|
} |
|
return defaultValue |
|
} |
|
|
|
func getEnvBoolOrDefault(key string, defaultValue bool) bool { |
|
if value := os.Getenv(key); value != "" { |
|
if boolValue, err := strconv.ParseBool(value); err == nil { |
|
return boolValue |
|
} |
|
} |
|
return defaultValue |
|
} |
|
|
|
func getEnvFloatPtr(key string, defaultValue float64) *float64 { |
|
if value := os.Getenv(key); value != "" { |
|
if floatValue, err := strconv.ParseFloat(value, 64); err == nil { |
|
return &floatValue |
|
} |
|
} |
|
return &defaultValue |
|
} |
|
|
|
|
|
func loadConfig() (*Config, error) { |
|
|
|
if envConfig := os.Getenv("CONFIG_JSON"); envConfig != "" { |
|
var cfg Config |
|
if err := json.Unmarshal([]byte(envConfig), &cfg); err == nil { |
|
if err := validateConfig(&cfg); err == nil { |
|
log.Println("Using configuration from CONFIG_JSON environment variable") |
|
return &cfg, nil |
|
} |
|
} |
|
} |
|
|
|
|
|
log.Println("Building configuration from individual environment variables") |
|
cfg := &Config{ |
|
ThinkingServices: []ThinkingService{ |
|
{ |
|
ID: 1, |
|
Name: getEnvOrDefault("THINKING_SERVICE_NAME", "modelscope-deepseek-thinking"), |
|
Mode: getEnvOrDefault("THINKING_SERVICE_MODE", "standard"), |
|
Model: getEnvOrDefault("THINKING_SERVICE_MODEL", "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"), |
|
BaseURL: getEnvOrDefault("THINKING_SERVICE_BASE_URL", "https://api-inference.modelscope.cn"), |
|
APIPath: getEnvOrDefault("THINKING_SERVICE_API_PATH", "/v1/chat/completions"), |
|
APIKey: os.Getenv("THINKING_SERVICE_API_KEY"), |
|
Timeout: getEnvIntOrDefault("THINKING_SERVICE_TIMEOUT", 6000), |
|
Weight: getEnvIntOrDefault("THINKING_SERVICE_WEIGHT", 100), |
|
Proxy: getEnvOrDefault("THINKING_SERVICE_PROXY", ""), |
|
ReasoningEffort: getEnvOrDefault("THINKING_SERVICE_REASONING_EFFORT", "high"), |
|
ReasoningFormat: getEnvOrDefault("THINKING_SERVICE_REASONING_FORMAT", "parsed"), |
|
Temperature: getEnvFloatPtr("THINKING_SERVICE_TEMPERATURE", 0.8), |
|
ForceStopDeepThinking: getEnvBoolOrDefault("THINKING_SERVICE_FORCE_STOP", false), |
|
}, |
|
}, |
|
Channels: map[string]Channel{ |
|
"1": { |
|
Name: getEnvOrDefault("CHANNEL_NAME", "Gemini-channel"), |
|
BaseURL: getEnvOrDefault("CHANNEL_BASE_URL", "https://Richardlsr-gemini-balance.hf.space/hf"), |
|
APIPath: getEnvOrDefault("CHANNEL_API_PATH", "/v1/chat/completions"), |
|
Timeout: getEnvIntOrDefault("CHANNEL_TIMEOUT", 600), |
|
Proxy: getEnvOrDefault("CHANNEL_PROXY", ""), |
|
}, |
|
}, |
|
Global: GlobalConfig{ |
|
Server: ServerConfig{ |
|
Port: getEnvIntOrDefault("SERVER_PORT", 7860), |
|
Host: getEnvOrDefault("SERVER_HOST", "0.0.0.0"), |
|
ReadTimeout: getEnvIntOrDefault("SERVER_READ_TIMEOUT", 600), |
|
WriteTimeout: getEnvIntOrDefault("SERVER_WRITE_TIMEOUT", 600), |
|
IdleTimeout: getEnvIntOrDefault("SERVER_IDLE_TIMEOUT", 600), |
|
}, |
|
Log: LogConfig{ |
|
Level: getEnvOrDefault("LOG_LEVEL", "info"), |
|
Format: getEnvOrDefault("LOG_FORMAT", "text"), |
|
Output: getEnvOrDefault("LOG_OUTPUT", "console"), |
|
FilePath: getEnvOrDefault("LOG_FILE_PATH", "./logs/deepai.log"), |
|
Debug: DebugConfig{ |
|
Enabled: getEnvBoolOrDefault("LOG_DEBUG_ENABLED", true), |
|
PrintRequest: getEnvBoolOrDefault("LOG_PRINT_REQUEST", true), |
|
PrintResponse: getEnvBoolOrDefault("LOG_PRINT_RESPONSE", true), |
|
MaxContentLength: getEnvIntOrDefault("LOG_MAX_CONTENT_LENGTH", 1000), |
|
}, |
|
}, |
|
}, |
|
} |
|
|
|
|
|
if cfg.ThinkingServices[0].APIKey == "" { |
|
return nil, fmt.Errorf("THINKING_SERVICE_API_KEY environment variable is required") |
|
} |
|
|
|
if err := validateConfig(cfg); err != nil { |
|
return nil, fmt.Errorf("config validation error: %v", err) |
|
} |
|
|
|
return cfg, nil |
|
} |
|
|
|
func validateConfig(config *Config) error { |
|
if len(config.ThinkingServices) == 0 { |
|
return fmt.Errorf("no thinking services configured") |
|
} |
|
if len(config.Channels) == 0 { |
|
return fmt.Errorf("no channels configured") |
|
} |
|
for i, svc := range config.ThinkingServices { |
|
if svc.BaseURL == "" { |
|
return fmt.Errorf("thinking service %s has empty baseURL", svc.Name) |
|
} |
|
if svc.APIKey == "" { |
|
return fmt.Errorf("thinking service %s has empty apiKey", svc.Name) |
|
} |
|
if svc.Timeout <= 0 { |
|
return fmt.Errorf("thinking service %s has invalid timeout", svc.Name) |
|
} |
|
if svc.Model == "" { |
|
return fmt.Errorf("thinking service %s has empty model", svc.Name) |
|
} |
|
if svc.Mode == "" { |
|
config.ThinkingServices[i].Mode = "standard" |
|
} else if svc.Mode != "standard" && svc.Mode != "full" { |
|
return fmt.Errorf("thinking service %s unknown mode=%s", svc.Name, svc.Mode) |
|
} |
|
} |
|
return nil |
|
} |
|
|
|
func main() { |
|
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile) |
|
|
|
cfg, err := loadConfig() |
|
if err != nil { |
|
log.Fatalf("Failed to load config: %v", err) |
|
} |
|
log.Printf("Using config file: %s", viper.ConfigFileUsed()) |
|
|
|
server := NewServer(cfg) |
|
|
|
done := make(chan os.Signal, 1) |
|
signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) |
|
|
|
go func() { |
|
if err := server.Start(); err != nil && err != http.ErrServerClosed { |
|
log.Fatalf("start server error: %v", err) |
|
} |
|
}() |
|
log.Printf("Server started successfully") |
|
|
|
<-done |
|
log.Print("Server stopping...") |
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) |
|
defer cancel() |
|
|
|
if err := server.Shutdown(ctx); err != nil { |
|
log.Printf("Server forced to shutdown: %v", err) |
|
} |
|
log.Print("Server stopped") |
|
} |