from fastapi import APIRouter, File, UploadFile, Form, HTTPException from typing import Optional from PIL import Image import urllib.request from io import BytesIO from config import settings import utils import os import json from routers.donut_inference import process_document_donut import logging import io # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) router = APIRouter() def count_values(obj): if isinstance(obj, dict): count = 0 for value in obj.values(): count += count_values(value) return count elif isinstance(obj, list): count = 0 for item in obj: count += count_values(item) return count else: return 1 @router.post("/inference") async def run_inference( file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None), model_in_use: str = Form('donut'), shipper_id: Optional[int] = Form(None) ): # Validate input if not file and not image_url: return {"info": "No input provided"} # Log the shipper_id that was received logger.info(f"Received inference request with shipper_id: {shipper_id}") # Convert shipper_id to string if provided (config.py expects a string) shipper_id_str = str(shipper_id) if shipper_id is not None else "default_shipper" logger.info(f"Using shipper_id: {shipper_id_str} for model selection") result = [] processing_time = 0 try: if file: # Ensure the uploaded file is a JPG image if file.content_type not in ["image/jpeg", "image/jpg"]: logger.warning(f"Invalid file type: {file.content_type}") return {"error": "Invalid file type. Only JPG images are allowed."} logger.info(f"Processing file: {file.filename}") image = Image.open(BytesIO(await file.read())) if model_in_use == 'donut': # Pass the shipper_id to the processing function result, processing_time = process_document_donut(image, shipper_id_str) utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, settings.model]) logger.info(f"Processing time: {processing_time:.2f} seconds with model: {settings.model}") else: logger.warning(f"Unsupported model: {model_in_use}") return {"error": f"Unsupported model: {model_in_use}"} elif image_url: logger.info(f"Processing image from URL: {image_url}") # test image url: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/invoices/processed/images/invoice_10.jpg try: with urllib.request.urlopen(image_url) as url: image = Image.open(BytesIO(url.read())) if model_in_use == 'donut': # Pass the shipper_id to the processing function result, processing_time = process_document_donut(image, shipper_id_str) # parse file name from url file_name = image_url.split("/")[-1] utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, settings.model]) logger.info(f"Processing time inference: {processing_time:.2f} seconds with model: {settings.model}") else: logger.warning(f"Unsupported model: {model_in_use}") return {"error": f"Unsupported model: {model_in_use}"} except Exception as e: logger.error(f"Error processing image URL: {str(e)}") return {"error": f"Error processing image URL: {str(e)}"} except Exception as e: logger.error(f"Error during inference: {str(e)}") return {"error": f"Inference failed: {str(e)}"} return result @router.get("/statistics") async def get_statistics(): file_path = settings.inference_stats_file # Check if the file exists, and read its content if os.path.exists(file_path): with open(file_path, 'r') as file: try: content = json.load(file) except json.JSONDecodeError: content = [] else: content = [] return content