File size: 1,978 Bytes
7d5d5aa 6816a14 2661513 1bdf561 7d5d5aa 2661513 1bdf561 7d5d5aa 2661513 7d5d5aa 6816a14 2661513 6816a14 2661513 6816a14 7d5d5aa 6816a14 2661513 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from fastapi import FastAPI, File, UploadFile, Form, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import io
app = FastAPI()
# Setup template engine
templates = Jinja2Templates(directory="templates")
# Load model and processor once
processor = TrOCRProcessor.from_pretrained('tjoab/latex_finetuned')
model = VisionEncoderDecoderModel.from_pretrained('tjoab/latex_finetuned')
@app.get("/", response_class=HTMLResponse)
async def form_page(request: Request):
return templates.TemplateResponse("form.html", {"request": request, "result": None})
@app.post("/", response_class=HTMLResponse)
async def handle_upload(request: Request, file: UploadFile = File(...)):
if file.content_type not in ["image/png", "image/jpeg"]:
return templates.TemplateResponse("form.html", {"request": request, "result": "Invalid file type"})
contents = await file.read()
image = Image.open(io.BytesIO(contents))
image = prepare_image(image)
inputs = processor(images=image, return_tensors="pt").pixel_values
pred_ids = model.generate(inputs, max_length=128)
latex_preds = processor.batch_decode(pred_ids, skip_special_tokens=True)
return templates.TemplateResponse("form.html", {"request": request, "result": latex_preds[0]})
def prepare_image(image: Image.Image) -> Image.Image:
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
background = Image.new('RGB', image.size, 'white')
return Image.alpha_composite(background, image.convert('RGBA')).convert('RGB')
return image.convert('RGB')
def prepare_image(image: Image.Image) -> Image.Image:
"""Converts image to RGB if needed and flattens transparency if present."""
return Image.composite(image, Image.new('RGB', image.size, 'white'), image)
|