|
from fastapi import FastAPI, File, UploadFile, Form, Request |
|
from fastapi.responses import HTMLResponse |
|
from fastapi.templating import Jinja2Templates |
|
from fastapi.staticfiles import StaticFiles |
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
from PIL import Image |
|
import io |
|
|
|
app = FastAPI() |
|
|
|
|
|
templates = Jinja2Templates(directory="templates") |
|
|
|
|
|
processor = TrOCRProcessor.from_pretrained('tjoab/latex_finetuned') |
|
model = VisionEncoderDecoderModel.from_pretrained('tjoab/latex_finetuned') |
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def form_page(request: Request): |
|
return templates.TemplateResponse("form.html", {"request": request, "result": None}) |
|
|
|
@app.post("/", response_class=HTMLResponse) |
|
async def handle_upload(request: Request, file: UploadFile = File(...)): |
|
if file.content_type not in ["image/png", "image/jpeg"]: |
|
return templates.TemplateResponse("form.html", {"request": request, "result": "Invalid file type"}) |
|
|
|
contents = await file.read() |
|
image = Image.open(io.BytesIO(contents)) |
|
image = prepare_image(image) |
|
|
|
inputs = processor(images=image, return_tensors="pt").pixel_values |
|
pred_ids = model.generate(inputs, max_length=128) |
|
latex_preds = processor.batch_decode(pred_ids, skip_special_tokens=True) |
|
|
|
return templates.TemplateResponse("form.html", {"request": request, "result": latex_preds[0]}) |
|
|
|
def prepare_image(image: Image.Image) -> Image.Image: |
|
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): |
|
background = Image.new('RGB', image.size, 'white') |
|
return Image.alpha_composite(background, image.convert('RGBA')).convert('RGB') |
|
return image.convert('RGB') |
|
|
|
def prepare_image(image: Image.Image) -> Image.Image: |
|
"""Converts image to RGB if needed and flattens transparency if present.""" |
|
return Image.composite(image, Image.new('RGB', image.size, 'white'), image) |
|
|