|
import gradio as gr |
|
from gradio_bbox_annotator import BBoxAnnotator |
|
import json |
|
import os |
|
from pathlib import Path |
|
|
|
|
|
CATEGORY_LIMITS = { |
|
"advertisement": 1, |
|
"text": 2 |
|
} |
|
CATEGORIES = list(CATEGORY_LIMITS.keys()) |
|
|
|
class AnnotationManager: |
|
def __init__(self): |
|
self.annotations = {} |
|
|
|
def validate_annotations(self, bbox_data): |
|
"""Validate the annotation data and return (is_valid, error_message)""" |
|
if not bbox_data or not isinstance(bbox_data, tuple): |
|
return False, "No image or annotations provided" |
|
|
|
image_path, annotations = bbox_data |
|
if not isinstance(image_path, str): |
|
return False, "Invalid image format" |
|
|
|
if not annotations: |
|
return False, "No annotations drawn" |
|
|
|
|
|
category_counts = {cat: 0 for cat in CATEGORIES} |
|
for ann in annotations: |
|
if len(ann) != 5: |
|
return False, "Invalid annotation format" |
|
y1, y2, x1, x2, label = ann |
|
|
|
|
|
if any(not isinstance(coord, (int, float)) for coord in [y1, y2, x1, x2]): |
|
return False, "Invalid coordinate values" |
|
|
|
|
|
if not label or label not in CATEGORIES: |
|
return False, f"Invalid or missing label. Must be one of: {', '.join(CATEGORIES)}" |
|
|
|
|
|
category_counts[label] += 1 |
|
|
|
|
|
for category, count in category_counts.items(): |
|
limit = CATEGORY_LIMITS[category] |
|
if count > limit: |
|
return False, f"Too many {category} annotations. Maximum allowed: {limit}" |
|
|
|
return True, "" |
|
|
|
def add_annotation(self, bbox_data): |
|
"""Add or update annotations for an image""" |
|
is_valid, error_msg = self.validate_annotations(bbox_data) |
|
if not is_valid: |
|
return self.get_json_annotations(), f"β Error: {error_msg}" |
|
|
|
image_path, annotations = bbox_data |
|
filename = os.path.basename(image_path) |
|
formatted_annotations = [] |
|
for ann in annotations: |
|
y1, y2, x1, x2, label = ann |
|
formatted_annotations.append({ |
|
"annotation": [y1, y2, x1, x2], |
|
"label": label |
|
}) |
|
self.annotations[filename] = formatted_annotations |
|
|
|
|
|
counts = {cat: sum(1 for ann in annotations if ann[4] == cat) for cat in CATEGORIES} |
|
counts_str = ", ".join(f"{count} {cat}" for cat, count in counts.items()) |
|
success_msg = f"β
Successfully saved for {filename}: {counts_str}" |
|
|
|
return self.get_json_annotations(), success_msg |
|
|
|
def get_json_annotations(self): |
|
"""Get all annotations as formatted JSON string""" |
|
return json.dumps(self.annotations, indent=2) |
|
|
|
def clear_annotations(self): |
|
"""Clear all annotations""" |
|
self.annotations = {} |
|
return "", "ποΈ All annotations cleared" |
|
|
|
def create_interface(): |
|
annotation_mgr = AnnotationManager() |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(""" |
|
# Advertisement and Text Annotation Tool |
|
|
|
**Instructions:** |
|
1. Upload an image using the upload button in the annotator |
|
2. Draw bounding boxes and select the appropriate label |
|
3. Click 'Save Annotations' to add to the collection |
|
4. Repeat for all images |
|
5. Copy the combined JSON when finished |
|
|
|
**Annotation Limits per Image:** |
|
- advertisement: Maximum 1 annotation |
|
- text: Maximum 2 annotations |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
bbox_input = BBoxAnnotator( |
|
show_label=True, |
|
label="Draw Bounding Boxes", |
|
show_download_button=True, |
|
interactive=True, |
|
categories=CATEGORIES |
|
) |
|
|
|
with gr.Column(scale=1): |
|
json_output = gr.TextArea( |
|
label="Combined Annotations JSON", |
|
interactive=True, |
|
lines=15, |
|
show_copy_button=True |
|
) |
|
|
|
with gr.Row(): |
|
save_btn = gr.Button("Save Current Image Annotations", variant="primary") |
|
clear_btn = gr.Button("Clear All Annotations", variant="secondary") |
|
|
|
|
|
status_msg = gr.Markdown(label="Status") |
|
|
|
|
|
save_btn.click( |
|
fn=annotation_mgr.add_annotation, |
|
inputs=[bbox_input], |
|
outputs=[json_output, status_msg] |
|
) |
|
|
|
clear_btn.click( |
|
fn=annotation_mgr.clear_annotations, |
|
inputs=[], |
|
outputs=[json_output, status_msg] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.launch() |