Spaces:
Running
Running
import streamlit as st | |
from transformers import VisionEncoderDecoderModel, AutoTokenizer | |
from texteller.models.ocr_model.utils.inference import inference as latex_inference | |
from texteller.models.ocr_model.utils.to_katex import to_katex | |
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import io | |
# Configure Streamlit page layout | |
st.set_page_config(layout="wide") | |
st.title("TeXTeller Demo – LaTeX Code Prediction from Images") | |
# Load the TeXTeller model and tokenizer only once | |
def load_model(): | |
checkpoint = "OleehyO/TexTeller" | |
model = VisionEncoderDecoderModel.from_pretrained(checkpoint) | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
return model, tokenizer | |
model, tokenizer = load_model() | |
# Utility function to render LaTeX as an image | |
def latex2image(latex_expression, image_size_in=(3, 0.5), fontsize=16, dpi=200): | |
fig = plt.figure(figsize=image_size_in, dpi=dpi) | |
fig.text( | |
x=0.5, | |
y=0.5, | |
s=f"${latex_expression}$", | |
horizontalalignment="center", | |
verticalalignment="center", | |
fontsize=fontsize | |
) | |
buf = io.BytesIO() | |
plt.savefig(buf, format="PNG", bbox_inches="tight", pad_inches=0.1) | |
plt.close(fig) | |
buf.seek(0) | |
return Image.open(buf) | |
# Upload box for the user to provide an input image | |
uploaded_file = st.file_uploader("Upload a math image (JPG, PNG)...", type=["jpg", "jpeg", "png"]) | |
# If an image is uploaded, process it | |
if uploaded_file: | |
# Display three columns: original image, predicted LaTeX, rendered LaTeX | |
col1, col2, col3 = st.columns(3) | |
# Load image using PIL | |
image = Image.open(uploaded_file) | |
with col1: | |
st.image(image, caption="Original Image", use_container_width=True) | |
# Perform prediction | |
with st.spinner("Running OCR model..."): | |
try: | |
res = latex_inference(model, tokenizer, [np.array(image)]) | |
predicted_latex = to_katex(res[0]) | |
# Show the predicted LaTeX string | |
with col2: | |
st.markdown("**Predicted LaTeX code:**") | |
st.text_area(label="", value=predicted_latex, height=80) | |
# Convert LaTeX string to an image and display | |
with col3: | |
pred_image = latex2image(predicted_latex) | |
st.image(pred_image, caption="Rendered from Prediction", use_container_width=True) | |
except Exception as e: | |
st.error(f"Error during prediction: {e}") |