Charles Kabui
print('document_image_1.info.get(annotation_key) == True, end:', document_image_1.info.get(annotation_key) == True)
4ce1f5f
import traceback | |
import gradio as gr | |
from utils.get_RGB_image import get_RGB_image, is_online_file, steam_online_file | |
from pdf2image import convert_from_path, convert_from_bytes | |
import layoutparser as lp | |
from PIL import Image | |
from utils.get_features import get_features | |
from imagehash import average_hash | |
from sklearn.metrics.pairwise import cosine_similarity | |
from utils.visualize_bboxes_on_image import visualize_bboxes_on_image | |
label_map = {0: 'Caption', 1: 'Footnote', 2: 'Formula', 3: 'List-item', 4: 'Page-footer', 5: 'Page-header', 6: 'Picture', 7: 'Section-header', 8: 'Table', 9: 'Text', 10: 'Title'} | |
label_names = list(label_map.values()) | |
color_map = {'Caption': '#acc2d9', 'Footnote': '#56ae57', 'Formula': '#b2996e', 'List-item': '#a8ff04', 'Page-footer': '#69d84f', 'Page-header': '#894585', 'Picture': '#70b23f', 'Section-header': '#d4ffff', 'Table': '#65ab7c', 'Text': '#952e8f', 'Title': '#fcfc81'} | |
cache = { | |
'document_image_1_hash': None, | |
'document_image_2_hash': None, | |
'document_image_1_features': None, | |
'document_image_2_features': None, | |
} | |
pre_message_style = 'overflow:auto;border:2px solid pink;padding:4px;border-radius:4px;' | |
visualize_bboxes_on_image_kwargs = { | |
'label_text_color': 'white', | |
'label_rectangle_color': 'black', | |
'label_text_size': 12, | |
'label_text_padding': 3, | |
'label_rectangle_left_margin': 0, | |
'label_rectangle_top_margin': 0 | |
} | |
vectors_types = ['vectors', 'weighted_vectors', 'reduced_vectors', 'reduced_weighted_vectors'] | |
annotation_key = 'is_annotated_document_image' | |
annotation_original_image_key = 'original_image' | |
def annotate_document_image(document_image: Image.Image, original_document_image: Image.Image): | |
document_image.info.update({ | |
annotation_key: True, | |
annotation_original_image_key: original_document_image | |
}) | |
return document_image | |
def get_original_document_image(document_image: Image.Image): | |
if document_image.info.get(annotation_key) == True: | |
return document_image.info.get(annotation_original_image_key) | |
return document_image | |
def similarity_fn(model: lp.Detectron2LayoutModel, document_image_1: Image.Image, document_image_2: Image.Image, vectors_type: str): | |
message = None | |
annotations = { | |
'predicted_bboxes': 'predicted_bboxes' if vectors_type in ['vectors', 'weighted_vectors'] else 'reduced_predicted_bboxes', | |
'predicted_scores': 'predicted_scores' if vectors_type in ['vectors', 'weighted_vectors'] else 'reduced_predicted_scores', | |
'predicted_labels': 'predicted_labels' if vectors_type in ['vectors', 'weighted_vectors'] else 'reduced_predicted_labels', | |
} | |
show_vectors_type = False | |
try: | |
if document_image_1 is None or document_image_2 is None: | |
message = f'<pre style="{pre_message_style}">Please load both the documents to compare.<pre>' | |
else: | |
document_image_1 = get_original_document_image(document_image_1) | |
document_image_2 = get_original_document_image(document_image_2) | |
document_image_1_hash = str(average_hash(document_image_1)) | |
document_image_2_hash = str(average_hash(document_image_2)) | |
if document_image_1_hash == cache['document_image_1_hash']: | |
document_image_1_features = cache['document_image_1_features'] | |
else: | |
document_image_1_features = get_features(document_image_1, model, label_names) | |
cache['document_image_1_hash'] = document_image_1_hash | |
cache['document_image_1_features'] = document_image_1_features | |
if document_image_2_hash == cache['document_image_2_hash']: | |
document_image_2_features = cache['document_image_2_features'] | |
else: | |
document_image_2_features = get_features(document_image_2, model, label_names) | |
cache['document_image_2_hash'] = document_image_2_hash | |
cache['document_image_2_features'] = document_image_2_features | |
[[similarity]] = cosine_similarity( | |
[ | |
cache['document_image_1_features'][vectors_type] | |
], | |
[ | |
cache['document_image_2_features'][vectors_type] | |
]) | |
message = f'<pre style="{pre_message_style}">Similarity between the two documents is: {round(similarity, 4)}<pre>' | |
annotated_document_image_1 = visualize_bboxes_on_image( | |
image = document_image_1, | |
bboxes = cache['document_image_1_features'][annotations['predicted_bboxes']], | |
titles = [f'{label}, score:{round(score, 2)}' for label, score in zip( | |
cache['document_image_1_features'][annotations['predicted_labels']], | |
cache['document_image_1_features'][annotations['predicted_scores']])], | |
**visualize_bboxes_on_image_kwargs) | |
annotated_document_image_2 = visualize_bboxes_on_image( | |
image = document_image_2, | |
bboxes = cache['document_image_2_features'][annotations['predicted_bboxes']], | |
titles = [f'{label}, score:{score}' for label, score in zip( | |
cache['document_image_2_features'][annotations['predicted_labels']], | |
cache['document_image_2_features'][annotations['predicted_scores']])], | |
**visualize_bboxes_on_image_kwargs) | |
show_vectors_type = True | |
document_image_1 = annotate_document_image(annotated_document_image_1, document_image_1) | |
document_image_2 = annotate_document_image(annotated_document_image_2, document_image_2) | |
except Exception as e: | |
message = f'<pre style="{pre_message_style}">{traceback.format_exc()}<pre>' | |
return [ | |
gr.HTML(message, visible=True), | |
document_image_1, | |
document_image_2, | |
gr.Dropdown(visible=show_vectors_type) | |
] | |
def load_image(filename, page = 0): | |
try: | |
image = None | |
try: | |
if (is_online_file(filename)): | |
image = get_RGB_image(convert_from_bytes(steam_online_file(filename))[page]) | |
else: | |
image = get_RGB_image(convert_from_path(filename)[page]) | |
except: | |
image = get_RGB_image(filename) | |
return [ | |
gr.Image(value=image, visible=True), | |
None | |
] | |
except: | |
error = traceback.format_exc() | |
return [None, gr.HTML(value=error, visible=True)] | |
def preview_url(url, page = 0): | |
[image, error] = load_image(url, page = page) | |
if image: | |
return [gr.Tabs(selected=0), image, error] | |
else: | |
return [gr.Tabs(selected=1), image, error] | |
def document_view(document_number: int): | |
gr.HTML(value=f'<h4>Load the {"first" if document_number == 1 else "second"} PDF or Document Image<h4>', elem_classes=['center']) | |
with gr.Tabs() as document_tabs: | |
with gr.Tab("From Image", id=0): | |
document = gr.Image(type="pil", label=f"Document {document_number}", visible=False) | |
document_error_message = gr.HTML(label="Error Message", visible=False) | |
document_preview = gr.UploadButton( | |
"Upload PDF or Document Image", | |
file_types=["image", ".pdf"], | |
file_count="single") | |
with gr.Tab("From URL", id=1): | |
document_url = gr.Textbox( | |
label=f"Document {document_number} URL", | |
info="Paste a Link/URL to PDF or Document Image", | |
placeholder="https://datasets-server.huggingface.co/.../image.jpg") | |
document_url_error_message = gr.HTML(label="Error Message", visible=False) | |
document_url_preview = gr.Button(value="Preview", variant="primary") | |
document_preview.upload( | |
fn = lambda file: load_image(file.name), | |
inputs = [document_preview], | |
outputs = [document, document_error_message]) | |
document_url_preview.click( | |
fn = preview_url, | |
inputs = [document_url], | |
outputs = [document_tabs, document, document_url_error_message]) | |
return document | |
def app(*, model_path, config_path, debug = False): | |
model: lp.Detectron2LayoutModel = lp.Detectron2LayoutModel( | |
config_path = config_path, | |
model_path = model_path, | |
label_map = label_map) | |
title = 'Document Similarity Search Using Visual Layout Features' | |
description = f"<h2>{title}<h2>" | |
css = ''' | |
image { max-height="86vh" !important; } | |
.center { display: flex; flex: 1 1 auto; align-items: center; align-content: center; justify-content: center; justify-items: center; } | |
.hr { width: 100%; display: block; padding: 0; margin: 0; background: gray; height: 4px; border: none; } | |
''' | |
with gr.Blocks(title=title, css=css) as app: | |
with gr.Row(): | |
gr.HTML(value=description, elem_classes=['center']) | |
with gr.Row(equal_height = False): | |
with gr.Column(): | |
document_1_image = document_view(1) | |
with gr.Column(): | |
document_2_image = document_view(2) | |
gr.HTML('<hr/>', elem_classes=['hr']) | |
with gr.Row(elem_classes=['center']): | |
with gr.Column(): | |
submit = gr.Button(value="Get Similarity", variant="primary") | |
with gr.Column(): | |
vectors_type = gr.Dropdown( | |
choices = vectors_types, | |
value = vectors_types[0], | |
visible = False, | |
label = "Vectors Type", | |
info = "Select the Vectors Type to use for Similarity Calculation") | |
similarity_output = gr.HTML(label="Similarity Score", visible=False) | |
reset = gr.Button(value="Reset", variant="secondary") | |
kwargs = { | |
'fn': lambda document_1_image, document_2_image, vectors_type: similarity_fn( | |
model, | |
document_1_image, | |
document_2_image, | |
vectors_type), | |
'inputs': [document_1_image, document_2_image, vectors_type], | |
'outputs': [similarity_output, document_1_image, document_2_image, vectors_type] | |
} | |
submit.click(**kwargs) | |
vectors_type.change(**kwargs) | |
return app.launch(debug=debug) |