Spaces:
Running
Running
File size: 2,495 Bytes
86e21aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import streamlit as st
from transformers import VisionEncoderDecoderModel, AutoTokenizer
from texteller.models.ocr_model.utils.inference import inference as latex_inference
from texteller.models.ocr_model.utils.to_katex import to_katex
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import io
# Configure Streamlit page layout
st.set_page_config(layout="wide")
st.title("TeXTeller Demo – LaTeX Code Prediction from Images")
# Load the TeXTeller model and tokenizer only once
@st.cache_resource
def load_model():
checkpoint = "OleehyO/TexTeller"
model = VisionEncoderDecoderModel.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
return model, tokenizer
model, tokenizer = load_model()
# Utility function to render LaTeX as an image
def latex2image(latex_expression, image_size_in=(3, 0.5), fontsize=16, dpi=200):
fig = plt.figure(figsize=image_size_in, dpi=dpi)
fig.text(
x=0.5,
y=0.5,
s=f"${latex_expression}$",
horizontalalignment="center",
verticalalignment="center",
fontsize=fontsize
)
buf = io.BytesIO()
plt.savefig(buf, format="PNG", bbox_inches="tight", pad_inches=0.1)
plt.close(fig)
buf.seek(0)
return Image.open(buf)
# Upload box for the user to provide an input image
uploaded_file = st.file_uploader("Upload a math image (JPG, PNG)...", type=["jpg", "jpeg", "png"])
# If an image is uploaded, process it
if uploaded_file:
# Display three columns: original image, predicted LaTeX, rendered LaTeX
col1, col2, col3 = st.columns(3)
# Load image using PIL
image = Image.open(uploaded_file)
with col1:
st.image(image, caption="Original Image", use_container_width=True)
# Perform prediction
with st.spinner("Running OCR model..."):
try:
res = latex_inference(model, tokenizer, [np.array(image)])
predicted_latex = to_katex(res[0])
# Show the predicted LaTeX string
with col2:
st.markdown("**Predicted LaTeX code:**")
st.text_area(label="", value=predicted_latex, height=80)
# Convert LaTeX string to an image and display
with col3:
pred_image = latex2image(predicted_latex)
st.image(pred_image, caption="Rendered from Prediction", use_container_width=True)
except Exception as e:
st.error(f"Error during prediction: {e}") |