AIMaster7 commited on
Commit
4f72e24
·
verified ·
1 Parent(s): a302e21

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +40 -28
main.py CHANGED
@@ -28,7 +28,6 @@ IMAGE_GEN_API_URL = "https://www.chatwithmono.xyz/api/image"
28
  MODERATION_API_URL = "https://www.chatwithmono.xyz/api/moderation"
29
 
30
  # --- Model Definitions ---
31
- # Added florence-2-ocr for the new endpoint
32
  AVAILABLE_MODELS = [
33
  {"id": "gpt-4-turbo", "object": "model", "created": int(time.time()), "owned_by": "system"},
34
  {"id": "gpt-4o", "object": "model", "created": int(time.time()), "owned_by": "system"},
@@ -43,10 +42,9 @@ MODEL_ALIASES = {}
43
  app = FastAPI(
44
  title="OpenAI Compatible API",
45
  description="An adapter for various services to be compatible with the OpenAI API specification.",
46
- version="1.1.0"
47
  )
48
 
49
- # Initialize Gradio client for OCR globally to avoid re-initialization on each request
50
  try:
51
  ocr_client = Client("multimodalart/Florence-2-l4")
52
  except Exception as e:
@@ -54,8 +52,7 @@ except Exception as e:
54
  ocr_client = None
55
 
56
  # --- Pydantic Models ---
57
-
58
- # /v1/chat/completions
59
  class Message(BaseModel):
60
  role: str
61
  content: str
@@ -66,7 +63,6 @@ class ChatRequest(BaseModel):
66
  stream: Optional[bool] = False
67
  tools: Optional[Any] = None
68
 
69
- # /v1/images/generations
70
  class ImageGenerationRequest(BaseModel):
71
  prompt: str
72
  aspect_ratio: Optional[str] = "1:1"
@@ -74,12 +70,10 @@ class ImageGenerationRequest(BaseModel):
74
  user: Optional[str] = None
75
  model: Optional[str] = "default"
76
 
77
- # /v1/moderations
78
  class ModerationRequest(BaseModel):
79
  input: Union[str, List[str]]
80
  model: Optional[str] = "text-moderation-stable"
81
 
82
- # /v1/ocr
83
  class OcrRequest(BaseModel):
84
  image_url: Optional[str] = Field(None, description="URL of the image to process.")
85
  image_b64: Optional[str] = Field(None, description="Base64 encoded string of the image to process.")
@@ -88,11 +82,9 @@ class OcrRequest(BaseModel):
88
  @classmethod
89
  def check_sources(cls, data: Any) -> Any:
90
  if isinstance(data, dict):
91
- url = data.get('image_url')
92
- b64 = data.get('image_b64')
93
- if not (url or b64):
94
  raise ValueError('Either image_url or image_b64 must be provided.')
95
- if url and b64:
96
  raise ValueError('Provide either image_url or image_b64, not both.')
97
  return data
98
 
@@ -100,10 +92,8 @@ class OcrResponse(BaseModel):
100
  ocr_text: str
101
  raw_response: dict
102
 
103
-
104
- # --- Helper Function for Random ID Generation ---
105
  def generate_random_id(prefix: str, length: int = 29) -> str:
106
- """Generates a cryptographically secure, random alphanumeric ID."""
107
  population = string.ascii_letters + string.digits
108
  random_part = "".join(secrets.choice(population) for _ in range(length))
109
  return f"{prefix}{random_part}"
@@ -115,6 +105,7 @@ async def list_models():
115
  """Lists the available models."""
116
  return {"object": "list", "data": AVAILABLE_MODELS}
117
 
 
118
  @app.post("/v1/chat/completions", tags=["Chat"])
119
  async def chat_completion(request: ChatRequest):
120
  """Handles chat completion requests, supporting streaming and non-streaming."""
@@ -128,7 +119,6 @@ async def chat_completion(request: ChatRequest):
128
  'user-agent': 'Mozilla/5.0',
129
  }
130
 
131
- # Handle tool prompting
132
  if request.tools:
133
  tool_prompt = f"""You have access to the following tools. To call a tool, please respond with JSON for a tool call within <tool_call></tool_call> XML tags. Respond in the format {{"name": tool name, "parameters": dictionary of argument name and its value}}. Do not use variables.
134
  Tools: {";".join(f"<tool>{tool}</tool>" for tool in request.tools)}
@@ -181,7 +171,6 @@ Response Format for tool call:
181
 
182
  in_tool_call = False
183
  tool_call_buffer = ""
184
- # Process text that might come after the tool call in the same chunk
185
  remaining_text = current_buffer.split("</tool_call>", 1)[1]
186
  if remaining_text:
187
  content_piece = remaining_text
@@ -191,16 +180,14 @@ Response Format for tool call:
191
  if "<tool_call>" in content_piece:
192
  in_tool_call = True
193
  tool_call_buffer += content_piece.split("<tool_call>", 1)[1]
194
- # Process text that came before the tool call
195
  text_before = content_piece.split("<tool_call>", 1)[0]
196
  if text_before:
197
- # Send the text before the tool call starts
198
  delta = {"content": text_before, "tool_calls": None}
199
  chunk = {"id": chat_id, "object": "chat.completion.chunk", "created": created, "model": model_id,
200
  "choices": [{"index": 0, "delta": delta, "finish_reason": None}], "usage": None}
201
  yield f"data: {json.dumps(chunk)}\n\n"
202
  if "</tool_call>" not in tool_call_buffer:
203
- continue # Wait for the closing tag
204
 
205
  if not in_tool_call:
206
  delta = {"content": content_piece}
@@ -217,7 +204,6 @@ Response Format for tool call:
217
  except (json.JSONDecodeError, AttributeError): pass
218
  break
219
 
220
- # Finalize
221
  final_usage = None
222
  if usage_info:
223
  final_usage = {"prompt_tokens": usage_info.get("promptTokens", 0), "completion_tokens": usage_info.get("completionTokens", 0), "total_tokens": usage_info.get("promptTokens", 0) + usage_info.get("completionTokens", 0)}
@@ -232,7 +218,7 @@ Response Format for tool call:
232
  yield "data: [DONE]\n\n"
233
 
234
  return StreamingResponse(event_stream(), media_type="text/event-stream")
235
- else: # Non-streaming
236
  full_response, usage_info = "", {}
237
  try:
238
  async with httpx.AsyncClient(timeout=120) as client:
@@ -300,6 +286,8 @@ async def generate_images(request: ImageGenerationRequest):
300
  return JSONResponse(status_code=500, content={"error": "An internal error occurred.", "details": str(e)})
301
  return {"created": int(time.time()), "data": results}
302
 
 
 
303
  @app.post("/v1/ocr", response_model=OcrResponse, tags=["OCR"])
304
  async def perform_ocr(request: OcrRequest):
305
  """
@@ -322,16 +310,40 @@ async def perform_ocr(request: OcrRequest):
322
 
323
  prediction = ocr_client.predict(image=handle_file(image_path), task_prompt="OCR", api_name="/process_image")
324
 
325
- if not prediction or not isinstance(prediction, tuple):
326
- raise HTTPException(status_code=502, detail="Invalid response from OCR service.")
327
 
328
- raw_result = prediction[0]
329
- ocr_text = raw_result.get("OCR", "")
330
- return OcrResponse(ocr_text=ocr_text, raw_response=raw_result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  except Exception as e:
 
 
 
332
  raise HTTPException(status_code=500, detail=f"An error occurred during OCR processing: {str(e)}")
333
  finally:
334
- if temp_file_path:
335
  os.unlink(temp_file_path)
336
 
337
  @app.post("/v1/moderations", tags=["Moderation"])
 
28
  MODERATION_API_URL = "https://www.chatwithmono.xyz/api/moderation"
29
 
30
  # --- Model Definitions ---
 
31
  AVAILABLE_MODELS = [
32
  {"id": "gpt-4-turbo", "object": "model", "created": int(time.time()), "owned_by": "system"},
33
  {"id": "gpt-4o", "object": "model", "created": int(time.time()), "owned_by": "system"},
 
42
  app = FastAPI(
43
  title="OpenAI Compatible API",
44
  description="An adapter for various services to be compatible with the OpenAI API specification.",
45
+ version="1.1.1" # Incremented version for the fix
46
  )
47
 
 
48
  try:
49
  ocr_client = Client("multimodalart/Florence-2-l4")
50
  except Exception as e:
 
52
  ocr_client = None
53
 
54
  # --- Pydantic Models ---
55
+ # (Pydantic models are unchanged and remain the same as before)
 
56
  class Message(BaseModel):
57
  role: str
58
  content: str
 
63
  stream: Optional[bool] = False
64
  tools: Optional[Any] = None
65
 
 
66
  class ImageGenerationRequest(BaseModel):
67
  prompt: str
68
  aspect_ratio: Optional[str] = "1:1"
 
70
  user: Optional[str] = None
71
  model: Optional[str] = "default"
72
 
 
73
  class ModerationRequest(BaseModel):
74
  input: Union[str, List[str]]
75
  model: Optional[str] = "text-moderation-stable"
76
 
 
77
  class OcrRequest(BaseModel):
78
  image_url: Optional[str] = Field(None, description="URL of the image to process.")
79
  image_b64: Optional[str] = Field(None, description="Base64 encoded string of the image to process.")
 
82
  @classmethod
83
  def check_sources(cls, data: Any) -> Any:
84
  if isinstance(data, dict):
85
+ if not (data.get('image_url') or data.get('image_b64')):
 
 
86
  raise ValueError('Either image_url or image_b64 must be provided.')
87
+ if data.get('image_url') and data.get('image_b64'):
88
  raise ValueError('Provide either image_url or image_b64, not both.')
89
  return data
90
 
 
92
  ocr_text: str
93
  raw_response: dict
94
 
95
+ # --- Helper Function ---
 
96
  def generate_random_id(prefix: str, length: int = 29) -> str:
 
97
  population = string.ascii_letters + string.digits
98
  random_part = "".join(secrets.choice(population) for _ in range(length))
99
  return f"{prefix}{random_part}"
 
105
  """Lists the available models."""
106
  return {"object": "list", "data": AVAILABLE_MODELS}
107
 
108
+ # (Chat, Image Generation, and Moderation endpoints are unchanged)
109
  @app.post("/v1/chat/completions", tags=["Chat"])
110
  async def chat_completion(request: ChatRequest):
111
  """Handles chat completion requests, supporting streaming and non-streaming."""
 
119
  'user-agent': 'Mozilla/5.0',
120
  }
121
 
 
122
  if request.tools:
123
  tool_prompt = f"""You have access to the following tools. To call a tool, please respond with JSON for a tool call within <tool_call></tool_call> XML tags. Respond in the format {{"name": tool name, "parameters": dictionary of argument name and its value}}. Do not use variables.
124
  Tools: {";".join(f"<tool>{tool}</tool>" for tool in request.tools)}
 
171
 
172
  in_tool_call = False
173
  tool_call_buffer = ""
 
174
  remaining_text = current_buffer.split("</tool_call>", 1)[1]
175
  if remaining_text:
176
  content_piece = remaining_text
 
180
  if "<tool_call>" in content_piece:
181
  in_tool_call = True
182
  tool_call_buffer += content_piece.split("<tool_call>", 1)[1]
 
183
  text_before = content_piece.split("<tool_call>", 1)[0]
184
  if text_before:
 
185
  delta = {"content": text_before, "tool_calls": None}
186
  chunk = {"id": chat_id, "object": "chat.completion.chunk", "created": created, "model": model_id,
187
  "choices": [{"index": 0, "delta": delta, "finish_reason": None}], "usage": None}
188
  yield f"data: {json.dumps(chunk)}\n\n"
189
  if "</tool_call>" not in tool_call_buffer:
190
+ continue
191
 
192
  if not in_tool_call:
193
  delta = {"content": content_piece}
 
204
  except (json.JSONDecodeError, AttributeError): pass
205
  break
206
 
 
207
  final_usage = None
208
  if usage_info:
209
  final_usage = {"prompt_tokens": usage_info.get("promptTokens", 0), "completion_tokens": usage_info.get("completionTokens", 0), "total_tokens": usage_info.get("promptTokens", 0) + usage_info.get("completionTokens", 0)}
 
218
  yield "data: [DONE]\n\n"
219
 
220
  return StreamingResponse(event_stream(), media_type="text/event-stream")
221
+ else:
222
  full_response, usage_info = "", {}
223
  try:
224
  async with httpx.AsyncClient(timeout=120) as client:
 
286
  return JSONResponse(status_code=500, content={"error": "An internal error occurred.", "details": str(e)})
287
  return {"created": int(time.time()), "data": results}
288
 
289
+
290
+ # === FIXED OCR Endpoint ===
291
  @app.post("/v1/ocr", response_model=OcrResponse, tags=["OCR"])
292
  async def perform_ocr(request: OcrRequest):
293
  """
 
310
 
311
  prediction = ocr_client.predict(image=handle_file(image_path), task_prompt="OCR", api_name="/process_image")
312
 
313
+ if not prediction or not isinstance(prediction, tuple) or len(prediction) == 0:
314
+ raise HTTPException(status_code=502, detail="Invalid or empty response from OCR service.")
315
 
316
+ raw_output = prediction[0]
317
+ raw_result_dict = {}
318
+
319
+ # --- START: FIX ---
320
+ # The Gradio client returns a JSON string, not a dict. We must parse it.
321
+ if isinstance(raw_output, str):
322
+ try:
323
+ raw_result_dict = json.loads(raw_output)
324
+ except json.JSONDecodeError:
325
+ raise HTTPException(status_code=502, detail="Failed to parse JSON response from OCR service.")
326
+ elif isinstance(raw_output, dict):
327
+ # If it's already a dict, use it directly
328
+ raw_result_dict = raw_output
329
+ else:
330
+ raise HTTPException(status_code=502, detail=f"Unexpected data type from OCR service: {type(raw_output)}")
331
+ # --- END: FIX ---
332
+
333
+ ocr_text = raw_result_dict.get("OCR", "")
334
+ # Fallback in case the OCR key is missing but there's other data
335
+ if not ocr_text:
336
+ ocr_text = str(raw_result_dict)
337
+
338
+ return OcrResponse(ocr_text=ocr_text, raw_response=raw_result_dict)
339
+
340
  except Exception as e:
341
+ # Catch the specific HTTPException and re-raise it, otherwise wrap other exceptions
342
+ if isinstance(e, HTTPException):
343
+ raise e
344
  raise HTTPException(status_code=500, detail=f"An error occurred during OCR processing: {str(e)}")
345
  finally:
346
+ if temp_file_path and os.path.exists(temp_file_path):
347
  os.unlink(temp_file_path)
348
 
349
  @app.post("/v1/moderations", tags=["Moderation"])