davanstrien HF Staff commited on
Commit
e4442f3
·
1 Parent(s): 5f3165f

Add requirements.in and update requirements.txt with dependencies

Browse files
Files changed (3) hide show
  1. app.py +138 -70
  2. requirements.in +6 -0
  3. requirements.txt +238 -3
app.py CHANGED
@@ -1,7 +1,25 @@
1
  import gradio as gr
2
- from PIL import Image # ImageDraw, ImageFont are no longer needed for overlay
3
  import xml.etree.ElementTree as ET
4
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # --- Helper Functions ---
7
 
@@ -34,16 +52,12 @@ def parse_alto_xml_for_text(xml_file_path):
34
  tree = ET.parse(xml_file_path)
35
  root = tree.getroot()
36
 
37
- # Find all TextLine elements
38
  for text_line in root.findall(f'.//{ns_prefix}TextLine'):
39
  line_text_parts = []
40
  for string_element in text_line.findall(f'{ns_prefix}String'):
41
  text = string_element.get('CONTENT')
42
- if text: # Ensure text is not None
43
  line_text_parts.append(text)
44
- # Also consider <SP/> (Space) elements if they contribute to word separation
45
- # and are not implicitly handled by joining CONTENT attributes.
46
- # For now, just joining CONTENT attributes.
47
  if line_text_parts:
48
  full_text_lines.append(" ".join(line_text_parts))
49
 
@@ -54,80 +68,148 @@ def parse_alto_xml_for_text(xml_file_path):
54
  except Exception as e:
55
  return f"An unexpected error occurred during XML parsing: {e}"
56
 
57
- # The draw_ocr_on_image function is no longer needed.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # --- Gradio Interface Function ---
60
 
61
- def process_image_and_xml(image_path, xml_path):
62
  """
63
  Main function for the Gradio interface.
64
- Processes the image and XML to return the image and extracted text.
 
65
  """
66
- if image_path is None: # If no image is uploaded at all
67
- return None, "Please upload an image."
68
-
69
- try:
70
- img_pil = Image.open(image_path).convert("RGB")
71
- except Exception as e:
72
- return None, f"Error loading image: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- if xml_path is None: # If XML is missing, but image is present
75
- return img_pil, "Please upload an OCR XML file."
76
 
77
- # Both image and XML are presumably present
78
- extracted_text = parse_alto_xml_for_text(xml_path)
79
-
80
- return img_pil, extracted_text
81
 
82
 
83
  # --- Create Gradio App ---
84
 
85
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
86
- gr.Markdown("# OCR Viewer (ALTO XML) - Text Extractor")
87
  gr.Markdown(
88
- "Upload an image and its corresponding ALTO OCR XML file. "
89
- "The app will display the image and extract/show the plain text."
90
  )
91
 
92
  with gr.Row():
93
  with gr.Column(scale=1):
94
  image_input = gr.File(label="Upload Image (PNG, JPG, etc.)", type="filepath")
95
- xml_input = gr.File(label="Upload ALTO XML File (.xml)", type="filepath")
96
- # show_overlay_checkbox has been removed
97
- submit_button = gr.Button("Process Files", variant="primary")
98
 
99
  with gr.Row():
100
  with gr.Column(scale=1):
101
- output_image_orig = gr.Image(label="Uploaded Image", type="pil", interactive=False)
102
  with gr.Column(scale=1):
103
- output_text = gr.Textbox(label="Extracted Plain Text", lines=15, interactive=False)
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- # output_image_overlay has been removed
106
-
107
- def update_interface(image_filepath, xml_filepath):
108
- # image_filepath and xml_filepath are now strings (paths) or None
109
-
110
- if image_filepath is None and xml_filepath is None:
111
- return None, "Please upload an image and an XML file."
112
- # process_image_and_xml handles cases where one is None
113
-
114
- img, text = process_image_and_xml(image_filepath, xml_filepath)
115
-
116
- return img, text
117
-
118
  submit_button.click(
119
- fn=update_interface,
120
- inputs=[image_input, xml_input], # show_overlay_checkbox removed
121
- outputs=[output_image_orig, output_text] # output_image_overlay removed
122
  )
123
 
124
- # The .change event for show_overlay_checkbox has been removed
125
-
126
  gr.Markdown("---")
127
  gr.Markdown("### Example ALTO XML Snippet (for `String` element extraction):")
128
  gr.Code(
129
- value="""
130
- <alto xmlns="http://www.loc.gov/standards/alto/v3/alto.xsd">
131
  <Description>...</Description>
132
  <Styles>...</Styles>
133
  <Layout>
@@ -146,28 +228,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
146
  </PrintSpace>
147
  </Page>
148
  </Layout>
149
- </alto>
150
- """,
 
151
  interactive=False
152
  )
153
 
154
-
155
  if __name__ == "__main__":
156
- try:
157
- # Create a dummy image for testing
158
- img_test = Image.new('RGB', (2394, 3612), color = 'lightgray') # Dimensions from example XML
159
- img_test.save("dummy_image.png")
160
- print("Created dummy_image.png for testing.")
161
-
162
- # Ensure the example XML file (189819724.34.xml) exists in the same directory
163
- # or provide the correct path if it's elsewhere.
164
- example_xml_filename = "189819724.34.xml"
165
- if not os.path.exists(example_xml_filename):
166
- print(f"WARNING: Example XML '{example_xml_filename}' not found. Please create it (using the content from the prompt) or upload your own.")
167
-
168
- except ImportError:
169
- print("Pillow not installed, can't create dummy image.")
170
- except Exception as e:
171
- print(f"Error during setup: {e}")
172
-
173
  demo.launch()
 
1
  import gradio as gr
2
+ from PIL import Image
3
  import xml.etree.ElementTree as ET
4
  import os
5
+ import torch
6
+ from transformers import AutoProcessor, AutoModelForImageTextToText, pipeline
7
+
8
+ # --- Global Model and Processor Initialization ---
9
+ # Load the OCR model and processor once when the app starts
10
+ try:
11
+ HF_PROCESSOR = AutoProcessor.from_pretrained("reducto/RolmOCR")
12
+ HF_MODEL = AutoModelForImageTextToText.from_pretrained(
13
+ "reducto/RolmOCR",
14
+ torch_dtype=torch.bfloat16,
15
+ # attn_implementation="flash_attention_2", # User had this commented out
16
+ device_map="auto"
17
+ )
18
+ HF_PIPE = pipeline("image-text-to-text", model=HF_MODEL, processor=HF_PROCESSOR)
19
+ print("Hugging Face OCR model loaded successfully.")
20
+ except Exception as e:
21
+ print(f"Error loading Hugging Face model: {e}")
22
+ HF_PIPE = None
23
 
24
  # --- Helper Functions ---
25
 
 
52
  tree = ET.parse(xml_file_path)
53
  root = tree.getroot()
54
 
 
55
  for text_line in root.findall(f'.//{ns_prefix}TextLine'):
56
  line_text_parts = []
57
  for string_element in text_line.findall(f'{ns_prefix}String'):
58
  text = string_element.get('CONTENT')
59
+ if text:
60
  line_text_parts.append(text)
 
 
 
61
  if line_text_parts:
62
  full_text_lines.append(" ".join(line_text_parts))
63
 
 
68
  except Exception as e:
69
  return f"An unexpected error occurred during XML parsing: {e}"
70
 
71
+ def run_hf_ocr(image_path):
72
+ """
73
+ Runs OCR on the provided image using the pre-loaded Hugging Face model.
74
+ """
75
+ if HF_PIPE is None:
76
+ return "Hugging Face OCR model not available."
77
+ if image_path is None:
78
+ return "No image provided for OCR."
79
+
80
+ try:
81
+ # Load the image using PIL, as the pipeline expects an image object or path
82
+ pil_image = Image.open(image_path).convert("RGB")
83
+
84
+ # The user's example output for the pipeline call was:
85
+ # [{'generated_text': [{'role': 'user', ...}, {'role': 'assistant', 'content': "TEXT..."}]}]
86
+ # This suggests the pipeline is returning a conversational style output.
87
+ # We will try to call the pipeline with the image and prompt directly.
88
+ ocr_results = HF_PIPE(
89
+ pil_image,
90
+ prompt="Return the plain text representation of this document as if you were reading it naturally.\n"
91
+ # The pipeline should handle formatting this into messages if needed by the model.
92
+ )
93
+
94
+ # Parse the output based on the user's example structure
95
+ if isinstance(ocr_results, list) and ocr_results and 'generated_text' in ocr_results[0]:
96
+ generated_content = ocr_results[0]['generated_text']
97
+
98
+ # Check if generated_content itself is the direct text (some pipelines do this)
99
+ if isinstance(generated_content, str):
100
+ return generated_content
101
+
102
+ # Check for the conversational structure
103
+ # [{'role': 'user', ...}, {'role': 'assistant', 'content': "TEXT..."}]
104
+ if isinstance(generated_content, list) and generated_content:
105
+ # The assistant's response is typically the last message in the list
106
+ # or specifically the one with role 'assistant'.
107
+ assistant_message = None
108
+ for msg in reversed(generated_content): # Check from the end
109
+ if isinstance(msg, dict) and msg.get('role') == 'assistant' and 'content' in msg:
110
+ assistant_message = msg['content']
111
+ break
112
+ if assistant_message:
113
+ return assistant_message
114
+
115
+ # Fallback if parsing the complex structure fails but we got some string
116
+ if isinstance(generated_content, list) and generated_content and isinstance(generated_content[0], dict) and 'content' in generated_content[0]:
117
+ # This is a guess if the structure is simpler than expected.
118
+ # Or if the first part is the user prompt echo and second is assistant.
119
+ if len(generated_content) > 1 and isinstance(generated_content[1], dict) and 'content' in generated_content[1]:
120
+ return generated_content[1]['content'] # Assuming second part is assistant
121
+
122
+ print(f"Unexpected OCR output structure from HF model: {ocr_results}")
123
+ return "Error: Could not parse OCR model output. Please check console for details."
124
+
125
+ else:
126
+ print(f"Unexpected OCR output structure from HF model: {ocr_results}")
127
+ return "Error: OCR model did not return expected output. Please check console for details."
128
+
129
+ except Exception as e:
130
+ print(f"Error during Hugging Face OCR: {e}")
131
+ return f"Error during Hugging Face OCR: {str(e)}"
132
 
133
  # --- Gradio Interface Function ---
134
 
135
+ def process_files(image_path, xml_path):
136
  """
137
  Main function for the Gradio interface.
138
+ Processes the image for display, runs OCR (Hugging Face model),
139
+ and parses ALTO XML if provided.
140
  """
141
+ img_to_display = None
142
+ alto_text_output = "ALTO XML not provided or not processed."
143
+ hf_ocr_text_output = "Image not provided or OCR not run."
144
+
145
+ if image_path:
146
+ try:
147
+ img_to_display = Image.open(image_path).convert("RGB")
148
+ hf_ocr_text_output = run_hf_ocr(image_path)
149
+ except Exception as e:
150
+ img_to_display = None # Clear image if it failed to load
151
+ hf_ocr_text_output = f"Error loading image or running HF OCR: {e}"
152
+ else:
153
+ hf_ocr_text_output = "Please upload an image to perform OCR."
154
+
155
+
156
+ if xml_path:
157
+ alto_text_output = parse_alto_xml_for_text(xml_path)
158
+ else:
159
+ alto_text_output = "No ALTO XML file uploaded."
160
+
161
+ # If only XML is provided without an image
162
+ if not image_path and xml_path:
163
+ img_to_display = None # No image to display
164
+ hf_ocr_text_output = "Upload an image to perform OCR."
165
 
 
 
166
 
167
+ return img_to_display, alto_text_output, hf_ocr_text_output
 
 
 
168
 
169
 
170
  # --- Create Gradio App ---
171
 
172
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
173
+ gr.Markdown("# OCR Viewer and Extractor")
174
  gr.Markdown(
175
+ "Upload an image to perform OCR using a Hugging Face model. "
176
+ "Optionally, upload its corresponding ALTO OCR XML file to compare the extracted text."
177
  )
178
 
179
  with gr.Row():
180
  with gr.Column(scale=1):
181
  image_input = gr.File(label="Upload Image (PNG, JPG, etc.)", type="filepath")
182
+ xml_input = gr.File(label="Upload ALTO XML File (Optional, .xml)", type="filepath")
183
+ submit_button = gr.Button("Process Image and XML", variant="primary")
 
184
 
185
  with gr.Row():
186
  with gr.Column(scale=1):
187
+ output_image_display = gr.Image(label="Uploaded Image", type="pil", interactive=False)
188
  with gr.Column(scale=1):
189
+ hf_ocr_output_textbox = gr.Textbox(
190
+ label="OCR Output (Hugging Face Model)",
191
+ lines=15,
192
+ interactive=False,
193
+ show_copy_button=True
194
+ )
195
+ alto_xml_output_textbox = gr.Textbox(
196
+ label="Text from ALTO XML",
197
+ lines=15,
198
+ interactive=False,
199
+ show_copy_button=True
200
+ )
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  submit_button.click(
203
+ fn=process_files,
204
+ inputs=[image_input, xml_input],
205
+ outputs=[output_image_display, alto_xml_output_textbox, hf_ocr_output_textbox]
206
  )
207
 
 
 
208
  gr.Markdown("---")
209
  gr.Markdown("### Example ALTO XML Snippet (for `String` element extraction):")
210
  gr.Code(
211
+ value=(
212
+ """<alto xmlns="http://www.loc.gov/standards/alto/v3/alto.xsd">
213
  <Description>...</Description>
214
  <Styles>...</Styles>
215
  <Layout>
 
228
  </PrintSpace>
229
  </Page>
230
  </Layout>
231
+ </alto>"""
232
+ ),
233
+ language="xml", # Added language for syntax highlighting
234
  interactive=False
235
  )
236
 
 
237
  if __name__ == "__main__":
238
+ # Removed dummy file creation as it's less relevant for single file focus
239
+ print("Attempting to launch Gradio demo...")
240
+ print("If the Hugging Face model is large, initial startup might take some time due to model download/loading.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  demo.launch()
requirements.in ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ Pillow
3
+ lxml
4
+ torch
5
+ transformers
6
+ spaces
requirements.txt CHANGED
@@ -1,3 +1,238 @@
1
- gradio
2
- Pillow
3
- lxml
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile --python-platform linux --python-version 3.10 requirements.in -o requirements.txt
3
+ aiofiles==24.1.0
4
+ # via gradio
5
+ annotated-types==0.7.0
6
+ # via pydantic
7
+ anyio==4.9.0
8
+ # via
9
+ # gradio
10
+ # httpx
11
+ # starlette
12
+ certifi==2025.4.26
13
+ # via
14
+ # httpcore
15
+ # httpx
16
+ # requests
17
+ charset-normalizer==3.4.2
18
+ # via requests
19
+ click==8.1.8
20
+ # via
21
+ # typer
22
+ # uvicorn
23
+ exceptiongroup==1.3.0
24
+ # via anyio
25
+ fastapi==0.115.12
26
+ # via gradio
27
+ ffmpy==0.5.0
28
+ # via gradio
29
+ filelock==3.18.0
30
+ # via
31
+ # huggingface-hub
32
+ # torch
33
+ # transformers
34
+ fsspec==2025.5.0
35
+ # via
36
+ # gradio-client
37
+ # huggingface-hub
38
+ # torch
39
+ gradio==5.30.0
40
+ # via
41
+ # -r requirements.in
42
+ # spaces
43
+ gradio-client==1.10.1
44
+ # via gradio
45
+ groovy==0.1.2
46
+ # via gradio
47
+ h11==0.16.0
48
+ # via
49
+ # httpcore
50
+ # uvicorn
51
+ httpcore==1.0.9
52
+ # via httpx
53
+ httpx==0.28.1
54
+ # via
55
+ # gradio
56
+ # gradio-client
57
+ # safehttpx
58
+ # spaces
59
+ huggingface-hub==0.31.4
60
+ # via
61
+ # gradio
62
+ # gradio-client
63
+ # tokenizers
64
+ # transformers
65
+ idna==3.10
66
+ # via
67
+ # anyio
68
+ # httpx
69
+ # requests
70
+ jinja2==3.1.6
71
+ # via
72
+ # gradio
73
+ # torch
74
+ lxml==5.4.0
75
+ # via -r requirements.in
76
+ markdown-it-py==3.0.0
77
+ # via rich
78
+ markupsafe==3.0.2
79
+ # via
80
+ # gradio
81
+ # jinja2
82
+ mdurl==0.1.2
83
+ # via markdown-it-py
84
+ mpmath==1.3.0
85
+ # via sympy
86
+ networkx==3.4.2
87
+ # via torch
88
+ numpy==2.2.6
89
+ # via
90
+ # gradio
91
+ # pandas
92
+ # transformers
93
+ nvidia-cublas-cu12==12.4.5.8
94
+ # via
95
+ # nvidia-cudnn-cu12
96
+ # nvidia-cusolver-cu12
97
+ # torch
98
+ nvidia-cuda-cupti-cu12==12.4.127
99
+ # via torch
100
+ nvidia-cuda-nvrtc-cu12==12.4.127
101
+ # via torch
102
+ nvidia-cuda-runtime-cu12==12.4.127
103
+ # via torch
104
+ nvidia-cudnn-cu12==9.1.0.70
105
+ # via torch
106
+ nvidia-cufft-cu12==11.2.1.3
107
+ # via torch
108
+ nvidia-curand-cu12==10.3.5.147
109
+ # via torch
110
+ nvidia-cusolver-cu12==11.6.1.9
111
+ # via torch
112
+ nvidia-cusparse-cu12==12.3.1.170
113
+ # via
114
+ # nvidia-cusolver-cu12
115
+ # torch
116
+ nvidia-cusparselt-cu12==0.6.2
117
+ # via torch
118
+ nvidia-nccl-cu12==2.21.5
119
+ # via torch
120
+ nvidia-nvjitlink-cu12==12.4.127
121
+ # via
122
+ # nvidia-cusolver-cu12
123
+ # nvidia-cusparse-cu12
124
+ # torch
125
+ nvidia-nvtx-cu12==12.4.127
126
+ # via torch
127
+ orjson==3.10.18
128
+ # via gradio
129
+ packaging==25.0
130
+ # via
131
+ # gradio
132
+ # gradio-client
133
+ # huggingface-hub
134
+ # spaces
135
+ # transformers
136
+ pandas==2.2.3
137
+ # via gradio
138
+ pillow==11.2.1
139
+ # via
140
+ # -r requirements.in
141
+ # gradio
142
+ psutil==5.9.8
143
+ # via spaces
144
+ pydantic==2.11.4
145
+ # via
146
+ # fastapi
147
+ # gradio
148
+ # spaces
149
+ pydantic-core==2.33.2
150
+ # via pydantic
151
+ pydub==0.25.1
152
+ # via gradio
153
+ pygments==2.19.1
154
+ # via rich
155
+ python-dateutil==2.9.0.post0
156
+ # via pandas
157
+ python-multipart==0.0.20
158
+ # via gradio
159
+ pytz==2025.2
160
+ # via pandas
161
+ pyyaml==6.0.2
162
+ # via
163
+ # gradio
164
+ # huggingface-hub
165
+ # transformers
166
+ regex==2024.11.6
167
+ # via transformers
168
+ requests==2.32.3
169
+ # via
170
+ # huggingface-hub
171
+ # spaces
172
+ # transformers
173
+ rich==14.0.0
174
+ # via typer
175
+ ruff==0.11.10
176
+ # via gradio
177
+ safehttpx==0.1.6
178
+ # via gradio
179
+ safetensors==0.5.3
180
+ # via transformers
181
+ semantic-version==2.10.0
182
+ # via gradio
183
+ shellingham==1.5.4
184
+ # via typer
185
+ six==1.17.0
186
+ # via python-dateutil
187
+ sniffio==1.3.1
188
+ # via anyio
189
+ spaces==0.36.0
190
+ # via -r requirements.in
191
+ starlette==0.46.2
192
+ # via
193
+ # fastapi
194
+ # gradio
195
+ sympy==1.13.1
196
+ # via torch
197
+ tokenizers==0.21.1
198
+ # via transformers
199
+ tomlkit==0.13.2
200
+ # via gradio
201
+ torch==2.6.0
202
+ # via -r requirements.in
203
+ tqdm==4.67.1
204
+ # via
205
+ # huggingface-hub
206
+ # transformers
207
+ transformers==4.52.2
208
+ # via -r requirements.in
209
+ triton==3.2.0
210
+ # via torch
211
+ typer==0.15.4
212
+ # via gradio
213
+ typing-extensions==4.13.2
214
+ # via
215
+ # anyio
216
+ # exceptiongroup
217
+ # fastapi
218
+ # gradio
219
+ # gradio-client
220
+ # huggingface-hub
221
+ # pydantic
222
+ # pydantic-core
223
+ # rich
224
+ # spaces
225
+ # torch
226
+ # typer
227
+ # typing-inspection
228
+ # uvicorn
229
+ typing-inspection==0.4.1
230
+ # via pydantic
231
+ tzdata==2025.2
232
+ # via pandas
233
+ urllib3==2.4.0
234
+ # via requests
235
+ uvicorn==0.34.2
236
+ # via gradio
237
+ websockets==15.0.1
238
+ # via gradio-client