|
from typing import List |
|
|
|
import os |
|
import cv2 |
|
import gradio as gr |
|
import numpy as np |
|
import supervision as sv |
|
import torch |
|
from tqdm import tqdm |
|
from inference.models import YOLOWorld |
|
|
|
from utils.efficient_sam import load, inference_with_boxes |
|
from utils.video import ( |
|
generate_file_name, |
|
calculate_end_frame_index, |
|
create_directory, |
|
remove_files_older_than |
|
) |
|
|
|
MARKDOWN = """ |
|
# Zero-Shot Object Detection & Segmentation Demonstration 2024-04-10 |
|
""" |
|
|
|
RESULTS = "results" |
|
|
|
IMAGE_EXAMPLES = [ |
|
['https://media.roboflow.com/dog.jpeg', 'dog, eye, nose, tongue, car', 0.005, 0.1, True, False, False], |
|
['https://media.roboflow.com/albert-4x.png', 'hand, hair', 0.005, 0.1, True, False, False], |
|
] |
|
VIDEO_EXAMPLES = [ |
|
['https://media.roboflow.com/supervision/video-examples/croissant-1280x720.mp4', 'croissant', 0.01, 0.2, False, False, False], |
|
['https://media.roboflow.com/supervision/video-examples/suitcases-1280x720.mp4', 'suitcase', 0.1, 0.2, False, False, False], |
|
['https://media.roboflow.com/supervision/video-examples/tokyo-walk-1280x720.mp4', 'woman walking', 0.1, 0.2, False, False, False], |
|
['https://media.roboflow.com/supervision/video-examples/wooly-mammoth-1280x720.mp4', 'mammoth', 0.01, 0.2, False, False, False], |
|
] |
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
EFFICIENT_SAM_MODEL = load(device=DEVICE) |
|
YOLO_WORLD_MODEL = YOLOWorld(model_id="yolo_world/l") |
|
|
|
BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator() |
|
MASK_ANNOTATOR = sv.MaskAnnotator() |
|
LABEL_ANNOTATOR = sv.LabelAnnotator() |
|
|
|
|
|
create_directory(directory_path=RESULTS) |
|
|
|
|
|
def process_categories(categories: str) -> List[str]: |
|
return [category.strip() for category in categories.split(',')] |
|
|
|
|
|
def annotate_image( |
|
input_image: np.ndarray, |
|
detections: sv.Detections, |
|
categories: List[str], |
|
with_confidence: bool = False, |
|
) -> np.ndarray: |
|
labels = [ |
|
( |
|
f"{categories[class_id]}: {confidence:.3f}" |
|
if with_confidence |
|
else f"{categories[class_id]}" |
|
) |
|
for class_id, confidence in |
|
zip(detections.class_id, detections.confidence) |
|
] |
|
output_image = MASK_ANNOTATOR.annotate(input_image, detections) |
|
output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections) |
|
output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels) |
|
return output_image |
|
|
|
|
|
def process_image( |
|
input_image: np.ndarray, |
|
categories: str, |
|
confidence_threshold: float = 0.3, |
|
iou_threshold: float = 0.5, |
|
with_segmentation: bool = True, |
|
with_confidence: bool = False, |
|
with_class_agnostic_nms: bool = False, |
|
) -> np.ndarray: |
|
|
|
remove_files_older_than(RESULTS, 30) |
|
|
|
categories = process_categories(categories) |
|
YOLO_WORLD_MODEL.set_classes(categories) |
|
results = YOLO_WORLD_MODEL.infer(input_image, confidence=confidence_threshold) |
|
detections = sv.Detections.from_inference(results) |
|
detections = detections.with_nms( |
|
class_agnostic=with_class_agnostic_nms, |
|
threshold=iou_threshold |
|
) |
|
if with_segmentation: |
|
detections.mask = inference_with_boxes( |
|
image=input_image, |
|
xyxy=detections.xyxy, |
|
model=EFFICIENT_SAM_MODEL, |
|
device=DEVICE |
|
) |
|
output_image = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR) |
|
output_image = annotate_image( |
|
input_image=output_image, |
|
detections=detections, |
|
categories=categories, |
|
with_confidence=with_confidence |
|
) |
|
return cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
def process_video( |
|
input_video: str, |
|
categories: str, |
|
confidence_threshold: float = 0.3, |
|
iou_threshold: float = 0.5, |
|
with_segmentation: bool = True, |
|
with_confidence: bool = False, |
|
with_class_agnostic_nms: bool = False, |
|
progress=gr.Progress(track_tqdm=True) |
|
) -> str: |
|
|
|
remove_files_older_than(RESULTS, 30) |
|
|
|
categories = process_categories(categories) |
|
YOLO_WORLD_MODEL.set_classes(categories) |
|
video_info = sv.VideoInfo.from_video_path(input_video) |
|
total = calculate_end_frame_index(input_video) |
|
frame_generator = sv.get_video_frames_generator( |
|
source_path=input_video, |
|
end=total |
|
) |
|
result_file_name = generate_file_name(extension="mp4") |
|
result_file_path = os.path.join(RESULTS, result_file_name) |
|
with sv.VideoSink(result_file_path, video_info=video_info) as sink: |
|
for _ in tqdm(range(total), desc="Processing video..."): |
|
frame = next(frame_generator) |
|
results = YOLO_WORLD_MODEL.infer(frame, confidence=confidence_threshold) |
|
detections = sv.Detections.from_inference(results) |
|
detections = detections.with_nms( |
|
class_agnostic=with_class_agnostic_nms, |
|
threshold=iou_threshold |
|
) |
|
if with_segmentation: |
|
detections.mask = inference_with_boxes( |
|
image=frame, |
|
xyxy=detections.xyxy, |
|
model=EFFICIENT_SAM_MODEL, |
|
device=DEVICE |
|
) |
|
frame = annotate_image( |
|
input_image=frame, |
|
detections=detections, |
|
categories=categories, |
|
with_confidence=with_confidence |
|
) |
|
sink.write_frame(frame) |
|
return result_file_path |
|
|
|
|
|
confidence_threshold_component = gr.Slider( |
|
minimum=0, |
|
maximum=1.0, |
|
value=0.3, |
|
step=0.01, |
|
label="Confidence Threshold", |
|
info=( |
|
"The confidence threshold for the YOLO-World model. Lower the threshold to " |
|
"reduce false negatives, enhancing the model's sensitivity to detect " |
|
"sought-after objects. Conversely, increase the threshold to minimize false " |
|
"positives, preventing the model from identifying objects it shouldn't." |
|
)) |
|
|
|
iou_threshold_component = gr.Slider( |
|
minimum=0, |
|
maximum=1.0, |
|
value=0.5, |
|
step=0.01, |
|
label="IoU Threshold", |
|
info=( |
|
"The Intersection over Union (IoU) threshold for non-maximum suppression. " |
|
"Decrease the value to lessen the occurrence of overlapping bounding boxes, " |
|
"making the detection process stricter. On the other hand, increase the value " |
|
"to allow more overlapping bounding boxes, accommodating a broader range of " |
|
"detections." |
|
)) |
|
|
|
with_segmentation_component = gr.Checkbox( |
|
value=True, |
|
label="With Segmentation", |
|
info=( |
|
"Whether to run EfficientSAM for instance segmentation." |
|
) |
|
) |
|
|
|
with_confidence_component = gr.Checkbox( |
|
value=False, |
|
label="Display Confidence", |
|
info=( |
|
"Whether to display the confidence of the detected objects." |
|
) |
|
) |
|
|
|
with_class_agnostic_nms_component = gr.Checkbox( |
|
value=False, |
|
label="Use Class-Agnostic NMS", |
|
info=( |
|
"Suppress overlapping bounding boxes across all classes." |
|
) |
|
) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(MARKDOWN) |
|
with gr.Accordion("Configuration", open=False): |
|
confidence_threshold_component.render() |
|
iou_threshold_component.render() |
|
with gr.Row(): |
|
with_segmentation_component.render() |
|
with_confidence_component.render() |
|
with_class_agnostic_nms_component.render() |
|
with gr.Tab(label="Image"): |
|
with gr.Row(): |
|
input_image_component = gr.Image( |
|
type='numpy', |
|
label='Input Image' |
|
) |
|
output_image_component = gr.Image( |
|
type='numpy', |
|
label='Output Image' |
|
) |
|
with gr.Row(): |
|
image_categories_text_component = gr.Textbox( |
|
label='Categories', |
|
placeholder='comma separated list of categories', |
|
scale=7 |
|
) |
|
image_submit_button_component = gr.Button( |
|
value='Submit', |
|
scale=1, |
|
variant='primary' |
|
) |
|
gr.Examples( |
|
fn=process_image, |
|
examples=IMAGE_EXAMPLES, |
|
inputs=[ |
|
input_image_component, |
|
image_categories_text_component, |
|
confidence_threshold_component, |
|
iou_threshold_component, |
|
with_segmentation_component, |
|
with_confidence_component, |
|
with_class_agnostic_nms_component |
|
], |
|
outputs=output_image_component |
|
) |
|
with gr.Tab(label="Video"): |
|
with gr.Row(): |
|
input_video_component = gr.Video( |
|
label='Input Video' |
|
) |
|
output_video_component = gr.Video( |
|
label='Output Video' |
|
) |
|
with gr.Row(): |
|
video_categories_text_component = gr.Textbox( |
|
label='Categories', |
|
placeholder='comma separated list of categories', |
|
scale=7 |
|
) |
|
video_submit_button_component = gr.Button( |
|
value='Submit', |
|
scale=1, |
|
variant='primary' |
|
) |
|
gr.Examples( |
|
fn=process_video, |
|
examples=VIDEO_EXAMPLES, |
|
inputs=[ |
|
input_video_component, |
|
video_categories_text_component, |
|
confidence_threshold_component, |
|
iou_threshold_component, |
|
with_segmentation_component, |
|
with_confidence_component, |
|
with_class_agnostic_nms_component |
|
], |
|
outputs=output_image_component |
|
) |
|
|
|
image_submit_button_component.click( |
|
fn=process_image, |
|
inputs=[ |
|
input_image_component, |
|
image_categories_text_component, |
|
confidence_threshold_component, |
|
iou_threshold_component, |
|
with_segmentation_component, |
|
with_confidence_component, |
|
with_class_agnostic_nms_component |
|
], |
|
outputs=output_image_component |
|
) |
|
video_submit_button_component.click( |
|
fn=process_video, |
|
inputs=[ |
|
input_video_component, |
|
video_categories_text_component, |
|
confidence_threshold_component, |
|
iou_threshold_component, |
|
with_segmentation_component, |
|
with_confidence_component, |
|
with_class_agnostic_nms_component |
|
], |
|
outputs=output_video_component |
|
) |
|
|
|
demo.launch(debug=False, show_error=True, max_threads=1) |
|
|