File size: 2,495 Bytes
86e21aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import streamlit as st
from transformers import VisionEncoderDecoderModel, AutoTokenizer
from texteller.models.ocr_model.utils.inference import inference as latex_inference
from texteller.models.ocr_model.utils.to_katex import to_katex
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import io

# Configure Streamlit page layout
st.set_page_config(layout="wide")
st.title("TeXTeller Demo – LaTeX Code Prediction from Images")

# Load the TeXTeller model and tokenizer only once
@st.cache_resource
def load_model():
    checkpoint = "OleehyO/TexTeller"
    model = VisionEncoderDecoderModel.from_pretrained(checkpoint)
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    return model, tokenizer

model, tokenizer = load_model()

# Utility function to render LaTeX as an image
def latex2image(latex_expression, image_size_in=(3, 0.5), fontsize=16, dpi=200):
    fig = plt.figure(figsize=image_size_in, dpi=dpi)
    fig.text(
        x=0.5,
        y=0.5,
        s=f"${latex_expression}$",
        horizontalalignment="center",
        verticalalignment="center",
        fontsize=fontsize
    )
    buf = io.BytesIO()
    plt.savefig(buf, format="PNG", bbox_inches="tight", pad_inches=0.1)
    plt.close(fig)
    buf.seek(0)
    return Image.open(buf)

# Upload box for the user to provide an input image
uploaded_file = st.file_uploader("Upload a math image (JPG, PNG)...", type=["jpg", "jpeg", "png"])

# If an image is uploaded, process it
if uploaded_file:
    # Display three columns: original image, predicted LaTeX, rendered LaTeX
    col1, col2, col3 = st.columns(3)

    # Load image using PIL
    image = Image.open(uploaded_file)
    with col1:
        st.image(image, caption="Original Image", use_container_width=True)

    # Perform prediction
    with st.spinner("Running OCR model..."):
        try:
            res = latex_inference(model, tokenizer, [np.array(image)])
            predicted_latex = to_katex(res[0])

            # Show the predicted LaTeX string
            with col2:
                st.markdown("**Predicted LaTeX code:**")
                st.text_area(label="", value=predicted_latex, height=80)

            # Convert LaTeX string to an image and display
            with col3:
                pred_image = latex2image(predicted_latex)
                st.image(pred_image, caption="Rendered from Prediction", use_container_width=True)

        except Exception as e:
            st.error(f"Error during prediction: {e}")