dangminh214's picture
Clean initial commit (no large files, no LFS pointers)
b26e93d
"""
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
"""
import os
import sys
import cv2 # Added for video processing
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image, ImageDraw
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
from src.core import YAMLConfig
CLASS_NAMES = {
1: "Angular Leafspot",
2: "Leafspot",
3: "Anthracnose Fruit Rot",
4: "Blossom Blight",
5: "Gray Mold",
6: "Powdery Mildew Fruit",
7: "Powdery Mildew Leaf"
}
color_map = {
1: "cyan",
2: "blue",
3: "green",
4: "orange",
5: "purple",
6: "yellow",
7: "pink"
}
def draw(images, labels, boxes, scores, thrh=0.4):
for i, im in enumerate(images):
draw = ImageDraw.Draw(im)
scr = scores[i]
lab = labels[i][scr > thrh]
box = boxes[i][scr > thrh]
scrs = scr[scr > thrh]
for j, b in enumerate(box):
cls_id = int(lab[j].item())
color = color_map.get(cls_id, "white")
label_name = CLASS_NAMES.get(cls_id, "Unknown")
score = round(scrs[j].item(), 2)
draw.rectangle(list(b), outline=color)
draw.text(
(b[0], b[1]),
text=f"{label_name} {score}",
fill=color
)
im.save("torch_results.jpg")
def process_image(model, device, file_path):
im_pil = Image.open(file_path).convert("RGB")
w, h = im_pil.size
orig_size = torch.tensor([[w, h]]).to(device)
transforms = T.Compose(
[
T.Resize((640, 640)),
T.ToTensor(),
]
)
im_data = transforms(im_pil).unsqueeze(0).to(device)
output = model(im_data, orig_size)
labels, boxes, scores = output
draw([im_pil], labels, boxes, scores)
def process_video(model, device, file_path):
cap = cv2.VideoCapture(file_path)
# Get video properties
fps = cap.get(cv2.CAP_PROP_FPS)
orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter("torch_results.mp4", fourcc, fps, (orig_w, orig_h))
transforms = T.Compose(
[
T.Resize((640, 640)),
T.ToTensor(),
]
)
frame_count = 0
print("Processing video frames...")
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Convert frame to PIL image
frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
w, h = frame_pil.size
orig_size = torch.tensor([[w, h]]).to(device)
im_data = transforms(frame_pil).unsqueeze(0).to(device)
output = model(im_data, orig_size)
labels, boxes, scores = output
# Draw detections on the frame
draw([frame_pil], labels, boxes, scores)
# Convert back to OpenCV image
frame = cv2.cvtColor(np.array(frame_pil), cv2.COLOR_RGB2BGR)
# Write the frame
out.write(frame)
frame_count += 1
if frame_count % 10 == 0:
print(f"Processed {frame_count} frames...")
cap.release()
out.release()
print("Video processing complete. Result saved as 'results_video.mp4'.")
def main(args):
"""Main function"""
cfg = YAMLConfig(args.config, resume=args.resume)
if "HGNetv2" in cfg.yaml_cfg:
cfg.yaml_cfg["HGNetv2"]["pretrained"] = False
if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu")
if "ema" in checkpoint:
state = checkpoint["ema"]["module"]
else:
state = checkpoint["model"]
else:
raise AttributeError("Only support resume to load model.state_dict by now.")
# Load train mode state and convert to deploy mode
cfg.model.load_state_dict(state)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.model = cfg.model.deploy()
self.postprocessor = cfg.postprocessor.deploy()
def forward(self, images, orig_target_sizes):
outputs = self.model(images)
outputs = self.postprocessor(outputs, orig_target_sizes)
return outputs
device = args.device
model = Model().to(device)
# Check if the input file is an image or a video
file_path = args.input
if os.path.splitext(file_path)[-1].lower() in [".jpg", ".jpeg", ".png", ".bmp"]:
# Process as image
process_image(model, device, file_path)
print("Image processing complete.")
else:
# Process as video
process_video(model, device, file_path)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, required=True)
parser.add_argument("-r", "--resume", type=str, required=True)
parser.add_argument("-i", "--input", type=str, required=True)
parser.add_argument("-d", "--device", type=str, default="cpu")
args = parser.parse_args()
main(args)