Spaces:
Runtime error
Runtime error
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor | |
import gradio as gr | |
from PIL import Image | |
# Use a public model identifier. If you need a private model, remember to authenticate. | |
model_name = "google/pix2struct-textcaps-base" | |
model = Pix2StructForConditionalGeneration.from_pretrained(model_name) | |
processor = Pix2StructProcessor.from_pretrained(model_name) | |
def solve_math_problem(image): | |
try: | |
# Ensure the image is in RGB format. | |
image = image.convert("RGB") | |
# Preprocess the image and text. Note that header_text is omitted as it's not used for non-VQA tasks. | |
inputs = processor( | |
images=[image], | |
text="Solve the following math problem:", | |
return_tensors="pt", | |
max_patches=2048 | |
) | |
# Generate the solution with generation parameters. | |
predictions = model.generate( | |
**inputs, | |
max_new_tokens=200, | |
early_stopping=True, | |
num_beams=4, | |
temperature=0.2 | |
) | |
# Decode the problem text and generated solution. | |
problem_text = processor.decode( | |
inputs["input_ids"][0], | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True | |
) | |
solution = processor.decode( | |
predictions[0], | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True | |
) | |
return f"Problem: {problem_text}\nSolution: {solution}" | |
except Exception as e: | |
return f"Error processing image: {str(e)}" | |
# Set up the Gradio interface. | |
demo = gr.Interface( | |
fn=solve_math_problem, | |
inputs=gr.Image( | |
type="pil", | |
label="Upload Handwritten Math Problem", | |
image_mode="RGB" # This forces the input to be RGB. | |
), | |
outputs=gr.Textbox(label="Solution", show_copy_button=True), | |
title="Handwritten Math Problem Solver", | |
description="Upload an image of a handwritten math problem (algebra, arithmetic, etc.) and get the solution", | |
examples=[ | |
["example_addition.png"], | |
["example_algebra.jpg"] | |
], | |
theme="soft", | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |