import json import logging import os import uuid from datetime import datetime from typing import Any, Dict, List, Optional import httpx import uvicorn from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, Depends from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel from starlette.middleware.cors import CORSMiddleware from starlette.responses import StreamingResponse, Response # Configure Logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) # Load Environment Variables load_dotenv() app = FastAPI() # Configuration Constants BASE_URL = "https://aichatonlineorg.erweima.ai/aichatonline" APP_SECRET = os.getenv("APP_SECRET", "666") ACCESS_TOKEN = os.getenv("SD_ACCESS_TOKEN", "") headers = { 'accept': '*/*', 'accept-language': 'zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6', 'authorization': f'Bearer {ACCESS_TOKEN}', 'cache-control': 'no-cache', 'origin': 'chrome-extension://dhoenijjpgpeimemopealfcbiecgceod', 'pragma': 'no-cache', 'priority': 'u=1, i', 'sec-fetch-dest': 'empty', 'sec-fetch-mode': 'cors', 'sec-fetch-site': 'none', 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36 Edg/129.0.0.0', } # Updated ALLOWED_MODELS List (No Duplicates) ALLOWED_MODELS = [ {"id": "claude-3.5-sonnet", "name": "Claude 3.5 Sonnet"}, {"id": "sider", "name": "Sider"}, {"id": "gpt-4o-mini", "name": "GPT-4o Mini"}, {"id": "claude-3-haiku", "name": "Claude 3 Haiku"}, {"id": "claude-3.5-haiku", "name": "Claude 3.5 Haiku"}, {"id": "gemini-1.5-flash", "name": "Gemini 1.5 Flash"}, {"id": "llama-3", "name": "Llama 3"}, {"id": "gpt-4o", "name": "GPT-4o"}, {"id": "gemini-1.5-pro", "name": "Gemini 1.5 Pro"}, {"id": "llama-3.1-405b", "name": "Llama 3.1 405b"}, ] # Configure CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], # Restrict this to specific origins in production allow_credentials=True, allow_methods=["*"], # All methods allowed allow_headers=["*"], # Allow all headers ) # Security Configuration security = HTTPBearer() # Pydantic Models class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): model: str messages: List[Message] stream: Optional[bool] = False # Utility Functions def create_chat_completion_data(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]: return { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion.chunk", "created": int(datetime.now().timestamp()), "model": model, "choices": [ { "index": 0, "delta": {"content": content, "role": "assistant"}, "finish_reason": finish_reason, } ], "usage": None, } def verify_app_secret(credentials: HTTPAuthorizationCredentials = Depends(security)): if credentials.credentials != APP_SECRET: logger.warning(f"Invalid APP_SECRET provided: {credentials.credentials}") raise HTTPException(status_code=403, detail="Invalid APP_SECRET") logger.info("APP_SECRET verified successfully.") return credentials.credentials # CORS Preflight Options Endpoint @app.options("/hf/v1/chat/completions") async def chat_completions_options(): return Response( status_code=200, headers={ "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "POST, OPTIONS", "Access-Control-Allow-Headers": "Content-Type, Authorization", }, ) # Replace Escaped Newlines def replace_escaped_newlines(input_string: str) -> str: return input_string.replace("\\n", "\n") # List Available Models @app.get("/hf/v1/models") async def list_models(): return {"object": "list", "data": ALLOWED_MODELS} # Chat Completions Endpoint @app.post("/hf/v1/chat/completions") async def chat_completions( request: ChatRequest, app_secret: str = Depends(verify_app_secret) ): logger.info(f"Received chat completion request for model: {request.model}") # Validate Selected Model if request.model not in [model['id'] for model in ALLOWED_MODELS]: allowed = ', '.join(model['id'] for model in ALLOWED_MODELS) logger.error(f"Model {request.model} is not allowed.") raise HTTPException( status_code=400, detail=f"Model '{request.model}' is not allowed. Allowed models are: {allowed}", ) logger.info(f"Using model: {request.model}") # Generate a UUID original_uuid = uuid.uuid4() uuid_str = str(original_uuid).replace("-", "") logger.debug(f"Generated UUID: {uuid_str}") # Prepare JSON Payload for External API json_data = { 'prompt': "\n".join( [ f"{'User' if msg.role == 'user' else 'Assistant'}: {msg.content}" for msg in request.messages ] ), 'stream': True, 'app_name': 'ChitChat_Edge_Ext', 'app_version': '4.26.1', 'tz_name': 'Asia/Karachi', 'cid': '', 'model': request.model, # Use the selected model directly 'search': False, 'auto_search': False, 'filter_search_history': False, 'from': 'chat', 'group_id': 'default', 'chat_models': [request.model], # Include the model in chat_models 'files': [], 'prompt_template': { 'key': '', 'attributes': { 'lang': 'original', }, }, 'tools': { 'auto': [ 'search', 'text_to_image', 'data_analysis', ], }, 'extra_info': { 'origin_url': '', 'origin_title': '', }, } logger.debug(f"JSON Data Sent to External API: {json.dumps(json_data, indent=2)}") async def generate(): async with httpx.AsyncClient() as client: try: async with client.stream( 'POST', 'https://sider.ai/api/v2/completion/text', headers=headers, json=json_data, timeout=120.0 ) as response: response.raise_for_status() async for line in response.aiter_lines(): if line and ("[DONE]" not in line): # Assuming the line starts with 'data: ' followed by JSON if line.startswith("data: "): json_line = line[6:] if json_line.startswith("{"): try: data = json.loads(json_line) content = data.get("data", {}).get("text", "") logger.debug(f"Received content: {content}") yield f"data: {json.dumps(create_chat_completion_data(content, request.model))}\n\n" except json.JSONDecodeError as e: logger.error(f"JSON decode error: {e} - Line: {json_line}") # Send the stop signal yield f"data: {json.dumps(create_chat_completion_data('', request.model, 'stop'))}\n\n" yield "data: [DONE]\n\n" except httpx.HTTPStatusError as e: logger.error(f"HTTP error occurred: {e} - Response: {e.response.text}") raise HTTPException(status_code=e.response.status_code, detail=str(e)) except httpx.RequestError as e: logger.error(f"An error occurred while requesting: {e}") raise HTTPException(status_code=500, detail=str(e)) if request.stream: logger.info("Streaming response initiated.") return StreamingResponse(generate(), media_type="text/event-stream") else: logger.info("Non-streaming response initiated.") full_response = "" async for chunk in generate(): if chunk.startswith("data: ") and not chunk[6:].startswith("[DONE]"): # Parse the JSON part after 'data: ' try: data = json.loads(chunk[6:]) if data["choices"][0]["delta"].get("content"): full_response += data["choices"][0]["delta"]["content"] except json.JSONDecodeError: logger.warning(f"Failed to decode JSON from chunk: {chunk}") # Final Response Structure return { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(datetime.now().timestamp()), "model": request.model, "choices": [ { "index": 0, "message": {"role": "assistant", "content": full_response}, "finish_reason": "stop", } ], "usage": None, } # Entry Point if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)