vvmnnnkv's picture
interesting findings
1cf572f
import sys
# Mock audio modules to avoid installing them
sys.modules["audioop"] = type("audioop", (), {"__file__": ""})()
sys.modules["pyaudioop"] = type("pyaudioop", (), {"__file__": ""})()
import torch
import gradio as gr
import supervision as sv
import spaces
from PIL import Image
from transformers import AutoProcessor, Owlv2ForObjectDetection, Owlv2Processor
from transformers.models.owlv2.modeling_owlv2 import Owlv2ImageGuidedObjectDetectionOutput, center_to_corners_format, box_iou
#from transformers.models.owlv2.image_processing_owlv2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@spaces.GPU
def init_model(model_id):
processor = AutoProcessor.from_pretrained(model_id)
model = Owlv2ForObjectDetection.from_pretrained(model_id)
model.eval()
model.to(DEVICE)
image_size = tuple(processor.image_processor.size.values())
image_mean = torch.tensor(
processor.image_processor.image_mean, device=DEVICE
).view(1, 3, 1, 1)
image_std = torch.tensor(
processor.image_processor.image_std, device=DEVICE
).view(1, 3, 1, 1)
return processor, model, image_size, image_mean, image_std
@spaces.GPU
def inference(prompts, target_image, model_id, conf_thresh, iou_thresh, prompt_type):
processor, model, image_size, image_mean, image_std = init_model(model_id)
annotated_image_my = None
annotated_image_hf = None
annotated_prompt_image = None
if prompt_type == "Text":
inputs = processor(
images=target_image,
text=prompts["texts"],
return_tensors="pt"
).to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
target_sizes = torch.tensor([target_image.size[::-1]])
result = processor.post_process_grounded_object_detection(
outputs=outputs,
target_sizes=target_sizes,
threshold=conf_thresh
)[0]
class_names = {k: v for k, v in enumerate(prompts["texts"])}
# annotate the target image
annotated_image_hf = annotate_image(result, class_names, target_image)
elif prompt_type == "Visual":
prompt_image = prompts["images"]
inputs = processor(
images=target_image,
query_images=prompt_image,
return_tensors="pt"
).to(DEVICE)
with torch.no_grad():
query_feature_map = model.image_embedder(pixel_values=inputs.query_pixel_values)[0]
feature_map = model.image_embedder(pixel_values=inputs.pixel_values)[0]
batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape
image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim))
batch_size, num_patches_height, num_patches_width, hidden_dim = query_feature_map.shape
query_image_feats = torch.reshape(query_feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim))
# Select using hf method
query_embeds2, box_indices, pred_boxes = model.embed_image_query(
query_image_features=query_image_feats,
query_feature_map=query_feature_map
)
# Select top object from prompt image * iou
objectnesses = torch.sigmoid(model.objectness_predictor(query_image_feats))
_, source_class_embeddings = model.class_predictor(query_image_feats)
# identify the box that covers only the prompt image area excluding padding
pw, ph = prompt_image.size
max_side = max(pw, ph)
each_query_box = torch.tensor([[0, 0, pw/max_side, ph/max_side]], device=DEVICE)
pred_boxes_as_corners = center_to_corners_format(pred_boxes)
each_query_pred_boxes = pred_boxes_as_corners[0]
ious, _ = box_iou(each_query_box, each_query_pred_boxes)
comb_score = objectnesses * ious
top_obj_idx = torch.argmax(comb_score, dim=-1)
query_embeds = source_class_embeddings[0][top_obj_idx]
# Predict object boxes
target_pred_boxes = model.box_predictor(image_feats, feature_map)
# Predict for prompt: my method
(pred_logits, class_embeds) = model.class_predictor(image_feats=image_feats, query_embeds=query_embeds)
outputs = Owlv2ImageGuidedObjectDetectionOutput(
logits=pred_logits,
target_pred_boxes=target_pred_boxes,
)
# Post-process results
target_sizes = torch.tensor([target_image.size[::-1]])
result = processor.post_process_image_guided_detection(
outputs=outputs,
target_sizes=target_sizes,
threshold=conf_thresh,
nms_threshold=iou_thresh
)[0]
# prepare for supervision: add 0 label for all boxes
result['labels'] = torch.zeros(len(result['boxes']), dtype=torch.int64)
class_names = {0: "object"}
# annotate the target image
annotated_image_my = annotate_image(result, class_names, pad_to_square(target_image))
# Predict for prompt: hf method
(pred_logits, class_embeds) = model.class_predictor(image_feats=image_feats, query_embeds=query_embeds2)
# Predict object boxes
outputs = Owlv2ImageGuidedObjectDetectionOutput(
logits=pred_logits,
target_pred_boxes=target_pred_boxes,
)
# Post-process results
target_sizes = torch.tensor([target_image.size[::-1]])
result = processor.post_process_image_guided_detection(
outputs=outputs,
target_sizes=target_sizes,
threshold=conf_thresh,
nms_threshold=iou_thresh
)[0]
# prepare for supervision: add 0 label for all boxes
result['labels'] = torch.zeros(len(result['boxes']), dtype=torch.int64)
class_names = {0: "object"}
# annotate the target image
annotated_image_hf = annotate_image(result, class_names, pad_to_square(target_image))
# Render selected prompt embedding
query_pred_boxes = pred_boxes[0, [top_obj_idx, box_indices[0]]].unsqueeze(0)
query_logits = torch.reshape(objectnesses[0, [top_obj_idx, box_indices[0]]], (1, 2, 1))
query_outputs = Owlv2ImageGuidedObjectDetectionOutput(
logits=query_logits,
target_pred_boxes=query_pred_boxes,
)
query_result = processor.post_process_image_guided_detection(
outputs=query_outputs,
target_sizes=torch.tensor([prompt_image.size[::-1]]),
threshold=0.0,
nms_threshold=1.0
)[0]
query_result['labels'] = torch.Tensor([0, 1])
# Annotate the prompt image
query_class_names = {0: "my", 1: "hf"}
# annotate the prompt image
annotated_prompt_image = annotate_image(query_result, query_class_names, pad_to_square(prompt_image))
return annotated_image_my, annotated_image_hf, annotated_prompt_image
def annotate_image(result, class_names, image):
detections = sv.Detections.from_transformers(result, class_names)
resolution_wh = image.size
thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh)
text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh)
labels = [
f"{class_name} {confidence:.2f}"
for class_name, confidence
in zip(detections['class_name'], detections.confidence)
]
annotated_image = image.copy()
annotated_image = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=thickness).annotate(
scene=annotated_image, detections=detections)
annotated_image = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX, text_scale=text_scale, smart_position=True).annotate(
scene=annotated_image, detections=detections, labels=labels)
return annotated_image
def pad_to_square(image, background_color=(128, 128, 128)):
width, height = image.size
max_side = max(width, height)
result = Image.new(image.mode, (max_side, max_side), background_color)
result.paste(image, (0, 0))
return result
def app():
with gr.Blocks():
with gr.Row():
with gr.Column():
target_image = gr.Image(type="pil", label="Target Image", visible=True, interactive=True)
detect_button = gr.Button(value="Detect Objects")
prompt_type = gr.Textbox(value='Visual', visible=False) # Default prompt type
with gr.Tab("Visual") as visual_tab:
prompt_image = gr.Image(type="pil", label="Prompt Image", visible=True, interactive=True)
with gr.Tab("Text") as text_tab:
texts = gr.Textbox(label="Input Texts", value='', placeholder='person,bus', visible=True, interactive=True)
model_id = gr.Dropdown(
label="Model",
choices=[
"google/owlv2-base-patch16-ensemble",
"google/owlv2-large-patch14-ensemble"
],
value="google/owlv2-base-patch16-ensemble",
)
conf_thresh = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.25,
)
iou_thresh = gr.Slider(
label="NSM Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.70,
)
with gr.Column():
output_image_hf_gr = gr.Group()
with output_image_hf_gr:
gr.Markdown("### Annotated Image (HF default)")
output_image_hf = gr.Image(type="numpy", visible=True, show_label=False)
output_image_my_gr = gr.Group()
with output_image_my_gr:
gr.Markdown("### Annotated Image (Objectness Γ— IoU variant)")
output_image_my = gr.Image(type="numpy", visible=True, show_label=False)
annotated_prompt_image_gr = gr.Group()
with annotated_prompt_image_gr:
gr.Markdown("### Prompt Image with Selected Embeddings and Objectness Score")
annotated_prompt_image = gr.Image(type="numpy", visible=True, show_label=False)
visual_tab.select(
fn=lambda: ("Visual", gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)),
inputs=None,
outputs=[prompt_type, prompt_image, output_image_my_gr, annotated_prompt_image_gr]
)
text_tab.select(
fn=lambda: ("Text", gr.update(value=None, visible=False), gr.update(visible=False), gr.update(visible=False)),
inputs=None,
outputs=[prompt_type, prompt_image, output_image_my_gr, annotated_prompt_image_gr]
)
def run_inference(prompt_image, target_image, texts, model_id, conf_thresh, iou_thresh, prompt_type):
# add text/built-in prompts
if prompt_type == "Text":
texts = [text.strip() for text in texts.split(',')]
prompts = {
"texts": texts
}
# add visual prompt
elif prompt_type == "Visual":
prompts = {
"images": prompt_image,
}
return inference(prompts, target_image, model_id, conf_thresh, iou_thresh, prompt_type)
detect_button.click(
fn=run_inference,
inputs=[prompt_image, target_image, texts, model_id, conf_thresh, iou_thresh, prompt_type],
outputs=[output_image_my, output_image_hf, annotated_prompt_image],
)
###################### Examples ##########################
image_examples_list = [[
"test-data/target1.jpg",
"test-data/prompt1.jpg",
"google/owlv2-base-patch16-ensemble",
0.9,
0.3,
],
[
"test-data/target2.jpg",
"test-data/prompt2.jpg",
"google/owlv2-base-patch16-ensemble",
0.9,
0.3,
],
[
"test-data/target3.jpg",
"test-data/prompt3.jpg",
"google/owlv2-base-patch16-ensemble",
0.9,
0.3,
],
[
"test-data/target4.jpg",
"test-data/prompt4.jpg",
"google/owlv2-base-patch16-ensemble",
0.9,
0.3,
],
[
"test-data/target5.jpg",
"test-data/prompt5.jpg",
"google/owlv2-base-patch16-ensemble",
0.9,
0.3,
],
[
"test-data/target6.jpg",
"test-data/prompt6.jpg",
"google/owlv2-base-patch16-ensemble",
0.9,
0.3,
]
]
text_examples = gr.Examples(
examples=[[
"test-data/target1.jpg",
"logo",
"google/owlv2-base-patch16-ensemble",
0.3],
[
"test-data/target2.jpg",
"cat,remote",
"google/owlv2-base-patch16-ensemble",
0.3],
[
"test-data/target3.jpg",
"frog,spider,lizard",
"google/owlv2-base-patch16-ensemble",
0.3],
[
"test-data/target4.jpg",
"cat",
"google/owlv2-base-patch16-ensemble",
0.3
],
[
"test-data/target5.jpg",
"lemon,straw",
"google/owlv2-base-patch16-ensemble",
0.3
],
[
"test-data/target6.jpg",
"beer logo",
"google/owlv2-base-patch16-ensemble",
0.3
]
],
inputs=[target_image, texts, model_id, conf_thresh],
visible=False, cache_examples=False, label="Text Prompt Examples")
image_examples = gr.Examples(
examples=image_examples_list,
inputs=[target_image, prompt_image, model_id, conf_thresh, iou_thresh],
visible=True, cache_examples=False, label="Box Visual Prompt Examples")
# Examples update
def update_text_examples():
return gr.Dataset(visible=True), gr.Dataset(visible=False), gr.update(visible=False)
def update_visual_examples():
return gr.Dataset(visible=False), gr.Dataset(visible=True), gr.update(visible=True)
text_tab.select(
fn=update_text_examples,
inputs=None,
outputs=[text_examples.dataset, image_examples.dataset, iou_thresh]
)
visual_tab.select(
fn=update_visual_examples,
inputs=None,
outputs=[text_examples.dataset, image_examples.dataset, iou_thresh]
)
return target_image, prompt_image, model_id, conf_thresh, iou_thresh, image_examples_list
gradio_app = gr.Blocks()
with gradio_app:
gr.HTML(
"""
<h1 style='text-align: center'>OWLv2: Zero-shot detection with visual prompt πŸ‘€</h1>
""")
gr.Markdown("""
This demo showcases the OWLv2 model's ability to perform zero-shot object detection using visual and text prompts.
You can either provide a text prompt or an image as a visual prompt to detect objects in the target image.
Additionally, it compares different approaches for selecting a query embedding from a visual prompt. The method used in Hugging Face's `transformers` by default often underperforms because of how the visual prompt embedding is selected (see README.md for more details).
""")
with gr.Row():
with gr.Column():
# Create a list of all UI components
ui_components = app()
# Unpack the components
target_image, prompt_image, model_id, conf_thresh, iou_thresh, image_examples_list = ui_components
gradio_app.load(
fn=lambda: image_examples_list[1],
outputs=[target_image, prompt_image, model_id, conf_thresh, iou_thresh]
)
gradio_app.launch(allowed_paths=["figures"])