texteller-demo / app.py
bpdev75's picture
Create app.py
86e21aa verified
raw
history blame
2.5 kB
import streamlit as st
from transformers import VisionEncoderDecoderModel, AutoTokenizer
from texteller.models.ocr_model.utils.inference import inference as latex_inference
from texteller.models.ocr_model.utils.to_katex import to_katex
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import io
# Configure Streamlit page layout
st.set_page_config(layout="wide")
st.title("TeXTeller Demo – LaTeX Code Prediction from Images")
# Load the TeXTeller model and tokenizer only once
@st.cache_resource
def load_model():
checkpoint = "OleehyO/TexTeller"
model = VisionEncoderDecoderModel.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
return model, tokenizer
model, tokenizer = load_model()
# Utility function to render LaTeX as an image
def latex2image(latex_expression, image_size_in=(3, 0.5), fontsize=16, dpi=200):
fig = plt.figure(figsize=image_size_in, dpi=dpi)
fig.text(
x=0.5,
y=0.5,
s=f"${latex_expression}$",
horizontalalignment="center",
verticalalignment="center",
fontsize=fontsize
)
buf = io.BytesIO()
plt.savefig(buf, format="PNG", bbox_inches="tight", pad_inches=0.1)
plt.close(fig)
buf.seek(0)
return Image.open(buf)
# Upload box for the user to provide an input image
uploaded_file = st.file_uploader("Upload a math image (JPG, PNG)...", type=["jpg", "jpeg", "png"])
# If an image is uploaded, process it
if uploaded_file:
# Display three columns: original image, predicted LaTeX, rendered LaTeX
col1, col2, col3 = st.columns(3)
# Load image using PIL
image = Image.open(uploaded_file)
with col1:
st.image(image, caption="Original Image", use_container_width=True)
# Perform prediction
with st.spinner("Running OCR model..."):
try:
res = latex_inference(model, tokenizer, [np.array(image)])
predicted_latex = to_katex(res[0])
# Show the predicted LaTeX string
with col2:
st.markdown("**Predicted LaTeX code:**")
st.text_area(label="", value=predicted_latex, height=80)
# Convert LaTeX string to an image and display
with col3:
pred_image = latex2image(predicted_latex)
st.image(pred_image, caption="Rendered from Prediction", use_container_width=True)
except Exception as e:
st.error(f"Error during prediction: {e}")