import sys # Mock audio modules to avoid installing them sys.modules["audioop"] = type("audioop", (), {"__file__": ""})() sys.modules["pyaudioop"] = type("pyaudioop", (), {"__file__": ""})() import torch import gradio as gr import supervision as sv import spaces from PIL import Image from transformers import AutoProcessor, Owlv2ForObjectDetection, Owlv2Processor from transformers.models.owlv2.modeling_owlv2 import Owlv2ImageGuidedObjectDetectionOutput, center_to_corners_format, box_iou #from transformers.models.owlv2.image_processing_owlv2 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @spaces.GPU def init_model(model_id): processor = AutoProcessor.from_pretrained(model_id) model = Owlv2ForObjectDetection.from_pretrained(model_id) model.eval() model.to(DEVICE) image_size = tuple(processor.image_processor.size.values()) image_mean = torch.tensor( processor.image_processor.image_mean, device=DEVICE ).view(1, 3, 1, 1) image_std = torch.tensor( processor.image_processor.image_std, device=DEVICE ).view(1, 3, 1, 1) return processor, model, image_size, image_mean, image_std @spaces.GPU def inference(prompts, target_image, model_id, conf_thresh, iou_thresh, prompt_type): processor, model, image_size, image_mean, image_std = init_model(model_id) annotated_image_my = None annotated_image_hf = None annotated_prompt_image = None if prompt_type == "Text": inputs = processor( images=target_image, text=prompts["texts"], return_tensors="pt" ).to(DEVICE) with torch.no_grad(): outputs = model(**inputs) target_sizes = torch.tensor([target_image.size[::-1]]) result = processor.post_process_grounded_object_detection( outputs=outputs, target_sizes=target_sizes, threshold=conf_thresh )[0] class_names = {k: v for k, v in enumerate(prompts["texts"])} # annotate the target image annotated_image_hf = annotate_image(result, class_names, target_image) elif prompt_type == "Visual": prompt_image = prompts["images"] inputs = processor( images=target_image, query_images=prompt_image, return_tensors="pt" ).to(DEVICE) with torch.no_grad(): query_feature_map = model.image_embedder(pixel_values=inputs.query_pixel_values)[0] feature_map = model.image_embedder(pixel_values=inputs.pixel_values)[0] batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)) batch_size, num_patches_height, num_patches_width, hidden_dim = query_feature_map.shape query_image_feats = torch.reshape(query_feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)) # Select using hf method query_embeds2, box_indices, pred_boxes = model.embed_image_query( query_image_features=query_image_feats, query_feature_map=query_feature_map ) # Select top object from prompt image * iou objectnesses = torch.sigmoid(model.objectness_predictor(query_image_feats)) _, source_class_embeddings = model.class_predictor(query_image_feats) # identify the box that covers only the prompt image area excluding padding pw, ph = prompt_image.size max_side = max(pw, ph) each_query_box = torch.tensor([[0, 0, pw/max_side, ph/max_side]], device=DEVICE) pred_boxes_as_corners = center_to_corners_format(pred_boxes) each_query_pred_boxes = pred_boxes_as_corners[0] ious, _ = box_iou(each_query_box, each_query_pred_boxes) comb_score = objectnesses * ious top_obj_idx = torch.argmax(comb_score, dim=-1) query_embeds = source_class_embeddings[0][top_obj_idx] # Predict object boxes target_pred_boxes = model.box_predictor(image_feats, feature_map) # Predict for prompt: my method (pred_logits, class_embeds) = model.class_predictor(image_feats=image_feats, query_embeds=query_embeds) outputs = Owlv2ImageGuidedObjectDetectionOutput( logits=pred_logits, target_pred_boxes=target_pred_boxes, ) # Post-process results target_sizes = torch.tensor([target_image.size[::-1]]) result = processor.post_process_image_guided_detection( outputs=outputs, target_sizes=target_sizes, threshold=conf_thresh, nms_threshold=iou_thresh )[0] # prepare for supervision: add 0 label for all boxes result['labels'] = torch.zeros(len(result['boxes']), dtype=torch.int64) class_names = {0: "object"} # annotate the target image annotated_image_my = annotate_image(result, class_names, pad_to_square(target_image)) # Predict for prompt: hf method (pred_logits, class_embeds) = model.class_predictor(image_feats=image_feats, query_embeds=query_embeds2) # Predict object boxes outputs = Owlv2ImageGuidedObjectDetectionOutput( logits=pred_logits, target_pred_boxes=target_pred_boxes, ) # Post-process results target_sizes = torch.tensor([target_image.size[::-1]]) result = processor.post_process_image_guided_detection( outputs=outputs, target_sizes=target_sizes, threshold=conf_thresh, nms_threshold=iou_thresh )[0] # prepare for supervision: add 0 label for all boxes result['labels'] = torch.zeros(len(result['boxes']), dtype=torch.int64) class_names = {0: "object"} # annotate the target image annotated_image_hf = annotate_image(result, class_names, pad_to_square(target_image)) # Render selected prompt embedding query_pred_boxes = pred_boxes[0, [top_obj_idx, box_indices[0]]].unsqueeze(0) query_logits = torch.reshape(objectnesses[0, [top_obj_idx, box_indices[0]]], (1, 2, 1)) query_outputs = Owlv2ImageGuidedObjectDetectionOutput( logits=query_logits, target_pred_boxes=query_pred_boxes, ) query_result = processor.post_process_image_guided_detection( outputs=query_outputs, target_sizes=torch.tensor([prompt_image.size[::-1]]), threshold=0.0, nms_threshold=1.0 )[0] query_result['labels'] = torch.Tensor([0, 1]) # Annotate the prompt image query_class_names = {0: "my", 1: "hf"} # annotate the prompt image annotated_prompt_image = annotate_image(query_result, query_class_names, pad_to_square(prompt_image)) return annotated_image_my, annotated_image_hf, annotated_prompt_image def annotate_image(result, class_names, image): detections = sv.Detections.from_transformers(result, class_names) resolution_wh = image.size thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh) text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh) labels = [ f"{class_name} {confidence:.2f}" for class_name, confidence in zip(detections['class_name'], detections.confidence) ] annotated_image = image.copy() annotated_image = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=thickness).annotate( scene=annotated_image, detections=detections) annotated_image = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX, text_scale=text_scale, smart_position=True).annotate( scene=annotated_image, detections=detections, labels=labels) return annotated_image def pad_to_square(image, background_color=(128, 128, 128)): width, height = image.size max_side = max(width, height) result = Image.new(image.mode, (max_side, max_side), background_color) result.paste(image, (0, 0)) return result def app(): with gr.Blocks(): with gr.Row(): with gr.Column(): target_image = gr.Image(type="pil", label="Target Image", visible=True, interactive=True) detect_button = gr.Button(value="Detect Objects") prompt_type = gr.Textbox(value='Visual', visible=False) # Default prompt type with gr.Tab("Visual") as visual_tab: prompt_image = gr.Image(type="pil", label="Prompt Image", visible=True, interactive=True) with gr.Tab("Text") as text_tab: texts = gr.Textbox(label="Input Texts", value='', placeholder='person,bus', visible=True, interactive=True) model_id = gr.Dropdown( label="Model", choices=[ "google/owlv2-base-patch16-ensemble", "google/owlv2-large-patch14-ensemble" ], value="google/owlv2-base-patch16-ensemble", ) conf_thresh = gr.Slider( label="Confidence Threshold", minimum=0.0, maximum=1.0, step=0.05, value=0.25, ) iou_thresh = gr.Slider( label="NSM Threshold", minimum=0.0, maximum=1.0, step=0.05, value=0.70, ) with gr.Column(): output_image_hf_gr = gr.Group() with output_image_hf_gr: gr.Markdown("### Annotated Image (HF default)") output_image_hf = gr.Image(type="numpy", visible=True, show_label=False) output_image_my_gr = gr.Group() with output_image_my_gr: gr.Markdown("### Annotated Image (Objectness × IoU variant)") output_image_my = gr.Image(type="numpy", visible=True, show_label=False) annotated_prompt_image_gr = gr.Group() with annotated_prompt_image_gr: gr.Markdown("### Prompt Image with Selected Embeddings and Objectness Score") annotated_prompt_image = gr.Image(type="numpy", visible=True, show_label=False) visual_tab.select( fn=lambda: ("Visual", gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)), inputs=None, outputs=[prompt_type, prompt_image, output_image_my_gr, annotated_prompt_image_gr] ) text_tab.select( fn=lambda: ("Text", gr.update(value=None, visible=False), gr.update(visible=False), gr.update(visible=False)), inputs=None, outputs=[prompt_type, prompt_image, output_image_my_gr, annotated_prompt_image_gr] ) def run_inference(prompt_image, target_image, texts, model_id, conf_thresh, iou_thresh, prompt_type): # add text/built-in prompts if prompt_type == "Text": texts = [text.strip() for text in texts.split(',')] prompts = { "texts": texts } # add visual prompt elif prompt_type == "Visual": prompts = { "images": prompt_image, } return inference(prompts, target_image, model_id, conf_thresh, iou_thresh, prompt_type) detect_button.click( fn=run_inference, inputs=[prompt_image, target_image, texts, model_id, conf_thresh, iou_thresh, prompt_type], outputs=[output_image_my, output_image_hf, annotated_prompt_image], ) ###################### Examples ########################## image_examples_list = [[ "test-data/target1.jpg", "test-data/prompt1.jpg", "google/owlv2-base-patch16-ensemble", 0.9, 0.3, ], [ "test-data/target2.jpg", "test-data/prompt2.jpg", "google/owlv2-base-patch16-ensemble", 0.9, 0.3, ], [ "test-data/target3.jpg", "test-data/prompt3.jpg", "google/owlv2-base-patch16-ensemble", 0.9, 0.3, ], [ "test-data/target4.jpg", "test-data/prompt4.jpg", "google/owlv2-base-patch16-ensemble", 0.9, 0.3, ], [ "test-data/target5.jpg", "test-data/prompt5.jpg", "google/owlv2-base-patch16-ensemble", 0.9, 0.3, ], [ "test-data/target6.jpg", "test-data/prompt6.jpg", "google/owlv2-base-patch16-ensemble", 0.9, 0.3, ] ] text_examples = gr.Examples( examples=[[ "test-data/target1.jpg", "logo", "google/owlv2-base-patch16-ensemble", 0.3], [ "test-data/target2.jpg", "cat,remote", "google/owlv2-base-patch16-ensemble", 0.3], [ "test-data/target3.jpg", "frog,spider,lizard", "google/owlv2-base-patch16-ensemble", 0.3], [ "test-data/target4.jpg", "cat", "google/owlv2-base-patch16-ensemble", 0.3 ], [ "test-data/target5.jpg", "lemon,straw", "google/owlv2-base-patch16-ensemble", 0.3 ], [ "test-data/target6.jpg", "beer logo", "google/owlv2-base-patch16-ensemble", 0.3 ] ], inputs=[target_image, texts, model_id, conf_thresh], visible=False, cache_examples=False, label="Text Prompt Examples") image_examples = gr.Examples( examples=image_examples_list, inputs=[target_image, prompt_image, model_id, conf_thresh, iou_thresh], visible=True, cache_examples=False, label="Box Visual Prompt Examples") # Examples update def update_text_examples(): return gr.Dataset(visible=True), gr.Dataset(visible=False), gr.update(visible=False) def update_visual_examples(): return gr.Dataset(visible=False), gr.Dataset(visible=True), gr.update(visible=True) text_tab.select( fn=update_text_examples, inputs=None, outputs=[text_examples.dataset, image_examples.dataset, iou_thresh] ) visual_tab.select( fn=update_visual_examples, inputs=None, outputs=[text_examples.dataset, image_examples.dataset, iou_thresh] ) return target_image, prompt_image, model_id, conf_thresh, iou_thresh, image_examples_list gradio_app = gr.Blocks() with gradio_app: gr.HTML( """