File size: 3,898 Bytes
daa1bb4 f7f30cc 4d2b16a f7f30cc bcdcebb daa1bb4 89d8cc9 daa1bb4 bcdcebb 89d8cc9 bcdcebb a7b9a59 daa1bb4 bcdcebb 89d8cc9 daa1bb4 bcdcebb daa1bb4 f7f30cc bcdcebb a7b9a59 f7f30cc a7b9a59 f7f30cc a7b9a59 f7f30cc bcdcebb a7b9a59 f7f30cc a7b9a59 bcdcebb a7b9a59 4d2b16a f7f30cc 4d2b16a 89d8cc9 4d2b16a 89d8cc9 4d2b16a 89d8cc9 4d2b16a 89d8cc9 4d2b16a 89d8cc9 f7f30cc a7b9a59 f7f30cc 89d8cc9 a7b9a59 89d8cc9 bcdcebb 89d8cc9 a7b9a59 89d8cc9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from typing import List, Optional, Union
import requests
import json
import logging
app = FastAPI()
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("proxy")
# TypeGPT API settings
API_URL = "https://api.typegpt.net/v1/chat/completions"
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
BACKEND_MODEL = "pixtral-large-latest"
# Load model -> system prompt mappings
with open("model_map.json", "r") as f:
MODEL_PROMPTS = json.load(f)
# Request schema
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
stream: Optional[bool] = False
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stop: Optional[Union[str, List[str]]] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
# Construct payload with enforced system prompt
def build_payload(chat: ChatRequest):
# Use internal system prompt
system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
# Remove any user-provided system messages
filtered_messages = [msg for msg in chat.messages if msg.role != "system"]
# Insert enforced system prompt
payload_messages = [{"role": "system", "content": system_prompt}] + [
{"role": msg.role, "content": msg.content} for msg in filtered_messages
]
return {
"model": BACKEND_MODEL,
"messages": payload_messages,
"stream": chat.stream,
"temperature": chat.temperature,
"top_p": chat.top_p,
"n": chat.n,
"stop": chat.stop,
"presence_penalty": chat.presence_penalty,
"frequency_penalty": chat.frequency_penalty
}
# Streaming response handler
def stream_generator(requested_model: str, payload: dict, headers: dict):
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
for line in r.iter_lines(decode_unicode=True):
if not line:
continue
if line.startswith("data:"):
content = line[6:].strip()
if content == "[DONE]":
yield "data: [DONE]\n\n"
continue
try:
json_obj = json.loads(content)
if json_obj.get("model") == BACKEND_MODEL:
json_obj["model"] = requested_model
yield f"data: {json.dumps(json_obj)}\n\n"
except json.JSONDecodeError:
logger.warning("Invalid JSON in stream chunk: %s", content)
else:
logger.debug("Non-data stream line skipped: %s", line)
# Proxy endpoint
@app.post("/v1/chat/completions")
async def proxy_chat(request: Request):
try:
body = await request.json()
chat_request = ChatRequest(**body)
payload = build_payload(chat_request)
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
if chat_request.stream:
return StreamingResponse(
stream_generator(chat_request.model, payload, headers),
media_type="text/event-stream"
)
else:
response = requests.post(API_URL, headers=headers, json=payload)
data = response.json()
if "model" in data and data["model"] == BACKEND_MODEL:
data["model"] = chat_request.model
return JSONResponse(content=data)
except Exception as e:
logger.error("Error in /v1/chat/completions: %s", str(e))
return JSONResponse(content={"error": "Internal server error."}, status_code=500)
|