import torch from torchvision.utils import draw_bounding_boxes from torchvision.utils import draw_segmentation_masks def draw_image(image, masks, boxes, labels, alpha=0.4): image = torch.from_numpy(image).permute(2, 0, 1) if len(boxes) > 0: image = draw_bounding_boxes(image, boxes, colors=['red'] * len(boxes), labels=labels, width=2) if len(masks) > 0: image = draw_segmentation_masks(image, masks=masks, colors=['cyan'] * len(masks), alpha=alpha) return image.numpy().transpose(1, 2, 0)