Nanonets-OCR / app.py
Toughen1's picture
使用CPU
bc062f5 verified
import gradio as gr
import base64
import io
import torch
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor, AutoTokenizer
model_path = "nanonets/Nanonets-OCR-s"
# Load model once at startup
print("Loading Nanonets OCR model...")
model = AutoModelForImageTextToText.from_pretrained(
model_path,
torch_dtype="auto",
device_map="cpu", # 使用CPU
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_path)
processor = AutoProcessor.from_pretrained(model_path)
print("Model loaded successfully!")
def ocr_image_gradio(image, max_tokens=4096):
"""Process image through Nanonets OCR model for Gradio interface"""
if image is None:
return "Please upload an image."
prompt = """Extract the text from the above document as if you were reading it naturally. Return the tables in html format. Return the equations in LaTeX representation. If there is an image in the document and image caption is not present, add a small description of the image inside the <img></img> tag; otherwise, add the image caption inside <img></img>. Watermarks should be wrapped in brackets. Ex: <watermark>OFFICIAL COPY</watermark>. Page numbers should be wrapped in brackets. Ex: <page_number>14</page_number> or <page_number>9/22</page_number>. Prefer using ☐ and ☑ for check boxes."""
# Convert PIL image if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
},
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt")
inputs = inputs.to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=False,
repetition_penalty=1.25,
)
generated_ids = [
output_ids[len(input_ids) :]
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
]
output_text = processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
return output_text[0]
def ocr_base64_image(base64_string, max_tokens=4096):
"""Process base64 encoded image through Nanonets OCR model"""
if not base64_string or base64_string.strip() == "":
return "Please provide a valid base64 image string."
try:
# Remove data URL prefix if present
if "base64," in base64_string:
base64_string = base64_string.split("base64,")[1]
# Decode base64 to image
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data))
# Process image using existing OCR function
return ocr_image_gradio(image, max_tokens)
except Exception as e:
return f"Error processing base64 image: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Nanonets OCR Demo") as demo:
# Replace simple markdown with styled HTML header that includes resources
gr.HTML("""
<div class="title" style="text-align: center">
<h1>🔍 Nanonets OCR - Document Text Extraction</h1>
<p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
A state-of-the-art image-to-markdown OCR model for intelligent document processing
</p>
<div style="display: flex; justify-content: center; gap: 20px; margin: 15px 0;">
<a href="https://huggingface.co/nanonets/Nanonets-OCR-s" target="_blank" style="text-decoration: none; color: #2563eb; font-weight: 500;">
📚 Hugging Face Model
</a>
<a href="https://nanonets.com/research/nanonets-ocr-s/" target="_blank" style="text-decoration: none; color: #2563eb; font-weight: 500;">
📝 Release Blog
</a>
<a href="https://github.com/NanoNets/docext" target="_blank" style="text-decoration: none; color: #2563eb; font-weight: 500;">
💻 GitHub Repository
</a>
</div>
</div>
""")
with gr.Tabs() as tabs:
# Image tab
with gr.TabItem("Image OCR"):
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
label="上传文档图片", type="pil", height=400
)
image_max_tokens = gr.Slider(
minimum=1024,
maximum=8192,
value=4096,
step=512,
label="最大Token数",
info="生成的最大token数量",
)
image_extract_btn = gr.Button(
"提取文本", variant="primary", size="lg"
)
with gr.Column(scale=2):
image_output_text = gr.Textbox(
label="提取的文本",
lines=20,
show_copy_button=True,
placeholder="提取的文本将显示在这里...",
)
# Base64 Image tab
with gr.TabItem("Base64图片OCR"):
with gr.Row():
with gr.Column(scale=1):
base64_input = gr.Textbox(
label="输入Base64编码的图片",
lines=10,
placeholder="粘贴Base64编码的图片数据...",
)
base64_max_tokens = gr.Slider(
minimum=1024,
maximum=8192,
value=4096,
step=512,
label="最大Token数",
info="生成的最大token数量",
)
base64_extract_btn = gr.Button(
"提取文本", variant="primary", size="lg"
)
with gr.Column(scale=2):
base64_output_text = gr.Textbox(
label="提取的文本",
lines=20,
show_copy_button=True,
placeholder="提取的文本将显示在这里...",
)
# Event handlers for Image tab
image_extract_btn.click(
fn=ocr_image_gradio,
inputs=[image_input, image_max_tokens],
outputs=image_output_text,
show_progress=True,
)
image_input.change(
fn=ocr_image_gradio,
inputs=[image_input, image_max_tokens],
outputs=image_output_text,
show_progress=True,
)
# Event handlers for Base64 tab
base64_extract_btn.click(
fn=ocr_base64_image,
inputs=[base64_input, base64_max_tokens],
outputs=base64_output_text,
show_progress=True,
)
# Add model information section
with gr.Accordion("关于 Nanonets-OCR-s", open=False):
gr.Markdown("""
## Nanonets-OCR-s
Nanonets-OCR-s 是一个强大的最先进的图像到markdown的OCR模型,远超传统的文本提取功能。
它将文档转换为带有智能内容识别和语义标记的结构化markdown,非常适合大型语言模型(LLM)的下游处理。
### 主要特点
- **LaTeX公式识别**:自动将数学公式转换为格式正确的LaTeX语法。
它区分内联($...$)和显示($$...$$)公式。
- **智能图像描述**:使用结构化的`<img>`标签描述文档中的图像,使它们易于LLM处理。
它可以描述各种图像类型,包括徽标、图表、图形等,详细说明它们的内容、风格和上下文。
- **签名检测与隔离**:识别并隔离签名与其他文本,将其输出在`<signature>`标签内。
这对处理法律和商业文件至关重要。
- **水印提取**:检测并提取文档中的水印文本,将其放在`<watermark>`标签内。
- **智能复选框处理**:将表单复选框和单选按钮转换为标准化的Unicode符号(☐, ☑, ☒),
以实现一致可靠的处理。
- **复杂表格提取**:准确地从文档中提取复杂表格,并将它们转换为markdown和HTML表格格式。
""")
# API Usage Information
with gr.Accordion("API使用说明", open=True):
gr.Markdown("""
## API使用方法
### Base64图片识别API
您可以通过HTTP POST请求使用Base64图片识别API:
```
curl -X POST "http://localhost:7860/api/predict" \\
-H "Content-Type: application/json" \\
-d '{
"fn_index": 1,
"data": [
"YOUR_BASE64_STRING_HERE",
4096
]
}'
```
- `fn_index: 1` 对应Base64图片OCR功能
- 第一个参数是Base64编码的图片字符串
- 第二个参数是最大token数量
### 普通图片上传API
```
curl -X POST "http://localhost:7860/api/predict" \\
-H "Content-Type: application/json" \\
-d '{
"fn_index": 0,
"data": [
"IMAGE_DATA_HERE",
4096
]
}'
```
- `fn_index: 0` 对应普通图片OCR功能
""")
# CPU Usage Warning
with gr.Accordion("CPU环境说明", open=True):
gr.Markdown("""
## CPU环境性能说明
此应用程序当前运行在CPU环境下(2核16G),请注意:
- 处理大型图像可能需要更长时间
- 建议使用较小的图像以获得更快的响应速度
- 如果处理时间过长,可以考虑降低最大Token数
- 模型已针对CPU环境进行了优化配置
""")
if __name__ == "__main__":
import torch
print(f"使用设备: CPU - 可用线程数: {torch.get_num_threads()}")
# 设置线程数以优化CPU性能
torch.set_num_threads(2) # 设置为可用的2核
demo.queue().launch(share=True, server_name="0.0.0.0")