File size: 3,898 Bytes
daa1bb4
f7f30cc
4d2b16a
f7f30cc
bcdcebb
daa1bb4
89d8cc9
daa1bb4
 
bcdcebb
89d8cc9
 
 
 
 
bcdcebb
 
 
 
a7b9a59
daa1bb4
 
bcdcebb
89d8cc9
daa1bb4
 
 
bcdcebb
daa1bb4
 
f7f30cc
 
 
 
 
 
 
 
bcdcebb
a7b9a59
f7f30cc
a7b9a59
f7f30cc
a7b9a59
 
 
 
 
 
 
 
 
f7f30cc
bcdcebb
a7b9a59
f7f30cc
 
 
 
 
 
a7b9a59
bcdcebb
 
a7b9a59
4d2b16a
f7f30cc
4d2b16a
89d8cc9
 
 
4d2b16a
89d8cc9
 
 
4d2b16a
 
89d8cc9
4d2b16a
89d8cc9
4d2b16a
89d8cc9
 
 
f7f30cc
a7b9a59
f7f30cc
 
89d8cc9
 
 
 
a7b9a59
89d8cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcdcebb
89d8cc9
a7b9a59
89d8cc9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from typing import List, Optional, Union
import requests
import json
import logging

app = FastAPI()

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("proxy")

# TypeGPT API settings
API_URL = "https://api.typegpt.net/v1/chat/completions"
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
BACKEND_MODEL = "pixtral-large-latest"

# Load model -> system prompt mappings
with open("model_map.json", "r") as f:
    MODEL_PROMPTS = json.load(f)

# Request schema
class Message(BaseModel):
    role: str
    content: str

class ChatRequest(BaseModel):
    model: str
    messages: List[Message]
    stream: Optional[bool] = False
    temperature: Optional[float] = 1.0
    top_p: Optional[float] = 1.0
    n: Optional[int] = 1
    stop: Optional[Union[str, List[str]]] = None
    presence_penalty: Optional[float] = 0.0
    frequency_penalty: Optional[float] = 0.0

# Construct payload with enforced system prompt
def build_payload(chat: ChatRequest):
    # Use internal system prompt
    system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")

    # Remove any user-provided system messages
    filtered_messages = [msg for msg in chat.messages if msg.role != "system"]

    # Insert enforced system prompt
    payload_messages = [{"role": "system", "content": system_prompt}] + [
        {"role": msg.role, "content": msg.content} for msg in filtered_messages
    ]

    return {
        "model": BACKEND_MODEL,
        "messages": payload_messages,
        "stream": chat.stream,
        "temperature": chat.temperature,
        "top_p": chat.top_p,
        "n": chat.n,
        "stop": chat.stop,
        "presence_penalty": chat.presence_penalty,
        "frequency_penalty": chat.frequency_penalty
    }

# Streaming response handler
def stream_generator(requested_model: str, payload: dict, headers: dict):
    with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
        for line in r.iter_lines(decode_unicode=True):
            if not line:
                continue
            if line.startswith("data:"):
                content = line[6:].strip()
                if content == "[DONE]":
                    yield "data: [DONE]\n\n"
                    continue
                try:
                    json_obj = json.loads(content)
                    if json_obj.get("model") == BACKEND_MODEL:
                        json_obj["model"] = requested_model
                    yield f"data: {json.dumps(json_obj)}\n\n"
                except json.JSONDecodeError:
                    logger.warning("Invalid JSON in stream chunk: %s", content)
            else:
                logger.debug("Non-data stream line skipped: %s", line)

# Proxy endpoint
@app.post("/v1/chat/completions")
async def proxy_chat(request: Request):
    try:
        body = await request.json()
        chat_request = ChatRequest(**body)
        payload = build_payload(chat_request)

        headers = {
            "Authorization": f"Bearer {API_KEY}",
            "Content-Type": "application/json"
        }

        if chat_request.stream:
            return StreamingResponse(
                stream_generator(chat_request.model, payload, headers),
                media_type="text/event-stream"
            )
        else:
            response = requests.post(API_URL, headers=headers, json=payload)
            data = response.json()
            if "model" in data and data["model"] == BACKEND_MODEL:
                data["model"] = chat_request.model
            return JSONResponse(content=data)

    except Exception as e:
        logger.error("Error in /v1/chat/completions: %s", str(e))
        return JSONResponse(content={"error": "Internal server error."}, status_code=500)