Spaces:
No application file
No application file
""" | |
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. | |
""" | |
import os | |
import sys | |
import cv2 # Added for video processing | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as T | |
from PIL import Image, ImageDraw | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))) | |
from src.core import YAMLConfig | |
CLASS_NAMES = { | |
1: "Angular Leafspot", | |
2: "Leafspot", | |
3: "Anthracnose Fruit Rot", | |
4: "Blossom Blight", | |
5: "Gray Mold", | |
6: "Powdery Mildew Fruit", | |
7: "Powdery Mildew Leaf" | |
} | |
color_map = { | |
1: "cyan", | |
2: "blue", | |
3: "green", | |
4: "orange", | |
5: "purple", | |
6: "yellow", | |
7: "pink" | |
} | |
def draw(images, labels, boxes, scores, thrh=0.4): | |
for i, im in enumerate(images): | |
draw = ImageDraw.Draw(im) | |
scr = scores[i] | |
lab = labels[i][scr > thrh] | |
box = boxes[i][scr > thrh] | |
scrs = scr[scr > thrh] | |
for j, b in enumerate(box): | |
cls_id = int(lab[j].item()) | |
color = color_map.get(cls_id, "white") | |
label_name = CLASS_NAMES.get(cls_id, "Unknown") | |
score = round(scrs[j].item(), 2) | |
draw.rectangle(list(b), outline=color) | |
draw.text( | |
(b[0], b[1]), | |
text=f"{label_name} {score}", | |
fill=color | |
) | |
im.save("torch_results.jpg") | |
def process_image(model, device, file_path): | |
im_pil = Image.open(file_path).convert("RGB") | |
w, h = im_pil.size | |
orig_size = torch.tensor([[w, h]]).to(device) | |
transforms = T.Compose( | |
[ | |
T.Resize((640, 640)), | |
T.ToTensor(), | |
] | |
) | |
im_data = transforms(im_pil).unsqueeze(0).to(device) | |
output = model(im_data, orig_size) | |
labels, boxes, scores = output | |
draw([im_pil], labels, boxes, scores) | |
def process_video(model, device, file_path): | |
cap = cv2.VideoCapture(file_path) | |
# Get video properties | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
# Define the codec and create VideoWriter object | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
out = cv2.VideoWriter("torch_results.mp4", fourcc, fps, (orig_w, orig_h)) | |
transforms = T.Compose( | |
[ | |
T.Resize((640, 640)), | |
T.ToTensor(), | |
] | |
) | |
frame_count = 0 | |
print("Processing video frames...") | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Convert frame to PIL image | |
frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
w, h = frame_pil.size | |
orig_size = torch.tensor([[w, h]]).to(device) | |
im_data = transforms(frame_pil).unsqueeze(0).to(device) | |
output = model(im_data, orig_size) | |
labels, boxes, scores = output | |
# Draw detections on the frame | |
draw([frame_pil], labels, boxes, scores) | |
# Convert back to OpenCV image | |
frame = cv2.cvtColor(np.array(frame_pil), cv2.COLOR_RGB2BGR) | |
# Write the frame | |
out.write(frame) | |
frame_count += 1 | |
if frame_count % 10 == 0: | |
print(f"Processed {frame_count} frames...") | |
cap.release() | |
out.release() | |
print("Video processing complete. Result saved as 'results_video.mp4'.") | |
def main(args): | |
"""Main function""" | |
cfg = YAMLConfig(args.config, resume=args.resume) | |
if "HGNetv2" in cfg.yaml_cfg: | |
cfg.yaml_cfg["HGNetv2"]["pretrained"] = False | |
if args.resume: | |
checkpoint = torch.load(args.resume, map_location="cpu") | |
if "ema" in checkpoint: | |
state = checkpoint["ema"]["module"] | |
else: | |
state = checkpoint["model"] | |
else: | |
raise AttributeError("Only support resume to load model.state_dict by now.") | |
# Load train mode state and convert to deploy mode | |
cfg.model.load_state_dict(state) | |
class Model(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model = cfg.model.deploy() | |
self.postprocessor = cfg.postprocessor.deploy() | |
def forward(self, images, orig_target_sizes): | |
outputs = self.model(images) | |
outputs = self.postprocessor(outputs, orig_target_sizes) | |
return outputs | |
device = args.device | |
model = Model().to(device) | |
# Check if the input file is an image or a video | |
file_path = args.input | |
if os.path.splitext(file_path)[-1].lower() in [".jpg", ".jpeg", ".png", ".bmp"]: | |
# Process as image | |
process_image(model, device, file_path) | |
print("Image processing complete.") | |
else: | |
# Process as video | |
process_video(model, device, file_path) | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-c", "--config", type=str, required=True) | |
parser.add_argument("-r", "--resume", type=str, required=True) | |
parser.add_argument("-i", "--input", type=str, required=True) | |
parser.add_argument("-d", "--device", type=str, default="cpu") | |
args = parser.parse_args() | |
main(args) | |