|
import io |
|
import os |
|
import tempfile |
|
import time |
|
import uuid |
|
|
|
import cv2 |
|
import gradio as gr |
|
import pymupdf |
|
import spaces |
|
import torch |
|
from loguru import logger |
|
from PIL import Image |
|
from transformers import AutoProcessor, VisionEncoderDecoderModel |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
from utils.utils import prepare_image, parse_layout_string, process_coordinates |
|
except ImportError: |
|
logger.error("Could not import from 'utils.utils'. Please ensure utils.py is in the correct path.") |
|
|
|
def prepare_image(image): return image, None |
|
def parse_layout_string(s): return [] |
|
def process_coordinates(bbox, img, dims, prev_box): return 0,0,0,0,0,0,0,0,None |
|
|
|
|
|
|
|
|
|
model = None |
|
processor = None |
|
tokenizer = None |
|
|
|
|
|
@spaces.GPU |
|
def initialize_model(): |
|
"""Initializes the Hugging Face model and processor.""" |
|
global model, processor, tokenizer |
|
|
|
if model is None: |
|
logger.info("Loading DOLPHIN model for PDF to JSON conversion...") |
|
model_id = "ByteDance/Dolphin" |
|
|
|
try: |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
model = VisionEncoderDecoderModel.from_pretrained(model_id) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model.to(device) |
|
|
|
if device == "cuda": |
|
model = model.half() |
|
|
|
model.eval() |
|
tokenizer = processor.tokenizer |
|
logger.info(f"Model loaded successfully on {device}") |
|
except Exception as e: |
|
logger.error(f"Fatal error during model initialization: {e}") |
|
raise |
|
|
|
|
|
@spaces.GPU |
|
def model_inference(prompt, image): |
|
""" |
|
Performs inference using the Dolphin model. Handles both single and batch processing. |
|
""" |
|
global model, processor, tokenizer |
|
|
|
if model is None: |
|
logger.warning("Model not initialized. Initializing now...") |
|
initialize_model() |
|
|
|
is_batch = isinstance(image, list) |
|
images = image if is_batch else [image] |
|
prompts = prompt if isinstance(prompt, list) else [prompt] * len(images) |
|
|
|
device = model.device |
|
|
|
|
|
batch_inputs = processor(images, return_tensors="pt", padding=True) |
|
pixel_values_dtype = torch.float16 if device == "cuda" else torch.float32 |
|
batch_pixel_values = batch_inputs.pixel_values.to(device, dtype=pixel_values_dtype) |
|
|
|
|
|
prompts_with_task = [f"<s>{p} <Answer/>" for p in prompts] |
|
batch_prompt_inputs = tokenizer( |
|
prompts_with_task, add_special_tokens=False, return_tensors="pt" |
|
) |
|
batch_prompt_ids = batch_prompt_inputs.input_ids.to(device) |
|
batch_attention_mask = batch_prompt_inputs.attention_mask.to(device) |
|
|
|
|
|
outputs = model.generate( |
|
pixel_values=batch_pixel_values, |
|
decoder_input_ids=batch_prompt_ids, |
|
decoder_attention_mask=batch_attention_mask, |
|
max_length=4096, |
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
use_cache=True, |
|
bad_words_ids=[[tokenizer.unk_token_id]], |
|
return_dict_in_generate=True, |
|
) |
|
|
|
|
|
sequences = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False) |
|
results = [ |
|
seq.replace(prompts_with_task[i], "").replace("<pad>", "").replace("</s>", "").strip() |
|
for i, seq in enumerate(sequences) |
|
] |
|
|
|
return results[0] if not is_batch else results |
|
|
|
|
|
@spaces.GPU |
|
def process_element_batch(elements, prompt, max_batch_size=16): |
|
"""Processes a batch of elements of the same type (e.g., text or tables).""" |
|
results = [] |
|
for i in range(0, len(elements), max_batch_size): |
|
batch_elements = elements[i:i + max_batch_size] |
|
crops_list = [elem["crop"] for elem in batch_elements] |
|
prompts_list = [prompt] * len(crops_list) |
|
|
|
batch_results = model_inference(prompts_list, crops_list) |
|
|
|
for j, result in enumerate(batch_results): |
|
elem = batch_elements[j] |
|
results.append({ |
|
"label": elem["label"], |
|
"bbox": elem["bbox"], |
|
"text": result.strip(), |
|
"reading_order": elem["reading_order"], |
|
}) |
|
return results |
|
|
|
|
|
def convert_all_pdf_pages_to_images(file_path, target_size=896): |
|
"""Converts all pages of a PDF file to a list of image file paths.""" |
|
if not file_path or not file_path.lower().endswith('.pdf'): |
|
logger.warning("Not a PDF file. No pages to convert.") |
|
return [] |
|
|
|
image_paths = [] |
|
try: |
|
doc = pymupdf.open(file_path) |
|
for page_num in range(len(doc)): |
|
page = doc[page_num] |
|
scale = target_size / max(page.rect.width, page.rect.height) |
|
mat = pymupdf.Matrix(scale, scale) |
|
pix = page.get_pixmap(matrix=mat) |
|
|
|
img_data = pix.tobytes("png") |
|
pil_image = Image.open(io.BytesIO(img_data)) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=f"_page_{page_num+1}.png", delete=False) as tmp_file: |
|
pil_image.save(tmp_file.name, "PNG") |
|
image_paths.append(tmp_file.name) |
|
doc.close() |
|
except Exception as e: |
|
logger.error(f"Error converting PDF pages to images: {e}") |
|
|
|
for path in image_paths: |
|
cleanup_temp_file(path) |
|
return [] |
|
|
|
return image_paths |
|
|
|
|
|
def process_elements(layout_results, padded_image, dims): |
|
"""Crops and recognizes content for all document elements found in the layout.""" |
|
layout_results = parse_layout_string(layout_results) |
|
text_elements, table_elements, figure_results = [], [], [] |
|
reading_order = 0 |
|
previous_box = None |
|
|
|
for bbox, label in layout_results: |
|
try: |
|
x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates( |
|
bbox, padded_image, dims, previous_box |
|
) |
|
cropped = padded_image[y1:y2, x1:x2] |
|
|
|
if cropped.size > 0 and (cropped.shape[0] > 3 and cropped.shape[1] > 3): |
|
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) |
|
element_info = { |
|
"crop": pil_crop, "label": label, |
|
"bbox": [orig_x1, orig_y1, orig_x2, orig_y2], |
|
"reading_order": reading_order, |
|
} |
|
if label == "tab": |
|
table_elements.append(element_info) |
|
elif label == "fig": |
|
figure_results.append({**element_info, "text": "[FIGURE]"}) |
|
else: |
|
text_elements.append(element_info) |
|
reading_order += 1 |
|
except Exception as e: |
|
logger.error(f"Error processing element with label {label}: {str(e)}") |
|
continue |
|
|
|
recognition_results = figure_results.copy() |
|
if text_elements: |
|
recognition_results.extend(process_element_batch(text_elements, "Read text in the image.")) |
|
if table_elements: |
|
recognition_results.extend(process_element_batch(table_elements, "Parse the table in the image.")) |
|
|
|
recognition_results.sort(key=lambda x: x.get("reading_order", 0)) |
|
|
|
for res in recognition_results: |
|
res.pop('crop', None) |
|
|
|
return recognition_results |
|
|
|
|
|
def process_page(image_path): |
|
"""Processes a single page image to extract all content and return structured data.""" |
|
pil_image = Image.open(image_path).convert("RGB") |
|
|
|
|
|
layout_output = model_inference("Parse the reading order of this document.", pil_image) |
|
|
|
|
|
padded_image, dims = prepare_image(pil_image) |
|
recognition_results = process_elements(layout_output, padded_image, dims) |
|
|
|
return recognition_results |
|
|
|
|
|
def cleanup_temp_file(file_path): |
|
"""Safely deletes a temporary file if it exists.""" |
|
try: |
|
if file_path and os.path.exists(file_path): |
|
os.unlink(file_path) |
|
except Exception as e: |
|
logger.warning(f"Failed to cleanup temp file {file_path}: {e}") |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def pdf_to_json_converter(pdf_file): |
|
""" |
|
Main function for the Gradio interface. Takes a PDF file, processes all pages, |
|
and returns the structured data as a JSON object. |
|
""" |
|
if pdf_file is None: |
|
return {"error": "No file uploaded. Please upload a PDF file."} |
|
|
|
start_time = time.time() |
|
file_path = pdf_file.name |
|
temp_files_created = [] |
|
|
|
try: |
|
logger.info(f"Starting processing for document: {os.path.basename(file_path)}") |
|
|
|
|
|
image_paths = convert_all_pdf_pages_to_images(file_path) |
|
if not image_paths: |
|
raise Exception("Failed to convert PDF to images. The file might be corrupted or not a valid PDF.") |
|
temp_files_created.extend(image_paths) |
|
|
|
all_pages_data = [] |
|
|
|
for page_idx, image_path in enumerate(image_paths): |
|
logger.info(f"Processing page {page_idx + 1}/{len(image_paths)}") |
|
page_elements = process_page(image_path) |
|
all_pages_data.append({ |
|
"page": page_idx + 1, |
|
"elements": page_elements, |
|
}) |
|
|
|
processing_time = time.time() - start_time |
|
logger.info(f"Document processed successfully in {processing_time:.2f}s") |
|
|
|
|
|
final_json = { |
|
"document_info": { |
|
"file_name": os.path.basename(file_path), |
|
"total_pages": len(image_paths), |
|
"processing_time_seconds": round(processing_time, 2), |
|
}, |
|
"pages": all_pages_data |
|
} |
|
return final_json |
|
|
|
except Exception as e: |
|
logger.error(f"An error occurred during document processing: {str(e)}") |
|
return {"error": str(e), "file_name": os.path.basename(file_path)} |
|
|
|
finally: |
|
|
|
logger.info("Cleaning up temporary files...") |
|
for temp_file in temp_files_created: |
|
cleanup_temp_file(temp_file) |
|
|
|
|
|
|
|
def build_gradio_interface(): |
|
"""Builds and returns the simple Gradio UI.""" |
|
with gr.Blocks(title="PDF to JSON Converter") as demo: |
|
gr.Markdown( |
|
""" |
|
# PDF to JSON Converter |
|
Upload a multi-page PDF to extract its content into a structured JSON format using the Dolphin model. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
pdf_input = gr.File( |
|
label="Upload PDF File", |
|
file_types=[".pdf"], |
|
) |
|
submit_btn = gr.Button("Convert to JSON", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
json_output = gr.JSON(label="JSON Output", scale=2) |
|
|
|
submit_btn.click( |
|
fn=pdf_to_json_converter, |
|
inputs=[pdf_input], |
|
outputs=[json_output], |
|
) |
|
|
|
|
|
clear_btn = gr.ClearButton( |
|
value="Clear", |
|
components=[pdf_input, json_output] |
|
) |
|
|
|
return demo |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
logger.info("Starting Gradio application...") |
|
try: |
|
|
|
initialize_model() |
|
|
|
|
|
app_ui = build_gradio_interface() |
|
app_ui.launch() |
|
|
|
except Exception as main_exception: |
|
logger.error(f"Failed to start the application: {main_exception}") |