from datetime import datetime import json from typing import Any, Dict, Optional import uuid import re import httpx from api.config import ( MODEL_MAPPING, USER_SELECTED_MODEL, MODEL_PREFIXES, MODEL_REFERERS, MODEL_ALIASES, headers, AGENT_MODE, TRENDING_AGENT_MODE, BASE_URL, ) from fastapi import HTTPException from api.models import ChatRequest from api.logger import setup_logger logger = setup_logger(__name__) def create_chat_completion_data( content: str, model: str, timestamp: int, finish_reason: Optional[str] = None ) -> Dict[str, Any]: return { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion.chunk", "created": timestamp, "model": model, "choices": [ { "index": 0, "delta": {"content": content, "role": "assistant"}, "finish_reason": finish_reason, } ], "usage": None, } def message_to_dict(message): if isinstance(message.content, str): return {"role": message.role, "content": message.content} elif isinstance(message.content, list) and len(message.content) == 2: return { "role": message.role, "content": message.content[0]["text"], "data": { "imageBase64": message.content[1]["image_url"]["url"], "fileText": "", "title": "snapshot", }, } else: return {"role": message.role, "content": message.content} def get_full_model_name(model: str) -> str: # Handle aliases return MODEL_ALIASES.get(model, model) def get_model_prefix(model: str) -> str: return MODEL_PREFIXES.get(model, "") def get_referer_url(model: str) -> str: referer_path = MODEL_REFERERS.get(model, f"/?model={model}") return f"{BASE_URL}{referer_path}" async def process_streaming_response(request: ChatRequest): model = get_full_model_name(request.model) agent_mode = AGENT_MODE.get(model, {}) trending_agent_mode = TRENDING_AGENT_MODE.get(model, {}) prefix = get_model_prefix(model) # Construct formatted prompt formatted_prompt = "" for msg in request.messages: role = msg.role.capitalize() content = msg.content if isinstance(content, list) and len(content) == 2: # Handle image content content = f"FILE:BB\n$#$\n\n$#$\n{msg.content[0]['text']}" if role and content: formatted_prompt += f"{role}: {content}\n" if prefix: formatted_prompt = f"{prefix} {formatted_prompt}".strip() json_data = { "messages": [ { "role": msg.role, "content": msg.content[0]["text"] if isinstance(msg.content, list) else msg.content, "data": msg.content[1]["image_url"]["url"] if isinstance(msg.content, list) and len(msg.content) == 2 else None, } for msg in request.messages ], "previewToken": None, "userId": None, "codeModelMode": True, "agentMode": agent_mode, "trendingAgentMode": trending_agent_mode, "isMicMode": False, "userSystemPrompt": None, "maxTokens": request.max_tokens, "playgroundTopP": request.top_p, "playgroundTemperature": request.temperature, "isChromeExt": False, "githubToken": None, "clickedAnswer2": False, "clickedAnswer3": False, "clickedForceWebSearch": False, "visitFromDelta": False, "mobileClient": False, "userSelectedModel": USER_SELECTED_MODEL.get(model, model), } async with httpx.AsyncClient() as client: try: async with client.stream( "POST", f"{BASE_URL}/api/chat", headers=headers, json=json_data, timeout=100, ) as response: response.raise_for_status() async for line in response.aiter_lines(): timestamp = int(datetime.now().timestamp()) if line: content = line # Clean the response if necessary if content.startswith("$@$v=undefined-rv1$@$"): content = content[21:] yield f"data: {json.dumps(create_chat_completion_data(content, model, timestamp))}\n\n" # Indicate the end of the stream timestamp = int(datetime.now().timestamp()) yield f"data: {json.dumps(create_chat_completion_data('', model, timestamp, 'stop'))}\n\n" yield "data: [DONE]\n\n" except httpx.HTTPStatusError as e: logger.error(f"HTTP error occurred: {e}") raise HTTPException(status_code=e.response.status_code, detail=str(e)) except httpx.RequestError as e: logger.error(f"Error occurred during request: {e}") raise HTTPException(status_code=500, detail=str(e)) async def process_non_streaming_response(request: ChatRequest): model = get_full_model_name(request.model) agent_mode = AGENT_MODE.get(model, {}) trending_agent_mode = TRENDING_AGENT_MODE.get(model, {}) prefix = get_model_prefix(model) # Construct formatted prompt formatted_prompt = "" for msg in request.messages: role = msg.role.capitalize() content = msg.content if isinstance(content, list) and len(content) == 2: # Handle image content content = f"FILE:BB\n$#$\n\n$#$\n{msg.content[0]['text']}" if role and content: formatted_prompt += f"{role}: {content}\n" if prefix: formatted_prompt = f"{prefix} {formatted_prompt}".strip() json_data = { "messages": [ { "role": msg.role, "content": msg.content[0]["text"] if isinstance(msg.content, list) else msg.content, "data": msg.content[1]["image_url"]["url"] if isinstance(msg.content, list) and len(msg.content) == 2 else None, } for msg in request.messages ], "previewToken": None, "userId": None, "codeModelMode": True, "agentMode": agent_mode, "trendingAgentMode": trending_agent_mode, "isMicMode": False, "userSystemPrompt": None, "maxTokens": request.max_tokens, "playgroundTopP": request.top_p, "playgroundTemperature": request.temperature, "isChromeExt": False, "githubToken": None, "clickedAnswer2": False, "clickedAnswer3": False, "clickedForceWebSearch": False, "visitFromDelta": False, "mobileClient": False, "userSelectedModel": USER_SELECTED_MODEL.get(model, model), } async with httpx.AsyncClient() as client: try: response = await client.post( f"{BASE_URL}/api/chat", headers=headers, json=json_data, timeout=100, ) response.raise_for_status() full_response = response.text # Clean the response if necessary if full_response.startswith("$@$v=undefined-rv1$@$"): full_response = full_response[21:] return { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(datetime.now().timestamp()), "model": model, "choices": [ { "index": 0, "message": {"role": "assistant", "content": full_response}, "finish_reason": "stop", } ], "usage": None, } except httpx.HTTPStatusError as e: logger.error(f"HTTP error occurred: {e}") raise HTTPException(status_code=e.response.status_code, detail=str(e)) except httpx.RequestError as e: logger.error(f"Error occurred during request: {e}") raise HTTPException(status_code=500, detail=str(e))