from fastapi import FastAPI, File, UploadFile, Form, Request from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from transformers import TrOCRProcessor, VisionEncoderDecoderModel from PIL import Image import io app = FastAPI() # Setup template engine templates = Jinja2Templates(directory="templates") # Load model and processor once processor = TrOCRProcessor.from_pretrained('tjoab/latex_finetuned') model = VisionEncoderDecoderModel.from_pretrained('tjoab/latex_finetuned') @app.get("/", response_class=HTMLResponse) async def form_page(request: Request): return templates.TemplateResponse("form.html", {"request": request, "result": None}) @app.post("/", response_class=HTMLResponse) async def handle_upload(request: Request, file: UploadFile = File(...)): if file.content_type not in ["image/png", "image/jpeg"]: return templates.TemplateResponse("form.html", {"request": request, "result": "Invalid file type"}) contents = await file.read() image = Image.open(io.BytesIO(contents)) image = prepare_image(image) inputs = processor(images=image, return_tensors="pt").pixel_values pred_ids = model.generate(inputs, max_length=128) latex_preds = processor.batch_decode(pred_ids, skip_special_tokens=True) return templates.TemplateResponse("form.html", {"request": request, "result": latex_preds[0]}) def prepare_image(image: Image.Image) -> Image.Image: if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): background = Image.new('RGB', image.size, 'white') return Image.alpha_composite(background, image.convert('RGBA')).convert('RGB') return image.convert('RGB') def prepare_image(image: Image.Image) -> Image.Image: """Converts image to RGB if needed and flattens transparency if present.""" return Image.composite(image, Image.new('RGB', image.size, 'white'), image)