File size: 1,898 Bytes
e50f30c
ec7d971
623c9e7
ec7d971
623c9e7
e50f30c
4a01533
 
ec7d971
 
 
191e2cd
aebafcf
e50f30c
 
aebafcf
e50f30c
ec7d971
aebafcf
ec7d971
aebafcf
e50f30c
adc05de
aebafcf
 
 
 
 
 
 
 
 
adc05de
aebafcf
ec7d971
 
623c9e7
aebafcf
1ee9cdc
ec7d971
 
 
 
e50f30c
ec7d971
623c9e7
 
 
4a01533
aebafcf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
import gradio as gr
from PIL import Image

# Use a valid model identifier. Here we use "google/matcha-base".
model_name = "google/matcha-base"

# Load the pre-trained Pix2Struct model and processor
model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
processor = Pix2StructProcessor.from_pretrained(model_name)

# Move model to GPU if available and set to evaluation mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

def solve_math_problem(image):
    # Preprocess the image and include a prompt.
    inputs = processor(images=image, text="Solve the math problem:", return_tensors="pt")
    # Move all tensors to the same device as the model
    inputs = {key: value.to(device) for key, value in inputs.items()}
    
    # Generate the solution using beam search within a no_grad context
    with torch.no_grad():
        predictions = model.generate(
            **inputs,
            max_new_tokens=150,  # Increase this if longer answers are needed
            num_beams=5,         # Beam search for more stable outputs
            early_stopping=True,
            temperature=0.5      # Lower temperature for more deterministic output
        )
    
    # Decode the generated tokens to a string, skipping special tokens
    solution = processor.decode(predictions[0], skip_special_tokens=True)
    return solution

# Set up the Gradio interface
demo = gr.Interface(
    fn=solve_math_problem,
    inputs=gr.Image(type="pil", label="Upload Handwritten Math Problem"),
    outputs=gr.Textbox(label="Solution"),
    title="Handwritten Math Problem Solver",
    description="Upload an image of a handwritten math problem and the model will attempt to solve it.",
    theme="soft"
)

if __name__ == "__main__":
    demo.launch()