deki / yolo_script.py
orasul's picture
Load initial app
6ff22d6
raw
history blame
6.56 kB
import ultralytics
import cv2
from ultralytics import YOLO
import os
import glob
import argparse
import sys
import numpy as np
import uuid
def iou(box1, box2):
# Convert normalized coordinates to (x1, y1, x2, y2)
x1_1 = box1[0] - box1[2] / 2
y1_1 = box1[1] - box1[3] / 2
x2_1 = box1[0] + box1[2] / 2
y2_1 = box1[1] + box1[3] / 2
x1_2 = box2[0] - box2[2] / 2
y1_2 = box2[1] - box2[3] / 2
x2_2 = box2[0] + box2[2] / 2
y2_2 = box2[1] + box2[3] / 2
xi1 = max(x1_1, x1_2)
yi1 = max(y1_1, y1_2)
xi2 = min(x2_1, x2_2)
yi2 = min(y2_1, y2_2)
inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
union_area = box1_area + box2_area - inter_area
return inter_area / union_area if union_area > 0 else 0
def load_image(input_image, base_name: str = None):
if isinstance(input_image, str):
img = cv2.imread(input_image)
if img is None:
raise ValueError(f"Unable to load image from path: {input_image}")
if base_name is None:
base_name = os.path.splitext(os.path.basename(input_image))[0]
return img, base_name
else:
# Assume input_image is raw bytes or a file-like object
if isinstance(input_image, bytes):
image_bytes = input_image
else:
image_bytes = input_image.read()
nparr = np.frombuffer(image_bytes, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
raise ValueError("Unable to decode image from input bytes or file-like object.")
if base_name is None:
base_name = str(uuid.uuid4())
return img, base_name
def process_yolo(input_image, weights_file: str, output_dir: str = './yolo_run', model_obj: YOLO = None, base_name: str = None) -> str:
orig_image, inferred_base_name = load_image(input_image, base_name)
base_name = inferred_base_name
os.makedirs(output_dir, exist_ok=True)
# Determine file extension: if input_image is a file path, use its extension. Otherwise, default to .png
if isinstance(input_image, str):
ext = os.path.splitext(input_image)[1]
else:
ext = ".png"
output_image_name = f"{base_name}_yolo{ext}"
updated_output_image_name = f"{base_name}_yolo_updated{ext}"
# If input_image is a file path, call YOLO with the path to preserve filename-based labeling.
# Otherwise, if processing in-memory, YOLO might default to a generic name.
if isinstance(input_image, str):
source_input = input_image
else:
source_input = orig_image
# Use provided model or load from weights_file.
if model_obj is None:
model = YOLO(weights_file)
else:
model = model_obj
results = model(
source=source_input,
save_txt=True,
project=output_dir,
name='.',
exist_ok=True,
)
# Save the initial inference image.
img_with_boxes = results[0].plot(font_size=2, line_width=1)
output_image_path = os.path.join(output_dir, output_image_name)
cv2.imwrite(output_image_path, img_with_boxes)
print(f"Image saved as '{output_image_path}'")
labels_dir = os.path.join(output_dir, 'labels')
label_file = os.path.join(labels_dir, f"{base_name}.txt")
if not os.path.isfile(label_file):
raise FileNotFoundError(f"No label files found for the image '{base_name}' at path '{label_file}'.")
with open(label_file, 'r') as f:
lines = f.readlines()
boxes = []
for idx, line in enumerate(lines):
tokens = line.strip().split()
class_id = int(tokens[0])
x_center = float(tokens[1])
y_center = float(tokens[2])
width = float(tokens[3])
height = float(tokens[4])
boxes.append({
'class_id': class_id,
'bbox': [x_center, y_center, width, height],
'line': line,
'index': idx
})
boxes.sort(key=lambda b: b['bbox'][1] - (b['bbox'][3] / 2))
# Perform NMS.
keep_indices = []
suppressed = [False] * len(boxes)
num_removed = 0
for i in range(len(boxes)):
if suppressed[i]:
continue
keep_indices.append(i)
for j in range(i + 1, len(boxes)):
if suppressed[j]:
continue
if boxes[i]['class_id'] == boxes[j]['class_id']:
iou_value = iou(boxes[i]['bbox'], boxes[j]['bbox'])
if iou_value > 0.7:
suppressed[j] = True
num_removed += 1
with open(label_file, 'w') as f:
for idx in keep_indices:
f.write(boxes[idx]['line'])
print(f"Number of bounding boxes removed: {num_removed}")
# Draw updated bounding boxes on the original image (loaded in memory).
drawn_image = orig_image.copy()
h_img, w_img, _ = drawn_image.shape
for i, idx in enumerate(keep_indices):
box = boxes[idx]
x_center, y_center, w_norm, h_norm = box['bbox']
x_center *= w_img
y_center *= h_img
w_box = w_norm * w_img
h_box = h_norm * h_img
x1 = int(x_center - w_box / 2)
y1 = int(y_center - h_box / 2)
x2 = int(x_center + w_box / 2)
y2 = int(y_center + h_box / 2)
cv2.rectangle(drawn_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(drawn_image, str(i + 1), (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (143, 10, 18), 1)
updated_output_image_path = os.path.join(output_dir, updated_output_image_name)
cv2.imwrite(updated_output_image_path, drawn_image)
print(f"Updated image saved as '{updated_output_image_path}'")
return updated_output_image_path
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process YOLO inference and NMS on an image.')
parser.add_argument('input_image', help='Path to the input image or pass raw bytes via a file-like object.')
parser.add_argument('weights_file', help='Path to the YOLO weights file.')
parser.add_argument('output_dir', nargs='?', default='./yolo_run', help='Output directory (optional).')
parser.add_argument('--base_name', help='Optional base name for output files (without extension).')
args = parser.parse_args()
try:
process_yolo(args.input_image, args.weights_file, args.output_dir, base_name=args.base_name)
except Exception as e:
print(e)
sys.exit(1)