Spaces:
Running
on
Zero
Running
on
Zero

davanstrien
HF Staff
Refactor XML parsing functions for improved readability and consistency
5639776
import gradio as gr | |
from PIL import Image | |
import xml.etree.ElementTree as ET | |
import os | |
import torch | |
from transformers import AutoProcessor, AutoModelForImageTextToText, pipeline | |
import spaces | |
# --- Global Model and Processor --- | |
HF_PROCESSOR = None | |
HF_MODEL = None | |
HF_PIPE = None | |
MODEL_LOAD_ERROR_MSG = None | |
HF_PROCESSOR = AutoProcessor.from_pretrained("reducto/RolmOCR") | |
HF_MODEL = AutoModelForImageTextToText.from_pretrained( | |
"reducto/RolmOCR", torch_dtype=torch.bfloat16, device_map="auto" | |
) | |
HF_PIPE = pipeline("image-text-to-text", model=HF_MODEL, processor=HF_PROCESSOR) | |
# --- Helper Functions --- | |
def get_xml_namespace(xml_file_path): | |
""" | |
Dynamically gets the namespace from the XML file. | |
Returns both the namespace and the format type (ALTO or PAGE). | |
""" | |
try: | |
tree = ET.parse(xml_file_path) | |
root = tree.getroot() | |
if "}" in root.tag: | |
ns = root.tag.split("}")[0] + "}" | |
# Determine format based on root element | |
if "PcGts" in root.tag: | |
return ns, "PAGE" | |
elif "alto" in root.tag.lower(): | |
return ns, "ALTO" | |
except ET.ParseError: | |
print(f"Error parsing XML to find namespace: {xml_file_path}") | |
return "", "UNKNOWN" | |
def parse_page_xml_for_text(xml_file_path): | |
""" | |
Parses a PAGE XML file to extract text content. | |
Returns: | |
- full_text (str): All extracted text concatenated. | |
""" | |
full_text_lines = [] | |
if not xml_file_path or not os.path.exists(xml_file_path): | |
return "Error: XML file not provided or does not exist." | |
try: | |
ns_prefix, _ = get_xml_namespace(xml_file_path) | |
tree = ET.parse(xml_file_path) | |
root = tree.getroot() | |
# Find all TextLine elements | |
for text_line in root.findall(f".//{ns_prefix}TextLine"): | |
# First try to get text from TextEquiv/Unicode | |
text_equiv = text_line.find(f"{ns_prefix}TextEquiv/{ns_prefix}Unicode") | |
if text_equiv is not None and text_equiv.text: | |
full_text_lines.append(text_equiv.text) | |
continue | |
# If no TextEquiv, try to get text from Word elements | |
line_text_parts = [] | |
for word in text_line.findall(f"{ns_prefix}Word"): | |
word_text = word.find(f"{ns_prefix}TextEquiv/{ns_prefix}Unicode") | |
if word_text is not None and word_text.text: | |
line_text_parts.append(word_text.text) | |
if line_text_parts: | |
full_text_lines.append(" ".join(line_text_parts)) | |
return "\n".join(full_text_lines) | |
except ET.ParseError as e: | |
return f"Error parsing XML: {e}" | |
except Exception as e: | |
return f"An unexpected error occurred during XML parsing: {e}" | |
def parse_alto_xml_for_text(xml_file_path): | |
""" | |
Parses an ALTO XML file to extract text content. | |
Returns: | |
- full_text (str): All extracted text concatenated. | |
""" | |
full_text_lines = [] | |
if not xml_file_path or not os.path.exists(xml_file_path): | |
return "Error: XML file not provided or does not exist." | |
try: | |
ns_prefix, _ = get_xml_namespace(xml_file_path) | |
tree = ET.parse(xml_file_path) | |
root = tree.getroot() | |
for text_line in root.findall(f".//{ns_prefix}TextLine"): | |
line_text_parts = [] | |
for string_element in text_line.findall(f"{ns_prefix}String"): | |
text = string_element.get("CONTENT") | |
if text: | |
line_text_parts.append(text) | |
if line_text_parts: | |
full_text_lines.append(" ".join(line_text_parts)) | |
return "\n".join(full_text_lines) | |
except ET.ParseError as e: | |
return f"Error parsing XML: {e}" | |
except Exception as e: | |
return f"An unexpected error occurred during XML parsing: {e}" | |
def parse_xml_for_text(xml_file_path): | |
""" | |
Main function to parse XML files, automatically detecting the format. | |
""" | |
if not xml_file_path or not os.path.exists(xml_file_path): | |
return "Error: XML file not provided or does not exist." | |
try: | |
_, xml_format = get_xml_namespace(xml_file_path) | |
if xml_format == "PAGE": | |
return parse_page_xml_for_text(xml_file_path) | |
elif xml_format == "ALTO": | |
return parse_alto_xml_for_text(xml_file_path) | |
else: | |
return f"Error: Unsupported XML format. Expected ALTO or PAGE XML." | |
except Exception as e: | |
return f"Error determining XML format: {str(e)}" | |
def predict(pil_image): | |
"""Performs OCR prediction using the Hugging Face model.""" | |
global HF_PIPE, MODEL_LOAD_ERROR_MSG | |
if HF_PIPE is None: | |
error_to_report = ( | |
MODEL_LOAD_ERROR_MSG | |
if MODEL_LOAD_ERROR_MSG | |
else "OCR model could not be initialized." | |
) | |
raise RuntimeError(error_to_report) | |
# Format the message in the expected structure | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": pil_image}, | |
{ | |
"type": "text", | |
"text": "Return the plain text representation of this document as if you were reading it naturally.\n", | |
}, | |
], | |
} | |
] | |
# Use the pipeline with the properly formatted messages | |
return HF_PIPE(messages, max_new_tokens=8096) | |
def run_hf_ocr(image_path): | |
""" | |
Runs OCR on the provided image using the Hugging Face model (via predict function). | |
""" | |
if image_path is None: | |
return "No image provided for OCR." | |
try: | |
pil_image = Image.open(image_path).convert("RGB") | |
ocr_results = predict(pil_image) # predict handles model loading and inference | |
# Parse the output based on the user's example structure | |
if ( | |
isinstance(ocr_results, list) | |
and ocr_results | |
and "generated_text" in ocr_results[0] | |
): | |
generated_content = ocr_results[0]["generated_text"] | |
if isinstance(generated_content, str): | |
return generated_content | |
if isinstance(generated_content, list) and generated_content: | |
if assistant_message := next( | |
( | |
msg["content"] | |
for msg in reversed(generated_content) | |
if isinstance(msg, dict) | |
and msg.get("role") == "assistant" | |
and "content" in msg | |
), | |
None, | |
): | |
return assistant_message | |
# Fallback if the specific assistant message structure isn't found but there's content | |
if ( | |
isinstance(generated_content[0], dict) | |
and "content" in generated_content[0] | |
): | |
if ( | |
len(generated_content) > 1 | |
and isinstance(generated_content[1], dict) | |
and "content" in generated_content[1] | |
): | |
return generated_content[1][ | |
"content" | |
] # Assuming second part is assistant | |
else: | |
return generated_content[0]["content"] | |
print(f"Unexpected OCR output structure from HF model: {ocr_results}") | |
return "Error: Could not parse OCR model output. Check console." | |
else: | |
print(f"Unexpected OCR output structure from HF model: {ocr_results}") | |
return "Error: OCR model did not return expected output. Check console." | |
except RuntimeError as e: # Catch model loading/initialization errors from predict | |
return str(e) | |
except Exception as e: | |
print(f"Error during Hugging Face OCR processing: {e}") | |
return f"Error during Hugging Face OCR: {str(e)}" | |
# --- Gradio Interface Function --- | |
def process_files(image_path, xml_path): | |
""" | |
Main function for the Gradio interface. | |
Processes the image for display, runs OCR (Hugging Face model), | |
and parses XML if provided. | |
""" | |
img_to_display = None | |
xml_text_output = "XML not provided or not processed." | |
hf_ocr_text_output = "Image not provided or OCR not run." | |
if image_path: | |
try: | |
img_to_display = Image.open(image_path).convert("RGB") | |
hf_ocr_text_output = run_hf_ocr(image_path) | |
except Exception as e: | |
img_to_display = None # Clear image if it failed to load | |
hf_ocr_text_output = f"Error loading image or running HF OCR: {e}" | |
else: | |
hf_ocr_text_output = "Please upload an image to perform OCR." | |
if xml_path: | |
xml_text_output = parse_xml_for_text(xml_path) | |
else: | |
xml_text_output = "No XML file uploaded." | |
# If only XML is provided without an image | |
if not image_path and xml_path: | |
img_to_display = None # No image to display | |
hf_ocr_text_output = "Upload an image to perform OCR." | |
return img_to_display, xml_text_output, hf_ocr_text_output | |
# --- Create Gradio App --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# OCR Viewer and Extractor") | |
gr.Markdown( | |
"Upload an image to perform OCR using a Hugging Face model. " | |
"Optionally, upload its corresponding ALTO or PAGE XML file to compare the extracted text." | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.File( | |
label="Upload Image (PNG, JPG, etc.)", type="filepath" | |
) | |
xml_input = gr.File( | |
label="Upload XML File (Optional, ALTO or PAGE format)", type="filepath" | |
) | |
submit_button = gr.Button("Process Image and XML", variant="primary") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
output_image_display = gr.Image( | |
label="Uploaded Image", type="pil", interactive=False | |
) | |
with gr.Column(scale=1): | |
hf_ocr_output_textbox = gr.Markdown( | |
label="OCR Output (Hugging Face Model)", | |
show_copy_button=True, | |
) | |
xml_output_textbox = gr.Textbox( | |
label="Text from XML", | |
lines=15, | |
interactive=False, | |
show_copy_button=True, | |
) | |
submit_button.click( | |
fn=process_files, | |
inputs=[image_input, xml_input], | |
outputs=[output_image_display, xml_output_textbox, hf_ocr_output_textbox], | |
) | |
gr.Markdown("---") | |
gr.Markdown("### Example ALTO XML Snippet (for `String` element extraction):") | |
gr.Code( | |
value=( | |
"""<alto xmlns="http://www.loc.gov/standards/alto/v3/alto.xsd"> | |
<Description>...</Description> | |
<Styles>...</Styles> | |
<Layout> | |
<Page ID="Page13" PHYSICAL_IMG_NR="13" WIDTH="2394" HEIGHT="3612"> | |
<PrintSpace> | |
<TextLine WIDTH="684" HEIGHT="108" ID="p13_t1" HPOS="465" VPOS="196"> | |
<String ID="p13_w1" CONTENT="Introduction" HPOS="465" VPOS="196" WIDTH="684" HEIGHT="108" STYLEREFS="font0"/> | |
</TextLine> | |
<TextLine WIDTH="1798" HEIGHT="51" ID="p13_t2" HPOS="492" VPOS="523"> | |
<String ID="p13_w2" CONTENT="Britain" HPOS="492" VPOS="523" WIDTH="166" HEIGHT="51" STYLEREFS="font1"/> | |
<SP WIDTH="24" VPOS="523" HPOS="658"/> | |
<String ID="p13_w3" CONTENT="1981" HPOS="682" VPOS="523" WIDTH="117" HEIGHT="51" STYLEREFS="font1"/> | |
<!-- ... more String and SP elements ... --> | |
</TextLine> | |
<!-- ... more TextLine elements ... --> | |
</PrintSpace> | |
</Page> | |
</Layout> | |
</alto>""" | |
), | |
interactive=False, | |
) | |
if __name__ == "__main__": | |
# Removed dummy file creation as it's less relevant for single file focus | |
print("Attempting to launch Gradio demo...") | |
print( | |
"If the Hugging Face model is large, initial startup might take some time due to model download/loading (on first OCR attempt)." | |
) | |
demo.launch() | |