File size: 4,466 Bytes
30875d3
c2d58b3
 
 
 
f1483b9
c2d58b3
 
 
 
948d2eb
 
 
 
 
 
c2d58b3
 
 
 
 
f1483b9
 
 
 
c2d58b3
f1483b9
 
 
 
c2d58b3
 
 
 
30875d3
 
 
 
 
7894c3d
30875d3
 
 
 
 
 
 
 
 
 
f1483b9
1002c61
30875d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1483b9
 
c2d58b3
 
 
 
 
 
 
 
 
 
 
 
30875d3
f1483b9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from fastapi import APIRouter, File, UploadFile, Form, HTTPException
from typing import Optional
from PIL import Image
import urllib.request
from io import BytesIO
from config import settings
import utils
import os
import json
from routers.donut_inference import process_document_donut
import logging
import io

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

router = APIRouter()

def count_values(obj):
    if isinstance(obj, dict):
        count = 0
        for value in obj.values():
            count += count_values(value)
        return count
    elif isinstance(obj, list):
        count = 0
        for item in obj:
            count += count_values(item)
        return count
    else:
        return 1

@router.post("/inference")
async def run_inference(
    file: Optional[UploadFile] = File(None),
    image_url: Optional[str] = Form(None),
    model_in_use: str = Form('donut'),
    shipper_id: Optional[int] = Form(None)
):
    # Validate input
    if not file and not image_url:
        return {"info": "No input provided"}
    
    # Log the shipper_id that was received
    logger.info(f"Received inference request with shipper_id: {shipper_id}")
    
    # Convert shipper_id to string if provided (config.py expects a string)
    shipper_id_str = str(shipper_id) if shipper_id is not None else "default_shipper"
    logger.info(f"Using shipper_id: {shipper_id_str} for model selection")
    
    result = []
    processing_time = 0
    
    try:
        if file:
            # Ensure the uploaded file is a JPG image
            if file.content_type not in ["image/jpeg", "image/jpg"]:
                logger.warning(f"Invalid file type: {file.content_type}")
                return {"error": "Invalid file type. Only JPG images are allowed."}
            
            logger.info(f"Processing file: {file.filename}")
            image = Image.open(BytesIO(await file.read()))
            
            if model_in_use == 'donut':
                # Pass the shipper_id to the processing function
                result, processing_time = process_document_donut(image, shipper_id_str)
                utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, settings.model])
                logger.info(f"Processing time: {processing_time:.2f} seconds with model: {settings.model}")
            else:
                logger.warning(f"Unsupported model: {model_in_use}")
                return {"error": f"Unsupported model: {model_in_use}"}
                
        elif image_url:
            logger.info(f"Processing image from URL: {image_url}")
            # test image url: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/invoices/processed/images/invoice_10.jpg
            try:
                with urllib.request.urlopen(image_url) as url:
                    image = Image.open(BytesIO(url.read()))
                
                if model_in_use == 'donut':
                    # Pass the shipper_id to the processing function
                    result, processing_time = process_document_donut(image, shipper_id_str)
                    # parse file name from url
                    file_name = image_url.split("/")[-1]
                    utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, settings.model])
                    logger.info(f"Processing time inference: {processing_time:.2f} seconds with model: {settings.model}")
                else:
                    logger.warning(f"Unsupported model: {model_in_use}")
                    return {"error": f"Unsupported model: {model_in_use}"}
            except Exception as e:
                logger.error(f"Error processing image URL: {str(e)}")
                return {"error": f"Error processing image URL: {str(e)}"}
    
    except Exception as e:
        logger.error(f"Error during inference: {str(e)}")
        return {"error": f"Inference failed: {str(e)}"}
    
    return result

@router.get("/statistics")
async def get_statistics():
    file_path = settings.inference_stats_file
    # Check if the file exists, and read its content
    if os.path.exists(file_path):
        with open(file_path, 'r') as file:
            try:
                content = json.load(file)
            except json.JSONDecodeError:
                content = []
    else:
        content = []
    
    return content