Spaces:
Running
on
Zero
Running
on
Zero

davanstrien
HF Staff
Refactor OCR processing by introducing a GPU-accelerated predict function and updating the run_hf_ocr method to utilize it
864e5c4
import gradio as gr | |
from PIL import Image | |
import xml.etree.ElementTree as ET | |
import os | |
import torch | |
from transformers import AutoProcessor, AutoModelForImageTextToText, pipeline | |
import spaces | |
# --- Global Model and Processor Initialization --- | |
# Load the OCR model and processor once when the app starts | |
try: | |
HF_PROCESSOR = AutoProcessor.from_pretrained("reducto/RolmOCR") | |
HF_MODEL = AutoModelForImageTextToText.from_pretrained( | |
"reducto/RolmOCR", | |
torch_dtype=torch.bfloat16, | |
# attn_implementation="flash_attention_2", # User had this commented out | |
device_map="auto" | |
) | |
HF_PIPE = pipeline("image-text-to-text", model=HF_MODEL, processor=HF_PROCESSOR) | |
print("Hugging Face OCR model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading Hugging Face model: {e}") | |
HF_PIPE = None | |
# --- Helper Functions --- | |
def get_alto_namespace(xml_file_path): | |
""" | |
Dynamically gets the ALTO namespace from the XML file. | |
""" | |
try: | |
tree = ET.parse(xml_file_path) | |
root = tree.getroot() | |
if '}' in root.tag: | |
return root.tag.split('}')[0] + '}' | |
except ET.ParseError: | |
print(f"Error parsing XML to find namespace: {xml_file_path}") | |
return '' | |
def parse_alto_xml_for_text(xml_file_path): | |
""" | |
Parses an ALTO XML file to extract text content. | |
Returns: | |
- full_text (str): All extracted text concatenated. | |
""" | |
full_text_lines = [] | |
if not xml_file_path or not os.path.exists(xml_file_path): | |
return "Error: XML file not provided or does not exist." | |
try: | |
ns_prefix = get_alto_namespace(xml_file_path) | |
tree = ET.parse(xml_file_path) | |
root = tree.getroot() | |
for text_line in root.findall(f'.//{ns_prefix}TextLine'): | |
line_text_parts = [] | |
for string_element in text_line.findall(f'{ns_prefix}String'): | |
text = string_element.get('CONTENT') | |
if text: | |
line_text_parts.append(text) | |
if line_text_parts: | |
full_text_lines.append(" ".join(line_text_parts)) | |
return "\n".join(full_text_lines) | |
except ET.ParseError as e: | |
return f"Error parsing XML: {e}" | |
except Exception as e: | |
return f"An unexpected error occurred during XML parsing: {e}" | |
def run_hf_ocr(image_path): | |
""" | |
Runs OCR on the provided image using the pre-loaded Hugging Face model. | |
""" | |
if HF_PIPE is None: | |
return "Hugging Face OCR model not available." | |
if image_path is None: | |
return "No image provided for OCR." | |
try: | |
# Load the image using PIL, as the pipeline expects an image object or path | |
pil_image = Image.open(image_path).convert("RGB") | |
# The user's example output for the pipeline call was: | |
# [{'generated_text': [{'role': 'user', ...}, {'role': 'assistant', 'content': "TEXT..."}]}] | |
# This suggests the pipeline is returning a conversational style output. | |
# We will try to call the pipeline with the image and prompt directly. | |
ocr_results = predict(pil_image) | |
# Parse the output based on the user's example structure | |
if isinstance(ocr_results, list) and ocr_results and 'generated_text' in ocr_results[0]: | |
generated_content = ocr_results[0]['generated_text'] | |
# Check if generated_content itself is the direct text (some pipelines do this) | |
if isinstance(generated_content, str): | |
return generated_content | |
# Check for the conversational structure | |
# [{'role': 'user', ...}, {'role': 'assistant', 'content': "TEXT..."}] | |
if isinstance(generated_content, list) and generated_content: | |
# The assistant's response is typically the last message in the list | |
# or specifically the one with role 'assistant'. | |
assistant_message = None | |
for msg in reversed(generated_content): # Check from the end | |
if isinstance(msg, dict) and msg.get('role') == 'assistant' and 'content' in msg: | |
assistant_message = msg['content'] | |
break | |
if assistant_message: | |
return assistant_message | |
# Fallback if parsing the complex structure fails but we got some string | |
if isinstance(generated_content, list) and generated_content and isinstance(generated_content[0], dict) and 'content' in generated_content[0]: | |
# This is a guess if the structure is simpler than expected. | |
# Or if the first part is the user prompt echo and second is assistant. | |
if len(generated_content) > 1 and isinstance(generated_content[1], dict) and 'content' in generated_content[1]: | |
return generated_content[1]['content'] # Assuming second part is assistant | |
print(f"Unexpected OCR output structure from HF model: {ocr_results}") | |
return "Error: Could not parse OCR model output. Please check console for details." | |
else: | |
print(f"Unexpected OCR output structure from HF model: {ocr_results}") | |
return "Error: OCR model did not return expected output. Please check console for details." | |
except Exception as e: | |
print(f"Error during Hugging Face OCR: {e}") | |
return f"Error during Hugging Face OCR: {str(e)}" | |
def predict(pil_image): | |
ocr_results = HF_PIPE( | |
pil_image, | |
prompt="Return the plain text representation of this document as if you were reading it naturally.\n" | |
# The pipeline should handle formatting this into messages if needed by the model. | |
) | |
return ocr_results | |
# --- Gradio Interface Function --- | |
def process_files(image_path, xml_path): | |
""" | |
Main function for the Gradio interface. | |
Processes the image for display, runs OCR (Hugging Face model), | |
and parses ALTO XML if provided. | |
""" | |
img_to_display = None | |
alto_text_output = "ALTO XML not provided or not processed." | |
hf_ocr_text_output = "Image not provided or OCR not run." | |
if image_path: | |
try: | |
img_to_display = Image.open(image_path).convert("RGB") | |
hf_ocr_text_output = run_hf_ocr(image_path) | |
except Exception as e: | |
img_to_display = None # Clear image if it failed to load | |
hf_ocr_text_output = f"Error loading image or running HF OCR: {e}" | |
else: | |
hf_ocr_text_output = "Please upload an image to perform OCR." | |
if xml_path: | |
alto_text_output = parse_alto_xml_for_text(xml_path) | |
else: | |
alto_text_output = "No ALTO XML file uploaded." | |
# If only XML is provided without an image | |
if not image_path and xml_path: | |
img_to_display = None # No image to display | |
hf_ocr_text_output = "Upload an image to perform OCR." | |
return img_to_display, alto_text_output, hf_ocr_text_output | |
# --- Create Gradio App --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# OCR Viewer and Extractor") | |
gr.Markdown( | |
"Upload an image to perform OCR using a Hugging Face model. " | |
"Optionally, upload its corresponding ALTO OCR XML file to compare the extracted text." | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.File(label="Upload Image (PNG, JPG, etc.)", type="filepath") | |
xml_input = gr.File(label="Upload ALTO XML File (Optional, .xml)", type="filepath") | |
submit_button = gr.Button("Process Image and XML", variant="primary") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
output_image_display = gr.Image(label="Uploaded Image", type="pil", interactive=False) | |
with gr.Column(scale=1): | |
hf_ocr_output_textbox = gr.Textbox( | |
label="OCR Output (Hugging Face Model)", | |
lines=15, | |
interactive=False, | |
show_copy_button=True | |
) | |
alto_xml_output_textbox = gr.Textbox( | |
label="Text from ALTO XML", | |
lines=15, | |
interactive=False, | |
show_copy_button=True | |
) | |
submit_button.click( | |
fn=process_files, | |
inputs=[image_input, xml_input], | |
outputs=[output_image_display, alto_xml_output_textbox, hf_ocr_output_textbox] | |
) | |
gr.Markdown("---") | |
gr.Markdown("### Example ALTO XML Snippet (for `String` element extraction):") | |
gr.Code( | |
value=( | |
"""<alto xmlns="http://www.loc.gov/standards/alto/v3/alto.xsd"> | |
<Description>...</Description> | |
<Styles>...</Styles> | |
<Layout> | |
<Page ID="Page13" PHYSICAL_IMG_NR="13" WIDTH="2394" HEIGHT="3612"> | |
<PrintSpace> | |
<TextLine WIDTH="684" HEIGHT="108" ID="p13_t1" HPOS="465" VPOS="196"> | |
<String ID="p13_w1" CONTENT="Introduction" HPOS="465" VPOS="196" WIDTH="684" HEIGHT="108" STYLEREFS="font0"/> | |
</TextLine> | |
<TextLine WIDTH="1798" HEIGHT="51" ID="p13_t2" HPOS="492" VPOS="523"> | |
<String ID="p13_w2" CONTENT="Britain" HPOS="492" VPOS="523" WIDTH="166" HEIGHT="51" STYLEREFS="font1"/> | |
<SP WIDTH="24" VPOS="523" HPOS="658"/> | |
<String ID="p13_w3" CONTENT="1981" HPOS="682" VPOS="523" WIDTH="117" HEIGHT="51" STYLEREFS="font1"/> | |
<!-- ... more String and SP elements ... --> | |
</TextLine> | |
<!-- ... more TextLine elements ... --> | |
</PrintSpace> | |
</Page> | |
</Layout> | |
</alto>""" | |
), | |
interactive=False | |
) | |
if __name__ == "__main__": | |
# Removed dummy file creation as it's less relevant for single file focus | |
print("Attempting to launch Gradio demo...") | |
print("If the Hugging Face model is large, initial startup might take some time due to model download/loading.") | |
demo.launch() |