grademe / app.py
vverma
HTML form
7d5d5aa
from fastapi import FastAPI, File, UploadFile, Form, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import io
app = FastAPI()
# Setup template engine
templates = Jinja2Templates(directory="templates")
# Load model and processor once
processor = TrOCRProcessor.from_pretrained('tjoab/latex_finetuned')
model = VisionEncoderDecoderModel.from_pretrained('tjoab/latex_finetuned')
@app.get("/", response_class=HTMLResponse)
async def form_page(request: Request):
return templates.TemplateResponse("form.html", {"request": request, "result": None})
@app.post("/", response_class=HTMLResponse)
async def handle_upload(request: Request, file: UploadFile = File(...)):
if file.content_type not in ["image/png", "image/jpeg"]:
return templates.TemplateResponse("form.html", {"request": request, "result": "Invalid file type"})
contents = await file.read()
image = Image.open(io.BytesIO(contents))
image = prepare_image(image)
inputs = processor(images=image, return_tensors="pt").pixel_values
pred_ids = model.generate(inputs, max_length=128)
latex_preds = processor.batch_decode(pred_ids, skip_special_tokens=True)
return templates.TemplateResponse("form.html", {"request": request, "result": latex_preds[0]})
def prepare_image(image: Image.Image) -> Image.Image:
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
background = Image.new('RGB', image.size, 'white')
return Image.alpha_composite(background, image.convert('RGBA')).convert('RGB')
return image.convert('RGB')
def prepare_image(image: Image.Image) -> Image.Image:
"""Converts image to RGB if needed and flattens transparency if present."""
return Image.composite(image, Image.new('RGB', image.size, 'white'), image)