File size: 1,978 Bytes
7d5d5aa
 
 
 
6816a14
 
2661513
1bdf561
 
 
7d5d5aa
 
 
 
2661513
 
1bdf561
7d5d5aa
 
 
 
 
 
2661513
7d5d5aa
6816a14
2661513
 
 
6816a14
2661513
 
6816a14
 
7d5d5aa
 
 
 
 
 
 
6816a14
2661513
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from fastapi import FastAPI, File, UploadFile, Form, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import io

app = FastAPI()

# Setup template engine
templates = Jinja2Templates(directory="templates")

# Load model and processor once
processor = TrOCRProcessor.from_pretrained('tjoab/latex_finetuned')
model = VisionEncoderDecoderModel.from_pretrained('tjoab/latex_finetuned')

@app.get("/", response_class=HTMLResponse)
async def form_page(request: Request):
    return templates.TemplateResponse("form.html", {"request": request, "result": None})

@app.post("/", response_class=HTMLResponse)
async def handle_upload(request: Request, file: UploadFile = File(...)):
    if file.content_type not in ["image/png", "image/jpeg"]:
        return templates.TemplateResponse("form.html", {"request": request, "result": "Invalid file type"})

    contents = await file.read()
    image = Image.open(io.BytesIO(contents))
    image = prepare_image(image)

    inputs = processor(images=image, return_tensors="pt").pixel_values
    pred_ids = model.generate(inputs, max_length=128)
    latex_preds = processor.batch_decode(pred_ids, skip_special_tokens=True)

    return templates.TemplateResponse("form.html", {"request": request, "result": latex_preds[0]})

def prepare_image(image: Image.Image) -> Image.Image:
    if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
        background = Image.new('RGB', image.size, 'white')
        return Image.alpha_composite(background, image.convert('RGBA')).convert('RGB')
    return image.convert('RGB')

def prepare_image(image: Image.Image) -> Image.Image:
    """Converts image to RGB if needed and flattens transparency if present."""
    return Image.composite(image, Image.new('RGB', image.size, 'white'), image)