Nitin00043 commited on
Commit
3cf4149
·
verified ·
1 Parent(s): e6458e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -55
app.py CHANGED
@@ -1,69 +1,42 @@
 
1
  from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
2
  import gradio as gr
3
  from PIL import Image
4
 
5
- # Use a public model identifier. If you need a private model, remember to authenticate.
6
- model_name = "google/pix2struct-textcaps-base"
7
  model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
8
  processor = Pix2StructProcessor.from_pretrained(model_name)
9
 
10
- def solve_math_problem(image):
11
  try:
12
- # Ensure the image is in RGB format.
13
- image = image.convert("RGB")
14
-
15
- # Preprocess the image and text. Note that header_text is omitted as it's not used for non-VQA tasks.
16
- inputs = processor(
17
- images=[image],
18
- text="Solve the following math problem:",
19
- return_tensors="pt",
20
- max_patches=2048
21
- )
22
-
23
- # Generate the solution with generation parameters.
24
- predictions = model.generate(
25
- **inputs,
26
- max_new_tokens=200,
27
- early_stopping=True,
28
- num_beams=4,
29
- temperature=0.2
30
- )
31
-
32
- # Decode the problem text and generated solution.
33
- problem_text = processor.decode(
34
- inputs["input_ids"][0],
35
- skip_special_tokens=True,
36
- clean_up_tokenization_spaces=True
37
- )
38
- solution = processor.decode(
39
- predictions[0],
40
- skip_special_tokens=True,
41
- clean_up_tokenization_spaces=True
42
- )
43
-
44
- return f"Problem: {problem_text}\nSolution: {solution}"
45
-
46
  except Exception as e:
47
  return f"Error processing image: {str(e)}"
48
 
49
- # Set up the Gradio interface.
50
- demo = gr.Interface(
51
- fn=solve_math_problem,
52
- inputs=gr.Image(
53
- type="pil",
54
- label="Upload Handwritten Math Problem",
55
- image_mode="RGB" # This forces the input to be RGB.
56
- ),
57
- outputs=gr.Textbox(label="Solution", show_copy_button=True),
58
- title="Handwritten Math Problem Solver",
59
- description="Upload an image of a handwritten math problem (algebra, arithmetic, etc.) and get the solution",
60
- examples=[
61
- ["example_addition.png"],
62
- ["example_algebra.jpg"]
63
- ],
64
- theme="soft",
65
- allow_flagging="never"
66
  )
67
 
68
  if __name__ == "__main__":
69
- demo.launch()
 
1
+ import torch
2
  from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
3
  import gradio as gr
4
  from PIL import Image
5
 
6
+ # Load model and processor
7
+ model_name = "google/pix2struct-docvqa-large"
8
  model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
9
  processor = Pix2StructProcessor.from_pretrained(model_name)
10
 
11
+ def process_image(image_path):
12
  try:
13
+ # Load the image
14
+ image = Image.open(image_path).convert("RGB")
15
+
16
+ # Prepare the input
17
+ inputs = processor(images=image, text="What does this image say?", return_tensors="pt")
18
+
19
+ # Generate prediction
20
+ output = model.generate(**inputs)
21
+
22
+ # Decode the output
23
+ solution = processor.decode(output[0], skip_special_tokens=True)
24
+ return solution
25
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  except Exception as e:
27
  return f"Error processing image: {str(e)}"
28
 
29
+ def predict(image):
30
+ """Handles image input for Gradio."""
31
+ return process_image(image)
32
+
33
+ # Gradio app
34
+ iface = gr.Interface(
35
+ fn=predict,
36
+ inputs=gr.Image(type="filepath"),
37
+ outputs="text",
38
+ title="Image Text Solution"
 
 
 
 
 
 
 
39
  )
40
 
41
  if __name__ == "__main__":
42
+ iface.launch()