dangminh214's picture
Clean initial commit (no large files, no LFS pointers)
b26e93d
"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
import glob
import os
import torch
import torch.utils.data as data
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as F
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
class ToTensor(T.ToTensor):
def __init__(self) -> None:
super().__init__()
def __call__(self, pic):
if isinstance(pic, torch.Tensor):
return pic
return super().__call__(pic)
class PadToSize(T.Pad):
def __init__(self, size, fill=0, padding_mode="constant"):
super().__init__(0, fill, padding_mode)
self.size = size
self.fill = fill
def __call__(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be padded.
Returns:
PIL Image or Tensor: Padded image.
"""
w, h = F.get_image_size(img)
padding = (0, 0, self.size[0] - w, self.size[1] - h)
return F.pad(img, padding, self.fill, self.padding_mode)
class Dataset(data.Dataset):
def __init__(self, img_dir: str = "", preprocess: T.Compose = None, device="cuda:0") -> None:
super().__init__()
self.device = device
self.size = 640
self.im_path_list = list(glob.glob(os.path.join(img_dir, "*.jpg")))
if preprocess is None:
self.preprocess = T.Compose(
[
T.Resize(size=639, max_size=640),
PadToSize(size=(640, 640), fill=114),
ToTensor(),
T.ConvertImageDtype(torch.float),
]
)
else:
self.preprocess = preprocess
def __len__(
self,
):
return len(self.im_path_list)
def __getitem__(self, index):
# im = Image.open(self.img_path_list[index]).convert('RGB')
im = torchvision.io.read_file(self.im_path_list[index])
im = torchvision.io.decode_jpeg(
im, mode=torchvision.io.ImageReadMode.RGB, device=self.device
)
_, h, w = im.shape # c,h,w
im = self.preprocess(im)
blob = {
"images": im,
"im_shape": torch.tensor([self.size, self.size]).to(im.device),
"scale_factor": torch.tensor([self.size / h, self.size / w]).to(im.device),
"orig_target_sizes": torch.tensor([w, h]).to(im.device),
}
return blob
@staticmethod
def post_process():
pass
@staticmethod
def collate_fn():
pass
def draw_nms_result(blob, outputs, draw_score_threshold=0.25, name=""):
"""show result
Keys:
'num_dets', 'det_boxes', 'det_scores', 'det_classes'
"""
for i in range(blob["image"].shape[0]):
det_scores = outputs["det_scores"][i]
det_boxes = outputs["det_boxes"][i][det_scores > draw_score_threshold]
im = (blob["image"][i] * 255).to(torch.uint8)
im = torchvision.utils.draw_bounding_boxes(im, boxes=det_boxes, width=2)
Image.fromarray(im.permute(1, 2, 0).cpu().numpy()).save(f"test_{name}_{i}.jpg")