|
|
|
|
|
""" |
|
|
YOLOv10 Single Object Feature Extractor |
|
|
|
|
|
This script extracts features for a specific detected object by its index. |
|
|
It can be used to build feature databases or for targeted object analysis. |
|
|
""" |
|
|
|
|
|
from ultralytics import YOLO |
|
|
from ultralytics.utils.ops import xywh2xyxy, scale_boxes |
|
|
from ultralytics.engine.results import Results |
|
|
import torch |
|
|
import time |
|
|
from torch.nn.functional import cosine_similarity |
|
|
import cv2 |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
import urllib.request |
|
|
import argparse |
|
|
import json |
|
|
|
|
|
from torchvision.ops import RoIAlign as ROIAlign |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from types import MethodType |
|
|
import torchvision |
|
|
import collections |
|
|
|
|
|
|
|
|
|
|
|
def _predict_once(self, x, profile=False, visualize=False, embed=None): |
|
|
y, dt, embeddings = [], [], [] |
|
|
for m in self.model: |
|
|
if m.f != -1: |
|
|
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] |
|
|
if profile: |
|
|
self._profile_one_layer(m, x, dt) |
|
|
x = m(x) |
|
|
y.append(x if m.i in self.save else None) |
|
|
if visualize: |
|
|
feature_visualization(x, m.type, m.i, save_dir=visualize) |
|
|
|
|
|
if embed and m.i in embed: |
|
|
embeddings.append(x) |
|
|
if m.i == max(embed): |
|
|
return embeddings |
|
|
return x |
|
|
|
|
|
|
|
|
def get_yolov10_object_features_with_pooler(feat_list, idxs, boxes, orig_img_shape): |
|
|
""" |
|
|
Extracts object features from YOLOv10 feature maps using RoIAlign. |
|
|
Concatenates features from all levels for each detected object. |
|
|
""" |
|
|
|
|
|
img_size = 640 |
|
|
|
|
|
|
|
|
|
|
|
spatial_scales = [1.0 / 8, 1.0 / 16, 1.0 / 32] |
|
|
|
|
|
num_rois = len(boxes) |
|
|
if num_rois == 0: |
|
|
return [torch.empty(0)], [] |
|
|
|
|
|
|
|
|
zeros = torch.full((num_rois, 1), 0, device=boxes.device, dtype=boxes.dtype) |
|
|
rois = torch.cat((zeros, boxes), dim=1) |
|
|
|
|
|
poolers = [ |
|
|
ROIAlign(output_size=[7, 7], spatial_scale=ss, sampling_ratio=2) for ss in spatial_scales |
|
|
] |
|
|
|
|
|
pooled_feats = [] |
|
|
for feat_map, pooler in zip(feat_list, poolers): |
|
|
pooled_feats.append(pooler(feat_map, rois)) |
|
|
|
|
|
avg_pool = nn.AdaptiveAvgPool2d((1, 1)) |
|
|
|
|
|
pooled_feats_flat = [avg_pool(pf).view(num_rois, -1) for pf in pooled_feats] |
|
|
|
|
|
|
|
|
final_feats = torch.cat(pooled_feats_flat, dim=1) |
|
|
|
|
|
return [final_feats], pooled_feats |
|
|
|
|
|
|
|
|
def get_result_with_features_yolov10_simple(model, imgs, embed_layers, conf=0.25): |
|
|
""" |
|
|
Simplified approach: Use standard YOLO inference first, then extract features. |
|
|
""" |
|
|
if not isinstance(imgs, list): |
|
|
imgs = [imgs] |
|
|
|
|
|
|
|
|
results = model(imgs, verbose=False, conf=conf) |
|
|
|
|
|
|
|
|
for i, result in enumerate(results): |
|
|
if hasattr(result, 'boxes') and len(result.boxes) > 0: |
|
|
|
|
|
prepped = model.predictor.preprocess([result.orig_img]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prev_embed = getattr(model.predictor.args, "embed", None) |
|
|
model.predictor.args.embed = embed_layers |
|
|
|
|
|
|
|
|
features = model.predictor.inference(prepped) |
|
|
|
|
|
|
|
|
model.predictor.args.embed = prev_embed |
|
|
|
|
|
|
|
|
feature_maps = features[:-1] |
|
|
|
|
|
|
|
|
boxes_scaled = result.boxes.xyxy |
|
|
|
|
|
boxes_for_features = scale_boxes(result.orig_img.shape, boxes_scaled.clone(), prepped.shape[2:]) |
|
|
|
|
|
|
|
|
dummy_idxs = [torch.arange(len(boxes_for_features))] |
|
|
|
|
|
|
|
|
obj_feats, pooled_feats = get_yolov10_object_features_with_pooler(feature_maps, dummy_idxs, boxes_for_features, result.orig_img.shape) |
|
|
|
|
|
|
|
|
result.feats = obj_feats[0] if obj_feats else torch.empty(0) |
|
|
result.pooled_feats = pooled_feats |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def draw_debug_image(img, boxes, class_names, save_path="debug_detections.png", highlight_idx=None): |
|
|
"""Draw bounding boxes on the original image for debugging.""" |
|
|
debug_img = img.copy() |
|
|
for i, box in enumerate(boxes): |
|
|
x1, y1, x2, y2 = box.cpu().numpy().astype(int) |
|
|
|
|
|
x1, y1 = max(0, x1), max(0, y1) |
|
|
x2, y2 = min(img.shape[1], x2), min(img.shape[0], y2) |
|
|
|
|
|
|
|
|
color = (0, 0, 255) if i == highlight_idx else (0, 255, 0) |
|
|
thickness = 3 if i == highlight_idx else 2 |
|
|
|
|
|
cv2.rectangle(debug_img, (x1, y1), (x2, y2), color, thickness) |
|
|
cv2.putText(debug_img, f"{class_names[i]} #{i}", (x1, y1-10), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) |
|
|
|
|
|
cv2.imwrite(save_path, debug_img) |
|
|
print(f"Debug image with bounding boxes saved to {save_path}") |
|
|
return debug_img |
|
|
|
|
|
|
|
|
def draw_feature_heatmap(image, box, feature_map): |
|
|
""" |
|
|
Draws a feature map as a heatmap on a specific region of an image. |
|
|
""" |
|
|
|
|
|
feature_map = feature_map.detach().cpu() |
|
|
|
|
|
|
|
|
heatmap = torch.mean(feature_map, dim=0).numpy() |
|
|
|
|
|
|
|
|
if np.max(heatmap) > np.min(heatmap): |
|
|
heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap)) |
|
|
heatmap = (heatmap * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
x1, y1, x2, y2 = box.cpu().numpy().astype(int) |
|
|
x1, y1 = max(0, x1), max(0, y1) |
|
|
x2, y2 = min(image.shape[1], x2), min(image.shape[0], y2) |
|
|
|
|
|
bbox_w, bbox_h = x2 - x1, y2 - y1 |
|
|
if bbox_w <= 0 or bbox_h <= 0: |
|
|
return image |
|
|
|
|
|
|
|
|
heatmap_resized = cv2.resize(heatmap, (bbox_w, bbox_h), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
|
|
|
heatmap_colored = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET) |
|
|
|
|
|
|
|
|
roi = image[y1:y2, x1:x2] |
|
|
|
|
|
|
|
|
overlay = cv2.addWeighted(roi, 0.6, heatmap_colored, 0.4, 0) |
|
|
|
|
|
|
|
|
output_image = image.copy() |
|
|
output_image[y1:y2, x1:x2] = overlay |
|
|
|
|
|
return output_image |
|
|
|
|
|
|
|
|
def draw_filled_rounded_rectangle(img, pt1, pt2, color, radius): |
|
|
"""Draws a filled rounded rectangle.""" |
|
|
x1, y1 = pt1 |
|
|
x2, y2 = pt2 |
|
|
|
|
|
|
|
|
cv2.circle(img, (x1 + radius, y1 + radius), radius, color, -1) |
|
|
cv2.circle(img, (x2 - radius, y1 + radius), radius, color, -1) |
|
|
cv2.circle(img, (x1 + radius, y2 - radius), radius, color, -1) |
|
|
cv2.circle(img, (x2 - radius, y2 - radius), radius, color, -1) |
|
|
|
|
|
|
|
|
cv2.rectangle(img, (x1 + radius, y1), (x2 - radius, y2), color, -1) |
|
|
cv2.rectangle(img, (x1, y1 + radius), (x2, y2 - radius), color, -1) |
|
|
|
|
|
|
|
|
def draw_modern_bbox(image, box, label, color): |
|
|
"""Draws a modern-style bounding box with a semi-transparent, rounded label.""" |
|
|
x1, y1, x2, y2 = box.astype(int) |
|
|
|
|
|
|
|
|
cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness=2) |
|
|
|
|
|
|
|
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
|
font_scale = 0.5 |
|
|
font_thickness = 1 |
|
|
(text_w, text_h), _ = cv2.getTextSize(label, font, font_scale, font_thickness) |
|
|
|
|
|
|
|
|
label_bg_pt1 = (x1, y1 - text_h - 15) |
|
|
label_bg_pt2 = (x1 + text_w + 10, y1) |
|
|
if label_bg_pt1[1] < 0: |
|
|
label_bg_pt1 = (x1, y1 + 5) |
|
|
label_bg_pt2 = (x1 + text_w + 10, y1 + text_h + 20) |
|
|
|
|
|
|
|
|
overlay = image.copy() |
|
|
|
|
|
|
|
|
draw_filled_rounded_rectangle(overlay, label_bg_pt1, label_bg_pt2, color, radius=8) |
|
|
|
|
|
|
|
|
alpha = 0.6 |
|
|
cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image) |
|
|
|
|
|
|
|
|
text_pt = (label_bg_pt1[0] + 5, label_bg_pt1[1] + text_h + 5) |
|
|
cv2.putText(image, label, text_pt, font, font_scale, (0, 0, 0), font_thickness, cv2.LINE_AA) |
|
|
|
|
|
|
|
|
def generate_feature_heatmaps(model, img_path, embed_layers, output_dir="./", conf=0.25): |
|
|
""" |
|
|
Generates a single composite image containing the main image with bounding boxes |
|
|
and separate heatmap snippets for each detected object. |
|
|
|
|
|
Args: |
|
|
model: YOLOv10 model |
|
|
img_path: Path to the input image |
|
|
embed_layers: List of layer indices to extract features from |
|
|
output_dir: Directory to save outputs |
|
|
conf: Object detection confidence threshold |
|
|
""" |
|
|
|
|
|
|
|
|
img = cv2.imread(img_path) |
|
|
if img is None: |
|
|
raise FileNotFoundError(f"Could not read image at {img_path}") |
|
|
|
|
|
print(f"Processing image: {img_path}") |
|
|
|
|
|
|
|
|
results_with_feat = get_result_with_features_yolov10_simple(model, img_path, embed_layers, conf=conf) |
|
|
|
|
|
if not results_with_feat or not isinstance(results_with_feat, list) or len(results_with_feat) == 0: |
|
|
print("No results returned.") |
|
|
return |
|
|
|
|
|
result = results_with_feat[0] |
|
|
if not hasattr(result, 'boxes') or len(result.boxes) == 0: |
|
|
print("No objects detected in the image.") |
|
|
return |
|
|
|
|
|
num_objects = len(result.boxes) |
|
|
print(f"Total objects detected: {num_objects}. Generating composite layout...") |
|
|
|
|
|
|
|
|
all_class_names = [model.model.names[int(cls)] for cls in result.boxes.cls] |
|
|
|
|
|
|
|
|
main_image_with_boxes = img.copy() |
|
|
colors = [(71, 224, 253), (159, 128, 255), (159, 227, 128), (255, 191, 0), (255, 165, 0), (255, 0, 255)] |
|
|
for i in range(num_objects): |
|
|
label = f"{all_class_names[i]} {result.boxes.conf[i]:.2f}" |
|
|
color = colors[i % len(colors)] |
|
|
draw_modern_bbox(main_image_with_boxes, result.boxes.xyxy[i].cpu().numpy(), label, color) |
|
|
|
|
|
|
|
|
heatmap_snippets = [] |
|
|
if hasattr(result, 'pooled_feats') and result.pooled_feats: |
|
|
last_layer_pooled_feats = result.pooled_feats[-1] |
|
|
for i in range(num_objects): |
|
|
box = result.boxes.xyxy[i] |
|
|
feature_map = last_layer_pooled_feats[i] |
|
|
|
|
|
heatmap_on_full = draw_feature_heatmap(img.copy(), box, feature_map) |
|
|
x1, y1, x2, y2 = box.cpu().numpy().astype(int) |
|
|
snippet = heatmap_on_full[y1:y2, x1:x2] |
|
|
|
|
|
label_text = f"Obj #{i}: {all_class_names[i]}" |
|
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
|
(text_w, text_h), _ = cv2.getTextSize(label_text, font, 0.6, 1) |
|
|
|
|
|
h, w, _ = snippet.shape |
|
|
|
|
|
|
|
|
new_w = max(w, text_w + 10) |
|
|
snippet_with_label = np.full((h + text_h + 15, new_w, 3), 255, dtype=np.uint8) |
|
|
|
|
|
|
|
|
paste_x = (new_w - w) // 2 |
|
|
snippet_with_label[0:h, paste_x:paste_x+w] = snippet |
|
|
|
|
|
|
|
|
text_x = (new_w - text_w) // 2 |
|
|
cv2.putText(snippet_with_label, label_text, (text_x, h + text_h + 5), font, 0.6, (0,0,0), 1, cv2.LINE_AA) |
|
|
cv2.rectangle(snippet_with_label, (0,0), (new_w-1, h+text_h+14), (180,180,180), 1) |
|
|
heatmap_snippets.append(snippet_with_label) |
|
|
|
|
|
if not heatmap_snippets: |
|
|
print("No heatmaps generated. Saving image with bounding boxes only.") |
|
|
image_name = Path(img_path).stem |
|
|
save_path = Path(output_dir) / f"{image_name}_layout.png" |
|
|
cv2.imwrite(str(save_path), main_image_with_boxes) |
|
|
return |
|
|
|
|
|
|
|
|
main_h, main_w, _ = main_image_with_boxes.shape |
|
|
padding = 20 |
|
|
|
|
|
|
|
|
snippets_row_h = max(s.shape[0] for s in heatmap_snippets) |
|
|
total_snippets_w = sum(s.shape[1] for s in heatmap_snippets) + (len(heatmap_snippets) - 1) * 10 |
|
|
|
|
|
snippets_row = np.full((snippets_row_h, total_snippets_w, 3), 255, dtype=np.uint8) |
|
|
current_x = 0 |
|
|
for snippet in heatmap_snippets: |
|
|
h, w, _ = snippet.shape |
|
|
paste_y = (snippets_row_h - h) // 2 |
|
|
snippets_row[paste_y:paste_y+h, current_x:current_x+w] = snippet |
|
|
current_x += w + 10 |
|
|
|
|
|
|
|
|
canvas_h = main_h + snippets_row_h + 3 * padding |
|
|
canvas_w = max(main_w, total_snippets_w) + 2 * padding |
|
|
final_image = np.full((canvas_h, canvas_w, 3), 255, dtype=np.uint8) |
|
|
|
|
|
|
|
|
x_offset_main = (canvas_w - main_w) // 2 |
|
|
final_image[padding:padding+main_h, x_offset_main:x_offset_main+main_w] = main_image_with_boxes |
|
|
|
|
|
|
|
|
x_offset_snippets = (canvas_w - total_snippets_w) // 2 |
|
|
y_offset_snippets = main_h + 2 * padding |
|
|
final_image[y_offset_snippets:y_offset_snippets+snippets_row_h, x_offset_snippets:x_offset_snippets+total_snippets_w] = snippets_row |
|
|
|
|
|
|
|
|
image_name = Path(img_path).stem |
|
|
heatmap_path = Path(output_dir) / f"{image_name}_heatmap_layout.png" |
|
|
cv2.imwrite(str(heatmap_path), final_image) |
|
|
print(f" - Saved composite heatmap layout to: {heatmap_path}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='Generate a composite feature heatmap for all detected objects in an image or a directory of images.') |
|
|
group = parser.add_mutually_exclusive_group(required=True) |
|
|
group.add_argument('--image', '-i', type=str, help='Path to a single input image.') |
|
|
group.add_argument('--input-dir', '-d', type=str, help='Path to a directory of input images.') |
|
|
|
|
|
parser.add_argument('--model', '-m', type=str, default='yolov10n.pt', help='Path to YOLOv10 model') |
|
|
parser.add_argument('--output', '-o', type=str, default='./heatmaps', help='Output directory for generated layouts.') |
|
|
parser.add_argument('--conf', type=float, default=0.25, help='Object detection confidence threshold (e.g., 0.1 for more detections).') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
Path(args.output).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
print(f"Loading model: {args.model}") |
|
|
model = YOLO(args.model) |
|
|
|
|
|
|
|
|
model.model._predict_once = MethodType(_predict_once, model.model) |
|
|
|
|
|
|
|
|
model(np.zeros((640, 640, 3)), verbose=False) |
|
|
|
|
|
|
|
|
detect_layer_index = -1 |
|
|
for i, m in enumerate(model.model.model): |
|
|
if 'Detect' in type(m).__name__: |
|
|
detect_layer_index = i |
|
|
break |
|
|
|
|
|
if detect_layer_index != -1: |
|
|
input_layers_indices = model.model.model[detect_layer_index].f |
|
|
embed_layers = sorted(input_layers_indices) + [detect_layer_index] |
|
|
print(f"Auto-detected feature layers at indices: {input_layers_indices}") |
|
|
print(f"Embedding features from layers: {embed_layers}") |
|
|
else: |
|
|
print("Could not find Detect layer, falling back to hardcoded indices") |
|
|
embed_layers = [16, 19, 22, 23] |
|
|
|
|
|
|
|
|
if args.input_dir: |
|
|
input_path = Path(args.input_dir) |
|
|
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tif', '*.tiff'] |
|
|
image_files = [] |
|
|
for ext in image_extensions: |
|
|
image_files.extend(input_path.glob(ext)) |
|
|
|
|
|
if not image_files: |
|
|
print(f"No images found in '{args.input_dir}'.") |
|
|
return |
|
|
|
|
|
print(f"\nFound {len(image_files)} images in '{args.input_dir}'. Processing...") |
|
|
for img_path in image_files: |
|
|
generate_feature_heatmaps( |
|
|
model=model, |
|
|
img_path=str(img_path), |
|
|
embed_layers=embed_layers, |
|
|
output_dir=args.output, |
|
|
conf=args.conf |
|
|
) |
|
|
else: |
|
|
generate_feature_heatmaps( |
|
|
model=model, |
|
|
img_path=args.image, |
|
|
embed_layers=embed_layers, |
|
|
output_dir=args.output, |
|
|
conf=args.conf |
|
|
) |
|
|
|
|
|
print(f"\nProcessing complete. All layouts saved to '{args.output}'.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import sys |
|
|
if len(sys.argv) == 1: |
|
|
print("No arguments provided. Running heatmap generation on a test image.") |
|
|
|
|
|
|
|
|
print("Loading default model: yolov10n.pt") |
|
|
model = YOLO('yolov10n.pt') |
|
|
model.model._predict_once = MethodType(_predict_once, model.model) |
|
|
model(np.zeros((640, 640, 3)), verbose=False) |
|
|
|
|
|
|
|
|
detect_layer_index = -1 |
|
|
for i, m in enumerate(model.model.model): |
|
|
if 'Detect' in type(m).__name__: |
|
|
detect_layer_index = i |
|
|
break |
|
|
|
|
|
if detect_layer_index != -1: |
|
|
input_layers_indices = model.model.model[detect_layer_index].f |
|
|
embed_layers = sorted(input_layers_indices) + [detect_layer_index] |
|
|
print(f"Auto-detected feature layers at indices: {input_layers_indices}") |
|
|
else: |
|
|
embed_layers = [16, 19, 22, 23] |
|
|
|
|
|
|
|
|
img_path = "/home/hew/yolov10FX_obj/id-1.jpg" |
|
|
|
|
|
|
|
|
print("Using a lower confidence of 0.1 for test mode to find more objects.") |
|
|
generate_feature_heatmaps( |
|
|
model=model, |
|
|
img_path=img_path, |
|
|
embed_layers=embed_layers, |
|
|
output_dir="./", |
|
|
conf=0.1 |
|
|
) |
|
|
print(f"\nHeatmap generation completed successfully for test image!") |
|
|
|
|
|
else: |
|
|
main() |