Update main.py
Browse files
main.py
CHANGED
@@ -28,7 +28,6 @@ IMAGE_GEN_API_URL = "https://www.chatwithmono.xyz/api/image"
|
|
28 |
MODERATION_API_URL = "https://www.chatwithmono.xyz/api/moderation"
|
29 |
|
30 |
# --- Model Definitions ---
|
31 |
-
# Added florence-2-ocr for the new endpoint
|
32 |
AVAILABLE_MODELS = [
|
33 |
{"id": "gpt-4-turbo", "object": "model", "created": int(time.time()), "owned_by": "system"},
|
34 |
{"id": "gpt-4o", "object": "model", "created": int(time.time()), "owned_by": "system"},
|
@@ -43,10 +42,9 @@ MODEL_ALIASES = {}
|
|
43 |
app = FastAPI(
|
44 |
title="OpenAI Compatible API",
|
45 |
description="An adapter for various services to be compatible with the OpenAI API specification.",
|
46 |
-
version="1.1.
|
47 |
)
|
48 |
|
49 |
-
# Initialize Gradio client for OCR globally to avoid re-initialization on each request
|
50 |
try:
|
51 |
ocr_client = Client("multimodalart/Florence-2-l4")
|
52 |
except Exception as e:
|
@@ -54,8 +52,7 @@ except Exception as e:
|
|
54 |
ocr_client = None
|
55 |
|
56 |
# --- Pydantic Models ---
|
57 |
-
|
58 |
-
# /v1/chat/completions
|
59 |
class Message(BaseModel):
|
60 |
role: str
|
61 |
content: str
|
@@ -66,7 +63,6 @@ class ChatRequest(BaseModel):
|
|
66 |
stream: Optional[bool] = False
|
67 |
tools: Optional[Any] = None
|
68 |
|
69 |
-
# /v1/images/generations
|
70 |
class ImageGenerationRequest(BaseModel):
|
71 |
prompt: str
|
72 |
aspect_ratio: Optional[str] = "1:1"
|
@@ -74,12 +70,10 @@ class ImageGenerationRequest(BaseModel):
|
|
74 |
user: Optional[str] = None
|
75 |
model: Optional[str] = "default"
|
76 |
|
77 |
-
# /v1/moderations
|
78 |
class ModerationRequest(BaseModel):
|
79 |
input: Union[str, List[str]]
|
80 |
model: Optional[str] = "text-moderation-stable"
|
81 |
|
82 |
-
# /v1/ocr
|
83 |
class OcrRequest(BaseModel):
|
84 |
image_url: Optional[str] = Field(None, description="URL of the image to process.")
|
85 |
image_b64: Optional[str] = Field(None, description="Base64 encoded string of the image to process.")
|
@@ -88,11 +82,9 @@ class OcrRequest(BaseModel):
|
|
88 |
@classmethod
|
89 |
def check_sources(cls, data: Any) -> Any:
|
90 |
if isinstance(data, dict):
|
91 |
-
|
92 |
-
b64 = data.get('image_b64')
|
93 |
-
if not (url or b64):
|
94 |
raise ValueError('Either image_url or image_b64 must be provided.')
|
95 |
-
if
|
96 |
raise ValueError('Provide either image_url or image_b64, not both.')
|
97 |
return data
|
98 |
|
@@ -100,10 +92,8 @@ class OcrResponse(BaseModel):
|
|
100 |
ocr_text: str
|
101 |
raw_response: dict
|
102 |
|
103 |
-
|
104 |
-
# --- Helper Function for Random ID Generation ---
|
105 |
def generate_random_id(prefix: str, length: int = 29) -> str:
|
106 |
-
"""Generates a cryptographically secure, random alphanumeric ID."""
|
107 |
population = string.ascii_letters + string.digits
|
108 |
random_part = "".join(secrets.choice(population) for _ in range(length))
|
109 |
return f"{prefix}{random_part}"
|
@@ -115,6 +105,7 @@ async def list_models():
|
|
115 |
"""Lists the available models."""
|
116 |
return {"object": "list", "data": AVAILABLE_MODELS}
|
117 |
|
|
|
118 |
@app.post("/v1/chat/completions", tags=["Chat"])
|
119 |
async def chat_completion(request: ChatRequest):
|
120 |
"""Handles chat completion requests, supporting streaming and non-streaming."""
|
@@ -128,7 +119,6 @@ async def chat_completion(request: ChatRequest):
|
|
128 |
'user-agent': 'Mozilla/5.0',
|
129 |
}
|
130 |
|
131 |
-
# Handle tool prompting
|
132 |
if request.tools:
|
133 |
tool_prompt = f"""You have access to the following tools. To call a tool, please respond with JSON for a tool call within <tool_call></tool_call> XML tags. Respond in the format {{"name": tool name, "parameters": dictionary of argument name and its value}}. Do not use variables.
|
134 |
Tools: {";".join(f"<tool>{tool}</tool>" for tool in request.tools)}
|
@@ -181,7 +171,6 @@ Response Format for tool call:
|
|
181 |
|
182 |
in_tool_call = False
|
183 |
tool_call_buffer = ""
|
184 |
-
# Process text that might come after the tool call in the same chunk
|
185 |
remaining_text = current_buffer.split("</tool_call>", 1)[1]
|
186 |
if remaining_text:
|
187 |
content_piece = remaining_text
|
@@ -191,16 +180,14 @@ Response Format for tool call:
|
|
191 |
if "<tool_call>" in content_piece:
|
192 |
in_tool_call = True
|
193 |
tool_call_buffer += content_piece.split("<tool_call>", 1)[1]
|
194 |
-
# Process text that came before the tool call
|
195 |
text_before = content_piece.split("<tool_call>", 1)[0]
|
196 |
if text_before:
|
197 |
-
# Send the text before the tool call starts
|
198 |
delta = {"content": text_before, "tool_calls": None}
|
199 |
chunk = {"id": chat_id, "object": "chat.completion.chunk", "created": created, "model": model_id,
|
200 |
"choices": [{"index": 0, "delta": delta, "finish_reason": None}], "usage": None}
|
201 |
yield f"data: {json.dumps(chunk)}\n\n"
|
202 |
if "</tool_call>" not in tool_call_buffer:
|
203 |
-
continue
|
204 |
|
205 |
if not in_tool_call:
|
206 |
delta = {"content": content_piece}
|
@@ -217,7 +204,6 @@ Response Format for tool call:
|
|
217 |
except (json.JSONDecodeError, AttributeError): pass
|
218 |
break
|
219 |
|
220 |
-
# Finalize
|
221 |
final_usage = None
|
222 |
if usage_info:
|
223 |
final_usage = {"prompt_tokens": usage_info.get("promptTokens", 0), "completion_tokens": usage_info.get("completionTokens", 0), "total_tokens": usage_info.get("promptTokens", 0) + usage_info.get("completionTokens", 0)}
|
@@ -232,7 +218,7 @@ Response Format for tool call:
|
|
232 |
yield "data: [DONE]\n\n"
|
233 |
|
234 |
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
235 |
-
else:
|
236 |
full_response, usage_info = "", {}
|
237 |
try:
|
238 |
async with httpx.AsyncClient(timeout=120) as client:
|
@@ -300,6 +286,8 @@ async def generate_images(request: ImageGenerationRequest):
|
|
300 |
return JSONResponse(status_code=500, content={"error": "An internal error occurred.", "details": str(e)})
|
301 |
return {"created": int(time.time()), "data": results}
|
302 |
|
|
|
|
|
303 |
@app.post("/v1/ocr", response_model=OcrResponse, tags=["OCR"])
|
304 |
async def perform_ocr(request: OcrRequest):
|
305 |
"""
|
@@ -322,16 +310,40 @@ async def perform_ocr(request: OcrRequest):
|
|
322 |
|
323 |
prediction = ocr_client.predict(image=handle_file(image_path), task_prompt="OCR", api_name="/process_image")
|
324 |
|
325 |
-
if not prediction or not isinstance(prediction, tuple):
|
326 |
-
raise HTTPException(status_code=502, detail="Invalid response from OCR service.")
|
327 |
|
328 |
-
|
329 |
-
|
330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
except Exception as e:
|
|
|
|
|
|
|
332 |
raise HTTPException(status_code=500, detail=f"An error occurred during OCR processing: {str(e)}")
|
333 |
finally:
|
334 |
-
if temp_file_path:
|
335 |
os.unlink(temp_file_path)
|
336 |
|
337 |
@app.post("/v1/moderations", tags=["Moderation"])
|
|
|
28 |
MODERATION_API_URL = "https://www.chatwithmono.xyz/api/moderation"
|
29 |
|
30 |
# --- Model Definitions ---
|
|
|
31 |
AVAILABLE_MODELS = [
|
32 |
{"id": "gpt-4-turbo", "object": "model", "created": int(time.time()), "owned_by": "system"},
|
33 |
{"id": "gpt-4o", "object": "model", "created": int(time.time()), "owned_by": "system"},
|
|
|
42 |
app = FastAPI(
|
43 |
title="OpenAI Compatible API",
|
44 |
description="An adapter for various services to be compatible with the OpenAI API specification.",
|
45 |
+
version="1.1.1" # Incremented version for the fix
|
46 |
)
|
47 |
|
|
|
48 |
try:
|
49 |
ocr_client = Client("multimodalart/Florence-2-l4")
|
50 |
except Exception as e:
|
|
|
52 |
ocr_client = None
|
53 |
|
54 |
# --- Pydantic Models ---
|
55 |
+
# (Pydantic models are unchanged and remain the same as before)
|
|
|
56 |
class Message(BaseModel):
|
57 |
role: str
|
58 |
content: str
|
|
|
63 |
stream: Optional[bool] = False
|
64 |
tools: Optional[Any] = None
|
65 |
|
|
|
66 |
class ImageGenerationRequest(BaseModel):
|
67 |
prompt: str
|
68 |
aspect_ratio: Optional[str] = "1:1"
|
|
|
70 |
user: Optional[str] = None
|
71 |
model: Optional[str] = "default"
|
72 |
|
|
|
73 |
class ModerationRequest(BaseModel):
|
74 |
input: Union[str, List[str]]
|
75 |
model: Optional[str] = "text-moderation-stable"
|
76 |
|
|
|
77 |
class OcrRequest(BaseModel):
|
78 |
image_url: Optional[str] = Field(None, description="URL of the image to process.")
|
79 |
image_b64: Optional[str] = Field(None, description="Base64 encoded string of the image to process.")
|
|
|
82 |
@classmethod
|
83 |
def check_sources(cls, data: Any) -> Any:
|
84 |
if isinstance(data, dict):
|
85 |
+
if not (data.get('image_url') or data.get('image_b64')):
|
|
|
|
|
86 |
raise ValueError('Either image_url or image_b64 must be provided.')
|
87 |
+
if data.get('image_url') and data.get('image_b64'):
|
88 |
raise ValueError('Provide either image_url or image_b64, not both.')
|
89 |
return data
|
90 |
|
|
|
92 |
ocr_text: str
|
93 |
raw_response: dict
|
94 |
|
95 |
+
# --- Helper Function ---
|
|
|
96 |
def generate_random_id(prefix: str, length: int = 29) -> str:
|
|
|
97 |
population = string.ascii_letters + string.digits
|
98 |
random_part = "".join(secrets.choice(population) for _ in range(length))
|
99 |
return f"{prefix}{random_part}"
|
|
|
105 |
"""Lists the available models."""
|
106 |
return {"object": "list", "data": AVAILABLE_MODELS}
|
107 |
|
108 |
+
# (Chat, Image Generation, and Moderation endpoints are unchanged)
|
109 |
@app.post("/v1/chat/completions", tags=["Chat"])
|
110 |
async def chat_completion(request: ChatRequest):
|
111 |
"""Handles chat completion requests, supporting streaming and non-streaming."""
|
|
|
119 |
'user-agent': 'Mozilla/5.0',
|
120 |
}
|
121 |
|
|
|
122 |
if request.tools:
|
123 |
tool_prompt = f"""You have access to the following tools. To call a tool, please respond with JSON for a tool call within <tool_call></tool_call> XML tags. Respond in the format {{"name": tool name, "parameters": dictionary of argument name and its value}}. Do not use variables.
|
124 |
Tools: {";".join(f"<tool>{tool}</tool>" for tool in request.tools)}
|
|
|
171 |
|
172 |
in_tool_call = False
|
173 |
tool_call_buffer = ""
|
|
|
174 |
remaining_text = current_buffer.split("</tool_call>", 1)[1]
|
175 |
if remaining_text:
|
176 |
content_piece = remaining_text
|
|
|
180 |
if "<tool_call>" in content_piece:
|
181 |
in_tool_call = True
|
182 |
tool_call_buffer += content_piece.split("<tool_call>", 1)[1]
|
|
|
183 |
text_before = content_piece.split("<tool_call>", 1)[0]
|
184 |
if text_before:
|
|
|
185 |
delta = {"content": text_before, "tool_calls": None}
|
186 |
chunk = {"id": chat_id, "object": "chat.completion.chunk", "created": created, "model": model_id,
|
187 |
"choices": [{"index": 0, "delta": delta, "finish_reason": None}], "usage": None}
|
188 |
yield f"data: {json.dumps(chunk)}\n\n"
|
189 |
if "</tool_call>" not in tool_call_buffer:
|
190 |
+
continue
|
191 |
|
192 |
if not in_tool_call:
|
193 |
delta = {"content": content_piece}
|
|
|
204 |
except (json.JSONDecodeError, AttributeError): pass
|
205 |
break
|
206 |
|
|
|
207 |
final_usage = None
|
208 |
if usage_info:
|
209 |
final_usage = {"prompt_tokens": usage_info.get("promptTokens", 0), "completion_tokens": usage_info.get("completionTokens", 0), "total_tokens": usage_info.get("promptTokens", 0) + usage_info.get("completionTokens", 0)}
|
|
|
218 |
yield "data: [DONE]\n\n"
|
219 |
|
220 |
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
221 |
+
else:
|
222 |
full_response, usage_info = "", {}
|
223 |
try:
|
224 |
async with httpx.AsyncClient(timeout=120) as client:
|
|
|
286 |
return JSONResponse(status_code=500, content={"error": "An internal error occurred.", "details": str(e)})
|
287 |
return {"created": int(time.time()), "data": results}
|
288 |
|
289 |
+
|
290 |
+
# === FIXED OCR Endpoint ===
|
291 |
@app.post("/v1/ocr", response_model=OcrResponse, tags=["OCR"])
|
292 |
async def perform_ocr(request: OcrRequest):
|
293 |
"""
|
|
|
310 |
|
311 |
prediction = ocr_client.predict(image=handle_file(image_path), task_prompt="OCR", api_name="/process_image")
|
312 |
|
313 |
+
if not prediction or not isinstance(prediction, tuple) or len(prediction) == 0:
|
314 |
+
raise HTTPException(status_code=502, detail="Invalid or empty response from OCR service.")
|
315 |
|
316 |
+
raw_output = prediction[0]
|
317 |
+
raw_result_dict = {}
|
318 |
+
|
319 |
+
# --- START: FIX ---
|
320 |
+
# The Gradio client returns a JSON string, not a dict. We must parse it.
|
321 |
+
if isinstance(raw_output, str):
|
322 |
+
try:
|
323 |
+
raw_result_dict = json.loads(raw_output)
|
324 |
+
except json.JSONDecodeError:
|
325 |
+
raise HTTPException(status_code=502, detail="Failed to parse JSON response from OCR service.")
|
326 |
+
elif isinstance(raw_output, dict):
|
327 |
+
# If it's already a dict, use it directly
|
328 |
+
raw_result_dict = raw_output
|
329 |
+
else:
|
330 |
+
raise HTTPException(status_code=502, detail=f"Unexpected data type from OCR service: {type(raw_output)}")
|
331 |
+
# --- END: FIX ---
|
332 |
+
|
333 |
+
ocr_text = raw_result_dict.get("OCR", "")
|
334 |
+
# Fallback in case the OCR key is missing but there's other data
|
335 |
+
if not ocr_text:
|
336 |
+
ocr_text = str(raw_result_dict)
|
337 |
+
|
338 |
+
return OcrResponse(ocr_text=ocr_text, raw_response=raw_result_dict)
|
339 |
+
|
340 |
except Exception as e:
|
341 |
+
# Catch the specific HTTPException and re-raise it, otherwise wrap other exceptions
|
342 |
+
if isinstance(e, HTTPException):
|
343 |
+
raise e
|
344 |
raise HTTPException(status_code=500, detail=f"An error occurred during OCR processing: {str(e)}")
|
345 |
finally:
|
346 |
+
if temp_file_path and os.path.exists(temp_file_path):
|
347 |
os.unlink(temp_file_path)
|
348 |
|
349 |
@app.post("/v1/moderations", tags=["Moderation"])
|