Spaces:
Running
Running
Commit
·
95af358
1
Parent(s):
7f263b5
Add Nanonets OCR model integration with Gradio interface
Browse files- app.py +116 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
from transformers import AutoTokenizer, AutoProcessor, AutoModelForImageTextToText
|
4 |
+
import torch
|
5 |
+
import spaces
|
6 |
+
|
7 |
+
model_path = "nanonets/Nanonets-OCR-s"
|
8 |
+
|
9 |
+
# Load model once at startup
|
10 |
+
print("Loading Nanonets OCR model...")
|
11 |
+
model = AutoModelForImageTextToText.from_pretrained(
|
12 |
+
model_path,
|
13 |
+
torch_dtype="auto",
|
14 |
+
device_map="auto",
|
15 |
+
attn_implementation="flash_attention_2"
|
16 |
+
)
|
17 |
+
model.eval()
|
18 |
+
|
19 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
20 |
+
processor = AutoProcessor.from_pretrained(model_path)
|
21 |
+
print("Model loaded successfully!")
|
22 |
+
|
23 |
+
@spaces.GPU()
|
24 |
+
def ocr_image_gradio(image, max_tokens=4096):
|
25 |
+
"""Process image through Nanonets OCR model for Gradio interface"""
|
26 |
+
if image is None:
|
27 |
+
return "Please upload an image."
|
28 |
+
|
29 |
+
try:
|
30 |
+
prompt = """Extract the text from the above document as if you were reading it naturally. Return the tables in html format. Return the equations in LaTeX representation. If there is an image in the document and image caption is not present, add a small description of the image inside the <img></img> tag; otherwise, add the image caption inside <img></img>. Watermarks should be wrapped in brackets. Ex: <watermark>OFFICIAL COPY</watermark>. Page numbers should be wrapped in brackets. Ex: <page_number>14</page_number> or <page_number>9/22</page_number>. Prefer using ☐ and ☑ for check boxes."""
|
31 |
+
|
32 |
+
# Convert PIL image if needed
|
33 |
+
if not isinstance(image, Image.Image):
|
34 |
+
image = Image.fromarray(image)
|
35 |
+
|
36 |
+
messages = [
|
37 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
38 |
+
{"role": "user", "content": [
|
39 |
+
{"type": "image", "image": image},
|
40 |
+
{"type": "text", "text": prompt},
|
41 |
+
]},
|
42 |
+
]
|
43 |
+
|
44 |
+
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
45 |
+
inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt")
|
46 |
+
inputs = inputs.to(model.device)
|
47 |
+
|
48 |
+
with torch.no_grad():
|
49 |
+
output_ids = model.generate(**inputs, max_new_tokens=max_tokens, do_sample=False)
|
50 |
+
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
|
51 |
+
|
52 |
+
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
53 |
+
return output_text[0]
|
54 |
+
|
55 |
+
except Exception as e:
|
56 |
+
return f"Error processing image: {str(e)}"
|
57 |
+
|
58 |
+
# Create Gradio interface
|
59 |
+
with gr.Blocks(title="Nanonets OCR Demo", theme=gr.themes.Soft()) as demo:
|
60 |
+
gr.Markdown("# 🔍 Nanonets OCR - Document Text Extraction")
|
61 |
+
gr.Markdown("Upload an image of a document to extract text, tables, equations, and more!")
|
62 |
+
|
63 |
+
with gr.Row():
|
64 |
+
with gr.Column(scale=1):
|
65 |
+
image_input = gr.Image(
|
66 |
+
label="Upload Document Image",
|
67 |
+
type="pil",
|
68 |
+
height=400
|
69 |
+
)
|
70 |
+
max_tokens_slider = gr.Slider(
|
71 |
+
minimum=1024,
|
72 |
+
maximum=8192,
|
73 |
+
value=4096,
|
74 |
+
step=512,
|
75 |
+
label="Max Tokens",
|
76 |
+
info="Maximum number of tokens to generate"
|
77 |
+
)
|
78 |
+
extract_btn = gr.Button("Extract Text", variant="primary", size="lg")
|
79 |
+
|
80 |
+
with gr.Column(scale=2):
|
81 |
+
output_text = gr.Textbox(
|
82 |
+
label="Extracted Text",
|
83 |
+
lines=20,
|
84 |
+
max_lines=30,
|
85 |
+
show_copy_button=True,
|
86 |
+
placeholder="Extracted text will appear here..."
|
87 |
+
)
|
88 |
+
|
89 |
+
# Example images section
|
90 |
+
gr.Markdown("## 📄 Try with example images:")
|
91 |
+
gr.Examples(
|
92 |
+
examples=[
|
93 |
+
["examples/sample1.jpg"] if "examples/sample1.jpg" else None,
|
94 |
+
["examples/sample2.png"] if "examples/sample2.png" else None,
|
95 |
+
],
|
96 |
+
inputs=image_input,
|
97 |
+
label="Sample Documents"
|
98 |
+
)
|
99 |
+
|
100 |
+
# Event handlers
|
101 |
+
extract_btn.click(
|
102 |
+
fn=ocr_image_gradio,
|
103 |
+
inputs=[image_input, max_tokens_slider],
|
104 |
+
outputs=output_text,
|
105 |
+
show_progress=True
|
106 |
+
)
|
107 |
+
|
108 |
+
image_input.change(
|
109 |
+
fn=ocr_image_gradio,
|
110 |
+
inputs=[image_input, max_tokens_slider],
|
111 |
+
outputs=output_text,
|
112 |
+
show_progress=True
|
113 |
+
)
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
demo.queue().launch()
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
torch
|
3 |
+
spaces
|