ZennyKenny commited on
Commit
bc2d97b
·
verified ·
1 Parent(s): 149440a

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +373 -0
utils.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code is copied from https://github.com/allenai/olmocr
3
+ Under the Apache 2.0 license.
4
+ All credit goes to the original authors.
5
+ """
6
+ from dataclasses import dataclass
7
+ import re
8
+ import tempfile
9
+ from PIL import Image
10
+ import subprocess
11
+ import base64
12
+ from typing import List, Literal
13
+ import random
14
+ import ftfy
15
+ from pypdf.generic import RectangleObject
16
+ from pypdf import PdfReader
17
+
18
+ @dataclass(frozen=True)
19
+ class Element:
20
+ pass
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class BoundingBox:
25
+ x0: float
26
+ y0: float
27
+ x1: float
28
+ y1: float
29
+
30
+ @staticmethod
31
+ def from_rectangle(rect: RectangleObject) -> "BoundingBox":
32
+ return BoundingBox(rect[0], rect[1], rect[2], rect[3])
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class TextElement(Element):
37
+ text: str
38
+ x: float
39
+ y: float
40
+
41
+
42
+ @dataclass(frozen=True)
43
+ class ImageElement(Element):
44
+ name: str
45
+ bbox: BoundingBox
46
+
47
+
48
+ @dataclass(frozen=True)
49
+ class PageReport:
50
+ mediabox: BoundingBox
51
+ text_elements: List[TextElement]
52
+ image_elements: List[ImageElement]
53
+
54
+ def image_to_pdf(image_path):
55
+ try:
56
+ # Open the image file.
57
+ img = Image.open(image_path)
58
+ # Create a temporary file to store the PDF.
59
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
60
+ filename = tmp.name
61
+ temp_pdf_created = True
62
+ # Convert image to RGB if necessary and save as PDF.
63
+ if img.mode != "RGB":
64
+ img = img.convert("RGB")
65
+ img.save(filename, "PDF")
66
+ return filename
67
+ except Exception as conv_err:
68
+ return None
69
+
70
+ def get_pdf_media_box_width_height(local_pdf_path: str, page_num: int) -> tuple[float, float]:
71
+ """
72
+ Get the MediaBox dimensions for a specific page in a PDF file using the pdfinfo command.
73
+
74
+ :param pdf_file: Path to the PDF file
75
+ :param page_num: The page number for which to extract MediaBox dimensions
76
+ :return: A dictionary containing MediaBox dimensions or None if not found
77
+ """
78
+ # Construct the pdfinfo command to extract info for the specific page
79
+ command = ["pdfinfo", "-f", str(page_num), "-l", str(page_num), "-box", "-enc", "UTF-8", local_pdf_path]
80
+
81
+ # Run the command using subprocess
82
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
83
+
84
+ # Check if there is any error in executing the command
85
+ if result.returncode != 0:
86
+ raise ValueError(f"Error running pdfinfo: {result.stderr}")
87
+
88
+ # Parse the output to find MediaBox
89
+ output = result.stdout
90
+
91
+ for line in output.splitlines():
92
+ if "MediaBox" in line:
93
+ media_box_str: List[str] = line.split(":")[1].strip().split()
94
+ media_box: List[float] = [float(x) for x in media_box_str]
95
+ return abs(media_box[0] - media_box[2]), abs(media_box[3] - media_box[1])
96
+
97
+ raise ValueError("MediaBox not found in the PDF info.")
98
+
99
+ def render_pdf_to_base64png(local_pdf_path: str, page_num: int, target_longest_image_dim: int = 2048) -> str:
100
+ longest_dim = max(get_pdf_media_box_width_height(local_pdf_path, page_num))
101
+
102
+ # Convert PDF page to PNG using pdftoppm
103
+ pdftoppm_result = subprocess.run(
104
+ [
105
+ "pdftoppm",
106
+ "-png",
107
+ "-f",
108
+ str(page_num),
109
+ "-l",
110
+ str(page_num),
111
+ "-r",
112
+ str(target_longest_image_dim * 72 / longest_dim), # 72 pixels per point is the conversion factor
113
+ local_pdf_path,
114
+ ],
115
+ timeout=120,
116
+ stdout=subprocess.PIPE,
117
+ stderr=subprocess.PIPE,
118
+ )
119
+ assert pdftoppm_result.returncode == 0, pdftoppm_result.stderr
120
+ return base64.b64encode(pdftoppm_result.stdout).decode("utf-8")
121
+
122
+
123
+ def _linearize_pdf_report(report: PageReport, max_length: int = 4000) -> str:
124
+ result = ""
125
+ result += f"Page dimensions: {report.mediabox.x1:.1f}x{report.mediabox.y1:.1f}\n"
126
+
127
+ if max_length < 20:
128
+ return result
129
+
130
+ images = _merge_image_elements(report.image_elements)
131
+
132
+ # Process image elements
133
+ image_strings = []
134
+ for element in images:
135
+ image_str = f"[Image {element.bbox.x0:.0f}x{element.bbox.y0:.0f} to {element.bbox.x1:.0f}x{element.bbox.y1:.0f}]\n"
136
+ # Use element's unique identifier (e.g., id or position) for comparison
137
+ image_strings.append((element, image_str))
138
+
139
+ # Process text elements
140
+ text_strings = []
141
+ for element in report.text_elements: # type: ignore
142
+ if len(element.text.strip()) == 0: # type: ignore
143
+ continue
144
+
145
+ element_text = _cleanup_element_text(element.text) # type: ignore
146
+ text_str = f"[{element.x:.0f}x{element.y:.0f}]{element_text}\n" # type: ignore
147
+ text_strings.append((element, text_str))
148
+
149
+ # Combine all elements with their positions for sorting
150
+ all_elements: list[tuple[str, ImageElement, str, tuple[float, float]]] = []
151
+ for elem, s in image_strings:
152
+ position = (elem.bbox.x0, elem.bbox.y0)
153
+ all_elements.append(("image", elem, s, position))
154
+ for elem, s in text_strings:
155
+ position = (elem.x, elem.y) # type: ignore
156
+ all_elements.append(("text", elem, s, position))
157
+
158
+ # Calculate total length
159
+ total_length = len(result) + sum(len(s) for _, _, s, _ in all_elements)
160
+
161
+ if total_length <= max_length:
162
+ # Include all elements
163
+ for _, _, s, _ in all_elements:
164
+ result += s
165
+ return result
166
+
167
+ # Identify elements with min/max coordinates
168
+ edge_elements = set()
169
+
170
+ if images:
171
+ min_x0_image = min(images, key=lambda e: e.bbox.x0)
172
+ max_x1_image = max(images, key=lambda e: e.bbox.x1)
173
+ min_y0_image = min(images, key=lambda e: e.bbox.y0)
174
+ max_y1_image = max(images, key=lambda e: e.bbox.y1)
175
+ edge_elements.update([min_x0_image, max_x1_image, min_y0_image, max_y1_image])
176
+
177
+ if report.text_elements:
178
+ text_elements = [e for e in report.text_elements if len(e.text.strip()) > 0]
179
+ if text_elements:
180
+ min_x_text = min(text_elements, key=lambda e: e.x)
181
+ max_x_text = max(text_elements, key=lambda e: e.x)
182
+ min_y_text = min(text_elements, key=lambda e: e.y)
183
+ max_y_text = max(text_elements, key=lambda e: e.y)
184
+ edge_elements.update([min_x_text, max_x_text, min_y_text, max_y_text]) # type: ignore
185
+
186
+ # Keep track of element IDs to prevent duplication
187
+ selected_element_ids = set()
188
+ selected_elements = []
189
+
190
+ # Include edge elements first
191
+ for elem_type, elem, s, position in all_elements:
192
+ if elem in edge_elements and id(elem) not in selected_element_ids:
193
+ selected_elements.append((elem_type, elem, s, position))
194
+ selected_element_ids.add(id(elem))
195
+
196
+ # Calculate remaining length
197
+ current_length = len(result) + sum(len(s) for _, _, s, _ in selected_elements)
198
+ _remaining_length = max_length - current_length
199
+
200
+ # Exclude edge elements from the pool
201
+ remaining_elements = [(elem_type, elem, s, position) for elem_type, elem, s, position in all_elements if id(elem) not in selected_element_ids]
202
+
203
+ # Sort remaining elements by their positions (e.g., x-coordinate and then y-coordinate)
204
+ # remaining_elements.sort(key=lambda x: (x[3][0], x[3][1]))
205
+
206
+ # Shuffle remaining elements randomly
207
+ random.shuffle(remaining_elements)
208
+
209
+ # Add elements until reaching max_length
210
+ for elem_type, elem, s, position in remaining_elements:
211
+ if current_length + len(s) > max_length:
212
+ break
213
+ selected_elements.append((elem_type, elem, s, position))
214
+ selected_element_ids.add(id(elem))
215
+ current_length += len(s)
216
+
217
+ # Sort selected elements by their positions to maintain logical order
218
+ selected_elements.sort(key=lambda x: (x[3][0], x[3][1]))
219
+
220
+ # Build the final result
221
+ for _, _, s, _ in selected_elements:
222
+ result += s
223
+
224
+ return result
225
+
226
+
227
+ def _cap_split_string(text: str, max_length: int) -> str:
228
+ if len(text) <= max_length:
229
+ return text
230
+
231
+ head_length = max_length // 2 - 3
232
+ tail_length = head_length
233
+
234
+ head = text[:head_length].rsplit(" ", 1)[0] or text[:head_length]
235
+ tail = text[-tail_length:].split(" ", 1)[-1] or text[-tail_length:]
236
+
237
+ return f"{head} ... {tail}"
238
+
239
+
240
+ def _cleanup_element_text(element_text: str) -> str:
241
+ MAX_TEXT_ELEMENT_LENGTH = 250
242
+ TEXT_REPLACEMENTS = {"[": "\\[", "]": "\\]", "\n": "\\n", "\r": "\\r", "\t": "\\t"}
243
+ text_replacement_pattern = re.compile("|".join(re.escape(key) for key in TEXT_REPLACEMENTS.keys()))
244
+
245
+ element_text = ftfy.fix_text(element_text).strip()
246
+
247
+ # Replace square brackets with escaped brackets and other escaped chars
248
+ element_text = text_replacement_pattern.sub(lambda match: TEXT_REPLACEMENTS[match.group(0)], element_text)
249
+
250
+ return _cap_split_string(element_text, MAX_TEXT_ELEMENT_LENGTH)
251
+
252
+ def _merge_image_elements(images: List[ImageElement], tolerance: float = 0.5) -> List[ImageElement]:
253
+ n = len(images)
254
+ parent = list(range(n)) # Initialize Union-Find parent pointers
255
+
256
+ def find(i):
257
+ # Find with path compression
258
+ root = i
259
+ while parent[root] != root:
260
+ root = parent[root]
261
+ while parent[i] != i:
262
+ parent_i = parent[i]
263
+ parent[i] = root
264
+ i = parent_i
265
+ return root
266
+
267
+ def union(i, j):
268
+ # Union by attaching root of one tree to another
269
+ root_i = find(i)
270
+ root_j = find(j)
271
+ if root_i != root_j:
272
+ parent[root_i] = root_j
273
+
274
+ def bboxes_overlap(b1: BoundingBox, b2: BoundingBox, tolerance: float) -> bool:
275
+ # Compute horizontal and vertical distances between boxes
276
+ h_dist = max(0, max(b1.x0, b2.x0) - min(b1.x1, b2.x1))
277
+ v_dist = max(0, max(b1.y0, b2.y0) - min(b1.y1, b2.y1))
278
+ # Check if distances are within tolerance
279
+ return h_dist <= tolerance and v_dist <= tolerance
280
+
281
+ # Union overlapping images
282
+ for i in range(n):
283
+ for j in range(i + 1, n):
284
+ if bboxes_overlap(images[i].bbox, images[j].bbox, tolerance):
285
+ union(i, j)
286
+
287
+ # Group images by their root parent
288
+ groups: dict[int, list[int]] = {}
289
+ for i in range(n):
290
+ root = find(i)
291
+ groups.setdefault(root, []).append(i)
292
+
293
+ # Merge images in the same group
294
+ merged_images = []
295
+ for indices in groups.values():
296
+ # Initialize merged bounding box
297
+ merged_bbox = images[indices[0]].bbox
298
+ merged_name = images[indices[0]].name
299
+
300
+ for idx in indices[1:]:
301
+ bbox = images[idx].bbox
302
+ # Expand merged_bbox to include the current bbox
303
+ merged_bbox = BoundingBox(
304
+ x0=min(merged_bbox.x0, bbox.x0),
305
+ y0=min(merged_bbox.y0, bbox.y0),
306
+ x1=max(merged_bbox.x1, bbox.x1),
307
+ y1=max(merged_bbox.y1, bbox.y1),
308
+ )
309
+ # Optionally, update the name
310
+ merged_name += f"+{images[idx].name}"
311
+
312
+ merged_images.append(ImageElement(name=merged_name, bbox=merged_bbox))
313
+
314
+ # Return the merged images along with other elements
315
+ return merged_images
316
+
317
+ def _transform_point(x, y, m):
318
+ x_new = m[0] * x + m[2] * y + m[4]
319
+ y_new = m[1] * x + m[3] * y + m[5]
320
+ return x_new, y_new
321
+
322
+ def _mult(m: List[float], n: List[float]) -> List[float]:
323
+ return [
324
+ m[0] * n[0] + m[1] * n[2],
325
+ m[0] * n[1] + m[1] * n[3],
326
+ m[2] * n[0] + m[3] * n[2],
327
+ m[2] * n[1] + m[3] * n[3],
328
+ m[4] * n[0] + m[5] * n[2] + n[4],
329
+ m[4] * n[1] + m[5] * n[3] + n[5],
330
+ ]
331
+
332
+ def _pdf_report(local_pdf_path: str, page_num: int) -> PageReport:
333
+ reader = PdfReader(local_pdf_path)
334
+ page = reader.pages[page_num - 1]
335
+ resources = page.get("/Resources", {})
336
+ xobjects = resources.get("/XObject", {})
337
+ text_elements, image_elements = [], []
338
+
339
+ def visitor_body(text, cm, tm, font_dict, font_size):
340
+ txt2user = _mult(tm, cm)
341
+ text_elements.append(TextElement(text, txt2user[4], txt2user[5]))
342
+
343
+ def visitor_op(op, args, cm, tm):
344
+ if op == b"Do":
345
+ xobject_name = args[0]
346
+ xobject = xobjects.get(xobject_name)
347
+ if xobject and xobject["/Subtype"] == "/Image":
348
+ # Compute image bbox
349
+ # The image is placed according to the CTM
350
+ _width = xobject.get("/Width")
351
+ _height = xobject.get("/Height")
352
+ x0, y0 = _transform_point(0, 0, cm)
353
+ x1, y1 = _transform_point(1, 1, cm)
354
+ image_elements.append(ImageElement(xobject_name, BoundingBox(min(x0, x1), min(y0, y1), max(x0, x1), max(y0, y1))))
355
+
356
+ page.extract_text(visitor_text=visitor_body, visitor_operand_before=visitor_op)
357
+
358
+ return PageReport(
359
+ mediabox=BoundingBox.from_rectangle(page.mediabox),
360
+ text_elements=text_elements,
361
+ image_elements=image_elements,
362
+ )
363
+
364
+ def get_anchor_text(
365
+ local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pypdf", "topcoherency", "pdfreport"], target_length: int = 4000
366
+ ) -> str:
367
+ assert page > 0, "Pages are 1-indexed in pdf-land"
368
+
369
+
370
+ if pdf_engine == "pdfreport":
371
+ return _linearize_pdf_report(_pdf_report(local_pdf_path, page), max_length=target_length)
372
+ else:
373
+ raise NotImplementedError("Unknown engine")