Update main.py
Browse files
main.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import re
|
| 2 |
import random
|
| 3 |
import string
|
|
@@ -5,13 +6,14 @@ import uuid
|
|
| 5 |
import json
|
| 6 |
import logging
|
| 7 |
import asyncio
|
| 8 |
-
import
|
| 9 |
-
from
|
| 10 |
-
from fastapi import FastAPI, HTTPException, Request
|
| 11 |
-
from pydantic import BaseModel
|
| 12 |
from typing import List, Dict, Any, Optional, AsyncGenerator
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
from fastapi.responses import StreamingResponse
|
|
|
|
| 15 |
|
| 16 |
# Configure logging
|
| 17 |
logging.basicConfig(
|
|
@@ -23,6 +25,47 @@ logging.basicConfig(
|
|
| 23 |
)
|
| 24 |
logger = logging.getLogger(__name__)
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
# Custom exception for model not working
|
| 27 |
class ModelNotWorkingException(Exception):
|
| 28 |
def __init__(self, model: str):
|
|
@@ -30,23 +73,14 @@ class ModelNotWorkingException(Exception):
|
|
| 30 |
self.message = f"The model '{model}' is currently not working. Please try another model or wait for it to be fixed."
|
| 31 |
super().__init__(self.message)
|
| 32 |
|
| 33 |
-
#
|
| 34 |
class ImageResponse:
|
| 35 |
-
def __init__(self,
|
| 36 |
-
self.
|
| 37 |
self.alt = alt
|
| 38 |
|
| 39 |
-
def to_data_uri(image:
|
| 40 |
-
|
| 41 |
-
return f"data:{mime_type};base64,{encoded}"
|
| 42 |
-
|
| 43 |
-
def decode_base64_image(data_uri: str) -> bytes:
|
| 44 |
-
try:
|
| 45 |
-
header, encoded = data_uri.split(",", 1)
|
| 46 |
-
return base64.b64decode(encoded)
|
| 47 |
-
except Exception as e:
|
| 48 |
-
logger.error(f"Error decoding base64 image: {e}")
|
| 49 |
-
raise e
|
| 50 |
|
| 51 |
class Blackbox:
|
| 52 |
url = "https://www.blackbox.ai"
|
|
@@ -158,7 +192,7 @@ class Blackbox:
|
|
| 158 |
if model in cls.models:
|
| 159 |
return model
|
| 160 |
elif model in cls.userSelectedModel:
|
| 161 |
-
return
|
| 162 |
elif model in cls.model_aliases:
|
| 163 |
return cls.model_aliases[model]
|
| 164 |
else:
|
|
@@ -168,9 +202,9 @@ class Blackbox:
|
|
| 168 |
async def create_async_generator(
|
| 169 |
cls,
|
| 170 |
model: str,
|
| 171 |
-
messages: List[Dict[str,
|
| 172 |
proxy: Optional[str] = None,
|
| 173 |
-
image:
|
| 174 |
image_name: Optional[str] = None,
|
| 175 |
webSearchMode: bool = False,
|
| 176 |
**kwargs
|
|
@@ -200,39 +234,24 @@ class Blackbox:
|
|
| 200 |
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36",
|
| 201 |
}
|
| 202 |
|
| 203 |
-
if model in cls.model_prefixes
|
| 204 |
prefix = cls.model_prefixes[model]
|
| 205 |
if not messages[0]['content'].startswith(prefix):
|
| 206 |
logger.debug(f"Adding prefix '{prefix}' to the first message.")
|
| 207 |
messages[0]['content'] = f"{prefix} {messages[0]['content']}"
|
| 208 |
|
| 209 |
random_id = ''.join(random.choices(string.ascii_letters + string.digits, k=7))
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
"role": 'user',
|
| 213 |
-
"content": 'Hi' # This should be dynamically set based on input
|
| 214 |
-
}
|
| 215 |
if image is not None:
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
}
|
| 224 |
-
user_message['content'] = 'FILE:BB\n$#$\n\n$#$\n' + user_message['content']
|
| 225 |
-
logger.debug("Image data added to the message.")
|
| 226 |
-
except Exception as e:
|
| 227 |
-
logger.error(f"Failed to decode base64 image: {e}")
|
| 228 |
-
raise HTTPException(status_code=400, detail="Invalid image data provided.")
|
| 229 |
|
| 230 |
-
# Update the last message with user_message
|
| 231 |
-
if messages:
|
| 232 |
-
messages[-1] = user_message
|
| 233 |
-
else:
|
| 234 |
-
messages.append(user_message)
|
| 235 |
-
|
| 236 |
data = {
|
| 237 |
"messages": messages,
|
| 238 |
"id": random_id,
|
|
@@ -280,15 +299,7 @@ class Blackbox:
|
|
| 280 |
if url_match:
|
| 281 |
image_url = url_match.group(0)
|
| 282 |
logger.info(f"Image URL found: {image_url}")
|
| 283 |
-
|
| 284 |
-
# Fetch the image data
|
| 285 |
-
async with session.get(image_url) as img_response:
|
| 286 |
-
img_response.raise_for_status()
|
| 287 |
-
image_bytes = await img_response.read()
|
| 288 |
-
data_uri = to_data_uri(image_bytes)
|
| 289 |
-
logger.info("Image converted to base64 data URI.")
|
| 290 |
-
|
| 291 |
-
yield ImageResponse(data_uri, alt=messages[-1]['content'])
|
| 292 |
else:
|
| 293 |
logger.error("Image URL not found in the response.")
|
| 294 |
raise Exception("Image URL not found in the response")
|
|
@@ -349,7 +360,6 @@ class ChatRequest(BaseModel):
|
|
| 349 |
messages: List[Message]
|
| 350 |
stream: Optional[bool] = False
|
| 351 |
webSearchMode: Optional[bool] = False
|
| 352 |
-
image: Optional[str] = None # Add image field for base64 data
|
| 353 |
|
| 354 |
def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
|
| 355 |
return {
|
|
@@ -367,32 +377,25 @@ def create_response(content: str, model: str, finish_reason: Optional[str] = Non
|
|
| 367 |
"usage": None,
|
| 368 |
}
|
| 369 |
|
| 370 |
-
@app.post("/niansuhai/v1/chat/completions")
|
| 371 |
-
async def chat_completions(request: ChatRequest, req: Request):
|
| 372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
try:
|
| 374 |
-
# Validate that
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
status_code=422,
|
| 380 |
-
detail=[{
|
| 381 |
-
"type": "string_type",
|
| 382 |
-
"loc": ["body", "messages", idx, "content"],
|
| 383 |
-
"msg": "Input should be a valid string",
|
| 384 |
-
"input": msg.content
|
| 385 |
-
}]
|
| 386 |
-
)
|
| 387 |
-
|
| 388 |
-
# Convert Pydantic messages to dicts
|
| 389 |
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
| 390 |
|
| 391 |
async_generator = Blackbox.create_async_generator(
|
| 392 |
model=request.model,
|
| 393 |
messages=messages,
|
| 394 |
-
|
| 395 |
-
image=request.image, # Pass the base64 image
|
| 396 |
image_name=None,
|
| 397 |
webSearchMode=request.webSearchMode
|
| 398 |
)
|
|
@@ -402,8 +405,7 @@ async def chat_completions(request: ChatRequest, req: Request):
|
|
| 402 |
try:
|
| 403 |
async for chunk in async_generator:
|
| 404 |
if isinstance(chunk, ImageResponse):
|
| 405 |
-
|
| 406 |
-
image_markdown = f""
|
| 407 |
response_chunk = create_response(image_markdown, request.model)
|
| 408 |
else:
|
| 409 |
response_chunk = create_response(chunk, request.model)
|
|
@@ -426,11 +428,11 @@ async def chat_completions(request: ChatRequest, req: Request):
|
|
| 426 |
response_content = ""
|
| 427 |
async for chunk in async_generator:
|
| 428 |
if isinstance(chunk, ImageResponse):
|
| 429 |
-
response_content += f"![
|
| 430 |
else:
|
| 431 |
response_content += chunk
|
| 432 |
|
| 433 |
-
logger.info("Completed non-streaming response generation
|
| 434 |
return {
|
| 435 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
| 436 |
"object": "chat.completion",
|
|
@@ -462,26 +464,33 @@ async def chat_completions(request: ChatRequest, req: Request):
|
|
| 462 |
logger.exception("An unexpected error occurred while processing the chat completions request.")
|
| 463 |
raise HTTPException(status_code=500, detail=str(e))
|
| 464 |
|
| 465 |
-
@app.get("/niansuhai/v1/models")
|
| 466 |
-
async def get_models():
|
| 467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
return {"data": [{"id": model} for model in Blackbox.models]}
|
| 469 |
|
| 470 |
# Additional endpoints for better functionality
|
| 471 |
-
@app.get("/niansuhai/v1/health")
|
| 472 |
-
async def health_check():
|
| 473 |
"""Health check endpoint to verify the service is running."""
|
|
|
|
| 474 |
return {"status": "ok"}
|
| 475 |
|
| 476 |
-
@app.get("/niansuhai/v1/models/{model}/status")
|
| 477 |
-
async def model_status(model: str):
|
| 478 |
"""Check if a specific model is available."""
|
|
|
|
| 479 |
if model in Blackbox.models:
|
| 480 |
return {"model": model, "status": "available"}
|
| 481 |
elif model in Blackbox.model_aliases:
|
| 482 |
actual_model = Blackbox.model_aliases[model]
|
| 483 |
return {"model": actual_model, "status": "available via alias"}
|
| 484 |
else:
|
|
|
|
| 485 |
raise HTTPException(status_code=404, detail="Model not found")
|
| 486 |
|
| 487 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
import os
|
| 2 |
import re
|
| 3 |
import random
|
| 4 |
import string
|
|
|
|
| 6 |
import json
|
| 7 |
import logging
|
| 8 |
import asyncio
|
| 9 |
+
import time
|
| 10 |
+
from collections import defaultdict
|
|
|
|
|
|
|
| 11 |
from typing import List, Dict, Any, Optional, AsyncGenerator
|
| 12 |
+
|
| 13 |
+
from aiohttp import ClientSession, ClientTimeout, ClientError
|
| 14 |
+
from fastapi import FastAPI, HTTPException, Request, Depends, Header
|
| 15 |
from fastapi.responses import StreamingResponse
|
| 16 |
+
from pydantic import BaseModel
|
| 17 |
|
| 18 |
# Configure logging
|
| 19 |
logging.basicConfig(
|
|
|
|
| 25 |
)
|
| 26 |
logger = logging.getLogger(__name__)
|
| 27 |
|
| 28 |
+
# Load environment variables
|
| 29 |
+
API_KEYS = os.getenv('API_KEYS', '').split(',') # Comma-separated API keys
|
| 30 |
+
RATE_LIMIT = int(os.getenv('RATE_LIMIT', '60')) # Requests per minute
|
| 31 |
+
|
| 32 |
+
if not API_KEYS or API_KEYS == ['']:
|
| 33 |
+
logger.error("No API keys found. Please set the API_KEYS environment variable.")
|
| 34 |
+
raise Exception("API_KEYS environment variable not set.")
|
| 35 |
+
|
| 36 |
+
# Simple in-memory rate limiter
|
| 37 |
+
rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
|
| 38 |
+
|
| 39 |
+
async def get_api_key(authorization: str = Header(...)) -> str:
|
| 40 |
+
"""
|
| 41 |
+
Dependency to extract and validate the API key from the Authorization header.
|
| 42 |
+
Expects the header in the format: Authorization: Bearer <API_KEY>
|
| 43 |
+
"""
|
| 44 |
+
if not authorization.startswith('Bearer '):
|
| 45 |
+
logger.warning("Invalid authorization header format.")
|
| 46 |
+
raise HTTPException(status_code=401, detail='Invalid authorization header format')
|
| 47 |
+
api_key = authorization[7:]
|
| 48 |
+
if api_key not in API_KEYS:
|
| 49 |
+
logger.warning(f"Invalid API key attempted: {api_key}")
|
| 50 |
+
raise HTTPException(status_code=401, detail='Invalid API key')
|
| 51 |
+
return api_key
|
| 52 |
+
|
| 53 |
+
async def rate_limiter(api_key: str = Depends(get_api_key)):
|
| 54 |
+
"""
|
| 55 |
+
Dependency to enforce rate limiting per API key.
|
| 56 |
+
Raises HTTP 429 if the rate limit is exceeded.
|
| 57 |
+
"""
|
| 58 |
+
current_time = time.time()
|
| 59 |
+
window_start = rate_limit_store[api_key]["timestamp"]
|
| 60 |
+
if current_time - window_start > 60:
|
| 61 |
+
# Reset the count and timestamp after the time window
|
| 62 |
+
rate_limit_store[api_key] = {"count": 1, "timestamp": current_time}
|
| 63 |
+
else:
|
| 64 |
+
if rate_limit_store[api_key]["count"] >= RATE_LIMIT:
|
| 65 |
+
logger.warning(f"Rate limit exceeded for API key: {api_key}")
|
| 66 |
+
raise HTTPException(status_code=429, detail='Rate limit exceeded')
|
| 67 |
+
rate_limit_store[api_key]["count"] += 1
|
| 68 |
+
|
| 69 |
# Custom exception for model not working
|
| 70 |
class ModelNotWorkingException(Exception):
|
| 71 |
def __init__(self, model: str):
|
|
|
|
| 73 |
self.message = f"The model '{model}' is currently not working. Please try another model or wait for it to be fixed."
|
| 74 |
super().__init__(self.message)
|
| 75 |
|
| 76 |
+
# Mock implementations for ImageResponse and to_data_uri
|
| 77 |
class ImageResponse:
|
| 78 |
+
def __init__(self, url: str, alt: str):
|
| 79 |
+
self.url = url
|
| 80 |
self.alt = alt
|
| 81 |
|
| 82 |
+
def to_data_uri(image: Any) -> str:
|
| 83 |
+
return "data:image/png;base64,..." # Replace with actual base64 data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
class Blackbox:
|
| 86 |
url = "https://www.blackbox.ai"
|
|
|
|
| 192 |
if model in cls.models:
|
| 193 |
return model
|
| 194 |
elif model in cls.userSelectedModel:
|
| 195 |
+
return model
|
| 196 |
elif model in cls.model_aliases:
|
| 197 |
return cls.model_aliases[model]
|
| 198 |
else:
|
|
|
|
| 202 |
async def create_async_generator(
|
| 203 |
cls,
|
| 204 |
model: str,
|
| 205 |
+
messages: List[Dict[str, str]],
|
| 206 |
proxy: Optional[str] = None,
|
| 207 |
+
image: Any = None,
|
| 208 |
image_name: Optional[str] = None,
|
| 209 |
webSearchMode: bool = False,
|
| 210 |
**kwargs
|
|
|
|
| 234 |
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36",
|
| 235 |
}
|
| 236 |
|
| 237 |
+
if model in cls.model_prefixes:
|
| 238 |
prefix = cls.model_prefixes[model]
|
| 239 |
if not messages[0]['content'].startswith(prefix):
|
| 240 |
logger.debug(f"Adding prefix '{prefix}' to the first message.")
|
| 241 |
messages[0]['content'] = f"{prefix} {messages[0]['content']}"
|
| 242 |
|
| 243 |
random_id = ''.join(random.choices(string.ascii_letters + string.digits, k=7))
|
| 244 |
+
messages[-1]['id'] = random_id
|
| 245 |
+
messages[-1]['role'] = 'user'
|
|
|
|
|
|
|
|
|
|
| 246 |
if image is not None:
|
| 247 |
+
messages[-1]['data'] = {
|
| 248 |
+
'fileText': '',
|
| 249 |
+
'imageBase64': to_data_uri(image),
|
| 250 |
+
'title': image_name
|
| 251 |
+
}
|
| 252 |
+
messages[-1]['content'] = 'FILE:BB\n$#$\n\n$#$\n' + messages[-1]['content']
|
| 253 |
+
logger.debug("Image data added to the message.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
data = {
|
| 256 |
"messages": messages,
|
| 257 |
"id": random_id,
|
|
|
|
| 299 |
if url_match:
|
| 300 |
image_url = url_match.group(0)
|
| 301 |
logger.info(f"Image URL found: {image_url}")
|
| 302 |
+
yield ImageResponse(image_url, alt=messages[-1]['content'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
else:
|
| 304 |
logger.error("Image URL not found in the response.")
|
| 305 |
raise Exception("Image URL not found in the response")
|
|
|
|
| 360 |
messages: List[Message]
|
| 361 |
stream: Optional[bool] = False
|
| 362 |
webSearchMode: Optional[bool] = False
|
|
|
|
| 363 |
|
| 364 |
def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
|
| 365 |
return {
|
|
|
|
| 377 |
"usage": None,
|
| 378 |
}
|
| 379 |
|
| 380 |
+
@app.post("/niansuhai/v1/chat/completions", dependencies=[Depends(rate_limiter)])
|
| 381 |
+
async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
|
| 382 |
+
"""
|
| 383 |
+
Endpoint to handle chat completions.
|
| 384 |
+
Protected by API key and rate limiter.
|
| 385 |
+
"""
|
| 386 |
+
logger.info(f"Received chat completions request from API key: {api_key} | Request: {request}")
|
| 387 |
try:
|
| 388 |
+
# Validate that the requested model is available
|
| 389 |
+
if request.model not in Blackbox.models and request.model not in Blackbox.model_aliases:
|
| 390 |
+
logger.warning(f"Attempt to use unavailable model: {request.model}")
|
| 391 |
+
raise HTTPException(status_code=400, detail="Requested model is not available.")
|
| 392 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
| 394 |
|
| 395 |
async_generator = Blackbox.create_async_generator(
|
| 396 |
model=request.model,
|
| 397 |
messages=messages,
|
| 398 |
+
image=None,
|
|
|
|
| 399 |
image_name=None,
|
| 400 |
webSearchMode=request.webSearchMode
|
| 401 |
)
|
|
|
|
| 405 |
try:
|
| 406 |
async for chunk in async_generator:
|
| 407 |
if isinstance(chunk, ImageResponse):
|
| 408 |
+
image_markdown = f""
|
|
|
|
| 409 |
response_chunk = create_response(image_markdown, request.model)
|
| 410 |
else:
|
| 411 |
response_chunk = create_response(chunk, request.model)
|
|
|
|
| 428 |
response_content = ""
|
| 429 |
async for chunk in async_generator:
|
| 430 |
if isinstance(chunk, ImageResponse):
|
| 431 |
+
response_content += f"\n"
|
| 432 |
else:
|
| 433 |
response_content += chunk
|
| 434 |
|
| 435 |
+
logger.info(f"Completed non-streaming response generation for API key: {api_key}")
|
| 436 |
return {
|
| 437 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
| 438 |
"object": "chat.completion",
|
|
|
|
| 464 |
logger.exception("An unexpected error occurred while processing the chat completions request.")
|
| 465 |
raise HTTPException(status_code=500, detail=str(e))
|
| 466 |
|
| 467 |
+
@app.get("/niansuhai/v1/models", dependencies=[Depends(rate_limiter)])
|
| 468 |
+
async def get_models(api_key: str = Depends(get_api_key)):
|
| 469 |
+
"""
|
| 470 |
+
Endpoint to fetch available models.
|
| 471 |
+
Protected by API key and rate limiter.
|
| 472 |
+
"""
|
| 473 |
+
logger.info(f"Fetching available models for API key: {api_key}")
|
| 474 |
return {"data": [{"id": model} for model in Blackbox.models]}
|
| 475 |
|
| 476 |
# Additional endpoints for better functionality
|
| 477 |
+
@app.get("/niansuhai/v1/health", dependencies=[Depends(rate_limiter)])
|
| 478 |
+
async def health_check(api_key: str = Depends(get_api_key)):
|
| 479 |
"""Health check endpoint to verify the service is running."""
|
| 480 |
+
logger.info(f"Health check requested by API key: {api_key}")
|
| 481 |
return {"status": "ok"}
|
| 482 |
|
| 483 |
+
@app.get("/niansuhai/v1/models/{model}/status", dependencies=[Depends(rate_limiter)])
|
| 484 |
+
async def model_status(model: str, api_key: str = Depends(get_api_key)):
|
| 485 |
"""Check if a specific model is available."""
|
| 486 |
+
logger.info(f"Model status requested for '{model}' by API key: {api_key}")
|
| 487 |
if model in Blackbox.models:
|
| 488 |
return {"model": model, "status": "available"}
|
| 489 |
elif model in Blackbox.model_aliases:
|
| 490 |
actual_model = Blackbox.model_aliases[model]
|
| 491 |
return {"model": actual_model, "status": "available via alias"}
|
| 492 |
else:
|
| 493 |
+
logger.warning(f"Model not found: {model}")
|
| 494 |
raise HTTPException(status_code=404, detail="Model not found")
|
| 495 |
|
| 496 |
if __name__ == "__main__":
|