|
import re |
|
import random |
|
import string |
|
import uuid |
|
import json |
|
import logging |
|
import asyncio |
|
import base64 |
|
from aiohttp import ClientSession, ClientTimeout, ClientError |
|
from fastapi import FastAPI, HTTPException, Request |
|
from pydantic import BaseModel |
|
from typing import List, Dict, Any, Optional, AsyncGenerator |
|
from datetime import datetime |
|
from fastapi.responses import StreamingResponse |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
|
handlers=[ |
|
logging.StreamHandler() |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ModelNotWorkingException(Exception): |
|
def __init__(self, model: str): |
|
self.model = model |
|
self.message = f"The model '{model}' is currently not working. Please try another model or wait for it to be fixed." |
|
super().__init__(self.message) |
|
|
|
|
|
class ImageResponse: |
|
def __init__(self, data_uri: str, alt: str): |
|
self.data_uri = data_uri |
|
self.alt = alt |
|
|
|
def to_data_uri(image: bytes, mime_type: str = "image/png") -> str: |
|
encoded = base64.b64encode(image).decode('utf-8') |
|
return f"data:{mime_type};base64,{encoded}" |
|
|
|
def decode_base64_image(data_uri: str) -> bytes: |
|
try: |
|
header, encoded = data_uri.split(",", 1) |
|
return base64.b64decode(encoded) |
|
except Exception as e: |
|
logger.error(f"Error decoding base64 image: {e}") |
|
raise e |
|
|
|
class Blackbox: |
|
|
|
|
|
@classmethod |
|
async def create_async_generator( |
|
cls, |
|
model: str, |
|
messages: List[Dict[str, str]], |
|
proxy: Optional[str] = None, |
|
image: Optional[str] = None, |
|
image_name: Optional[str] = None, |
|
webSearchMode: bool = False, |
|
**kwargs |
|
) -> AsyncGenerator[Any, None]: |
|
model = cls.get_model(model) |
|
logger.info(f"Selected model: {model}") |
|
|
|
if not cls.working or model not in cls.models: |
|
logger.error(f"Model {model} is not working or not supported.") |
|
raise ModelNotWorkingException(model) |
|
|
|
headers = { |
|
|
|
} |
|
|
|
if model in cls.model_prefixes: |
|
prefix = cls.model_prefixes[model] |
|
if not messages[0]['content'].startswith(prefix): |
|
logger.debug(f"Adding prefix '{prefix}' to the first message.") |
|
messages[0]['content'] = f"{prefix} {messages[0]['content']}" |
|
|
|
random_id = ''.join(random.choices(string.ascii_letters + string.digits, k=7)) |
|
messages[-1]['id'] = random_id |
|
messages[-1]['role'] = 'user' |
|
|
|
if image is not None: |
|
try: |
|
image_bytes = decode_base64_image(image) |
|
data_uri = to_data_uri(image_bytes) |
|
messages[-1]['data'] = { |
|
'fileText': '', |
|
'imageBase64': data_uri, |
|
'title': image_name |
|
} |
|
messages[-1]['content'] = 'FILE:BB\n$#$\n\n$#$\n' + messages[-1]['content'] |
|
logger.debug("Image data added to the message.") |
|
except Exception as e: |
|
logger.error(f"Failed to decode base64 image: {e}") |
|
raise HTTPException(status_code=400, detail="Invalid image data provided.") |
|
|
|
data = { |
|
"messages": messages, |
|
"id": random_id, |
|
"previewToken": None, |
|
"userId": None, |
|
"codeModelMode": True, |
|
"agentMode": {}, |
|
"trendingAgentMode": {}, |
|
"isMicMode": False, |
|
"userSystemPrompt": None, |
|
"maxTokens": 99999999, |
|
"playgroundTopP": 0.9, |
|
"playgroundTemperature": 0.5, |
|
"isChromeExt": False, |
|
"githubToken": None, |
|
"clickedAnswer2": False, |
|
"clickedAnswer3": False, |
|
"clickedForceWebSearch": False, |
|
"visitFromDelta": False, |
|
"mobileClient": False, |
|
"userSelectedModel": None, |
|
"webSearchMode": webSearchMode, |
|
} |
|
|
|
if model in cls.agentMode: |
|
data["agentMode"] = cls.agentMode[model] |
|
elif model in cls.trendingAgentMode: |
|
data["trendingAgentMode"] = cls.trendingAgentMode[model] |
|
elif model in cls.userSelectedModel: |
|
data["userSelectedModel"] = cls.userSelectedModel[model] |
|
logger.info(f"Sending request to {cls.api_endpoint} with data: {data}") |
|
|
|
timeout = ClientTimeout(total=60) |
|
retry_attempts = 10 |
|
|
|
for attempt in range(retry_attempts): |
|
try: |
|
async with ClientSession(headers=headers, timeout=timeout) as session: |
|
async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response: |
|
response.raise_for_status() |
|
logger.info(f"Received response with status {response.status}") |
|
if model == 'ImageGeneration': |
|
response_text = await response.text() |
|
url_match = re.search(r'https://storage\.googleapis\.com/[^\s\)]+', response_text) |
|
if url_match: |
|
image_url = url_match.group(0) |
|
logger.info(f"Image URL found: {image_url}") |
|
|
|
|
|
async with session.get(image_url) as img_response: |
|
img_response.raise_for_status() |
|
image_bytes = await img_response.read() |
|
data_uri = to_data_uri(image_bytes) |
|
logger.info("Image converted to base64 data URI.") |
|
|
|
yield ImageResponse(data_uri, alt=messages[-1]['content']) |
|
else: |
|
logger.error("Image URL not found in the response.") |
|
raise Exception("Image URL not found in the response") |
|
else: |
|
full_response = "" |
|
search_results_json = "" |
|
try: |
|
async for chunk, _ in response.content.iter_chunks(): |
|
if chunk: |
|
decoded_chunk = chunk.decode(errors='ignore') |
|
decoded_chunk = re.sub(r'\$@\$v=[^$]+\$@\$', '', decoded_chunk) |
|
if decoded_chunk.strip(): |
|
if '$~~~$' in decoded_chunk: |
|
search_results_json += decoded_chunk |
|
else: |
|
full_response += decoded_chunk |
|
yield decoded_chunk |
|
logger.info("Finished streaming response chunks.") |
|
except Exception as e: |
|
logger.exception("Error while iterating over response chunks.") |
|
raise e |
|
if data["webSearchMode"] and search_results_json: |
|
match = re.search(r'\$~~~\$(.*?)\$~~~\$', search_results_json, re.DOTALL) |
|
if match: |
|
try: |
|
search_results = json.loads(match.group(1)) |
|
formatted_results = "\n\n**Sources:**\n" |
|
for i, result in enumerate(search_results[:5], 1): |
|
formatted_results += f"{i}. [{result['title']}]({result['link']})\n" |
|
logger.info("Formatted search results.") |
|
yield formatted_results |
|
except json.JSONDecodeError as je: |
|
logger.error("Failed to parse search results JSON.") |
|
raise je |
|
break |
|
except ClientError as ce: |
|
logger.error(f"Client error occurred: {ce}. Retrying attempt {attempt + 1}/{retry_attempts}") |
|
if attempt == retry_attempts - 1: |
|
raise HTTPException(status_code=502, detail="Error communicating with the external API. | NiansuhAI") |
|
except asyncio.TimeoutError: |
|
logger.error(f"Request timed out. Retrying attempt {attempt + 1}/{retry_attempts}") |
|
if attempt == retry_attempts - 1: |
|
raise HTTPException(status_code=504, detail="External API request timed out. | NiansuhAI") |
|
except Exception as e: |
|
logger.error(f"Unexpected error: {e}. Retrying attempt {attempt + 1}/{retry_attempts}") |
|
if attempt == retry_attempts - 1: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
app = FastAPI() |
|
|
|
class Message(BaseModel): |
|
role: str |
|
content: str |
|
|
|
class ChatRequest(BaseModel): |
|
model: str |
|
messages: List[Message] |
|
stream: Optional[bool] = False |
|
webSearchMode: Optional[bool] = False |
|
image: Optional[str] = None |
|
|
|
def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]: |
|
return { |
|
"id": f"chatcmpl-{uuid.uuid4()}", |
|
"object": "chat.completion.chunk", |
|
"created": int(datetime.now().timestamp()), |
|
"model": model, |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"delta": {"content": content, "role": "assistant"}, |
|
"finish_reason": finish_reason, |
|
} |
|
], |
|
"usage": None, |
|
} |
|
|
|
@app.post("/niansuhai/v1/chat/completions") |
|
async def chat_completions(request: ChatRequest, req: Request): |
|
logger.info(f"Received chat completions request: {request}") |
|
try: |
|
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages] |
|
|
|
async_generator = Blackbox.create_async_generator( |
|
model=request.model, |
|
messages=messages, |
|
proxy=None, |
|
image=request.image, |
|
image_name=None, |
|
webSearchMode=request.webSearchMode |
|
) |
|
|
|
if request.stream: |
|
async def generate(): |
|
try: |
|
async for chunk in async_generator: |
|
if isinstance(chunk, ImageResponse): |
|
image_markdown = f"" |
|
response_chunk = create_response(image_markdown, request.model) |
|
else: |
|
response_chunk = create_response(chunk, request.model) |
|
|
|
|
|
yield f"data: {json.dumps(response_chunk)}\n\n" |
|
|
|
|
|
yield "data: [DONE]\n\n" |
|
except HTTPException as he: |
|
error_response = {"error": he.detail} |
|
yield f"data: {json.dumps(error_response)}\n\n" |
|
except Exception as e: |
|
logger.exception("Error during streaming response generation.") |
|
error_response = {"error": str(e)} |
|
yield f"data: {json.dumps(error_response)}\n\n" |
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream") |
|
else: |
|
response_content = "" |
|
async for chunk in async_generator: |
|
if isinstance(chunk, ImageResponse): |
|
response_content += f"\n" |
|
else: |
|
response_content += chunk |
|
|
|
logger.info("Completed non-streaming response generation.") |
|
return { |
|
"id": f"chatcmpl-{uuid.uuid4()}", |
|
"object": "chat.completion", |
|
"created": int(datetime.now().timestamp()), |
|
"model": request.model, |
|
"choices": [ |
|
{ |
|
"message": { |
|
"role": "assistant", |
|
"content": response_content |
|
}, |
|
"finish_reason": "stop", |
|
"index": 0 |
|
} |
|
], |
|
"usage": { |
|
"prompt_tokens": sum(len(msg['content'].split()) for msg in messages), |
|
"completion_tokens": len(response_content.split()), |
|
"total_tokens": sum(len(msg['content'].split()) for msg in messages) + len(response_content.split()) |
|
}, |
|
} |
|
except ModelNotWorkingException as e: |
|
logger.warning(f"Model not working: {e}") |
|
raise HTTPException(status_code=503, detail=str(e)) |
|
except HTTPException as he: |
|
logger.warning(f"HTTPException: {he.detail}") |
|
raise he |
|
except Exception as e: |
|
logger.exception("An unexpected error occurred while processing the chat completions request.") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.get("/niansuhai/v1/models") |
|
async def get_models(): |
|
logger.info("Fetching available models.") |
|
return {"data": [{"id": model} for model in Blackbox.models]} |
|
|
|
|
|
@app.get("/niansuhai/v1/health") |
|
async def health_check(): |
|
"""Health check endpoint to verify the service is running.""" |
|
return {"status": "ok"} |
|
|
|
@app.get("/niansuhai/v1/models/{model}/status") |
|
async def model_status(model: str): |
|
"""Check if a specific model is available.""" |
|
if model in Blackbox.models: |
|
return {"model": model, "status": "available"} |
|
elif model in Blackbox.model_aliases: |
|
actual_model = Blackbox.model_aliases[model] |
|
return {"model": actual_model, "status": "available via alias"} |
|
else: |
|
raise HTTPException(status_code=404, detail="Model not found") |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|