File size: 2,272 Bytes
7fe122b
3dc1b9e
 
 
7fe122b
 
 
 
a1c289f
3dc1b9e
 
7fe122b
a1c289f
 
7fe122b
 
 
 
 
 
 
3dc1b9e
7fe122b
 
 
 
 
 
 
 
3dc1b9e
7fe122b
 
 
 
 
 
 
 
 
 
 
3dc1b9e
7fe122b
3dc1b9e
 
7fe122b
3dc1b9e
7fe122b
3dc1b9e
 
 
 
 
7fe122b
3dc1b9e
7fe122b
3dc1b9e
7fe122b
3dc1b9e
11d8425
3dc1b9e
 
7fe122b
a1c289f
3dc1b9e
 
 
5093ea9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
import gradio as gr
from PIL import Image

# Use a public model identifier. If you need a private model, remember to authenticate.
model_name = "google/pix2struct-textcaps-base"
model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
processor = Pix2StructProcessor.from_pretrained(model_name)

def solve_math_problem(image):
    try:
        # Ensure the image is in RGB format.
        image = image.convert("RGB")
        
        # Preprocess the image and text. Note that header_text is omitted as it's not used for non-VQA tasks.
        inputs = processor(
            images=[image],
            text="Solve the following math problem:",
            return_tensors="pt",
            max_patches=2048
        )
        
        # Generate the solution with generation parameters.
        predictions = model.generate(
            **inputs,
            max_new_tokens=200,
            early_stopping=True,
            num_beams=4,
            temperature=0.2
        )
        
        # Decode the problem text and generated solution.
        problem_text = processor.decode(
            inputs["input_ids"][0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        solution = processor.decode(
            predictions[0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        
        return f"Problem: {problem_text}\nSolution: {solution}"
        
    except Exception as e:
        return f"Error processing image: {str(e)}"

# Set up the Gradio interface.
demo = gr.Interface(
    fn=solve_math_problem,
    inputs=gr.Image(
        type="pil", 
        label="Upload Handwritten Math Problem",
        image_mode="RGB"  # This forces the input to be RGB.
    ),
    outputs=gr.Textbox(label="Solution", show_copy_button=True),
    title="Handwritten Math Problem Solver",
    description="Upload an image of a handwritten math problem (algebra, arithmetic, etc.) and get the solution",
    examples=[
        ["example_addition.png"],
        ["example_algebra.jpg"]
    ],
    theme="soft",
    allow_flagging="never"
)

if __name__ == "__main__":
    demo.launch()