import json import logging import os import uuid from datetime import datetime from typing import Any, Dict, List, Optional import httpx import uvicorn from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, Depends from pydantic import BaseModel from starlette.responses import StreamingResponse logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) load_dotenv() app = FastAPI() APP_SECRET = os.getenv("APP_SECRET", "666") ACCESS_TOKEN = os.getenv("SD_ACCESS_TOKEN", "") headers = { 'authorization': f'Bearer {ACCESS_TOKEN}', 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36 Edg/129.0.0.0', } ALLOWED_MODELS = [ "claude-3.5-sonnet", "sider", "gpt-4o-mini", "claude-3-haiku", "claude-3.5-haiku", "gemini-1.5-flash", "llama-3", "gpt-4o", "gemini-1.5-pro", "llama-3.1-405b" ] class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): model: str messages: List[Message] stream: Optional[bool] = False def create_chat_completion_data(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]: return { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion.chunk", "created": int(datetime.now().timestamp()), "model": model, "choices": [ { "index": 0, "delta": {"content": content, "role": "assistant"}, "finish_reason": finish_reason, } ], "usage": None, } @app.post("/hf/v1/chat/completions") async def chat_completions(request: ChatRequest): logger.info(f"Received chat completion request for model: {request.model}") if request.model not in ALLOWED_MODELS: logger.error(f"Model {request.model} is not allowed.") raise HTTPException( status_code=400, detail=f"Model {request.model} is not allowed. Allowed models are: {', '.join(ALLOWED_MODELS)}", ) json_data = { 'prompt': "\n".join( [f"{'User' if msg.role == 'user' else 'Assistant'}: {msg.content}" for msg in request.messages] ), 'stream': True, 'model': request.model, } logger.info(f"Sending request to external API with data: {json_data}") async def generate(): async with httpx.AsyncClient() as client: try: async with client.stream('POST', 'https://sider.ai/api/v3/completion/text', headers=headers, json=json_data, timeout=120.0) as response: response.raise_for_status() async for line in response.aiter_lines(): if line and ("[DONE]" not in line): content = json.loads(line[5:])["data"] yield f"data: {json.dumps(create_chat_completion_data(content.get('text', ''), request.model))}\n\n" yield f"data: {json.dumps(create_chat_completion_data('', request.model, 'stop'))}\n\n" yield "data: [DONE]\n\n" except httpx.HTTPStatusError as e: logger.error(f"HTTP error occurred: {e}") raise HTTPException(status_code=e.response.status_code, detail=str(e)) except httpx.RequestError as e: logger.error(f"An error occurred while requesting: {e}") raise HTTPException(status_code=500, detail=str(e)) if request.stream: logger.info("Streaming response") return StreamingResponse(generate(), media_type="text/event-stream") else: logger.info("Non-streaming response") full_response = "" async for chunk in generate(): if chunk.startswith("data: ") and not chunk[6:].startswith("[DONE]"): data = json.loads(chunk[6:]) if data["choices"][0]["delta"].get("content"): full_response += data["choices"][0]["delta"]["content"] logger.info(f"Full response generated: {full_response}") return { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(datetime.now().timestamp()), "model": request.model, "choices": [ { "index": 0, "message": {"role": "assistant", "content": full_response}, "finish_reason": "stop", } ], "usage": None, } if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)