Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,950 Bytes
e85fecb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
import importlib.metadata
from torch import Tensor
if "0.15.2" in importlib.metadata.version("torchvision"):
import torchvision
torchvision.disable_beta_transforms_warning()
from torchvision.datapoints import BoundingBox as BoundingBoxes
from torchvision.datapoints import BoundingBoxFormat, Image, Mask, Video
from torchvision.transforms.v2 import SanitizeBoundingBox as SanitizeBoundingBoxes
_boxes_keys = ["format", "spatial_size"]
elif "0.17" > importlib.metadata.version("torchvision") >= "0.16":
import torchvision
torchvision.disable_beta_transforms_warning()
from torchvision.transforms.v2 import SanitizeBoundingBoxes
from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
_boxes_keys = ["format", "canvas_size"]
elif importlib.metadata.version("torchvision") >= "0.17":
import torchvision
from torchvision.transforms.v2 import SanitizeBoundingBoxes
from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
_boxes_keys = ["format", "canvas_size"]
else:
raise RuntimeError("Please make sure torchvision version >= 0.15.2")
def convert_to_tv_tensor(tensor: Tensor, key: str, box_format="xyxy", spatial_size=None) -> Tensor:
"""
Args:
tensor (Tensor): input tensor
key (str): transform to key
Return:
Dict[str, TV_Tensor]
"""
assert key in (
"boxes",
"masks",
), "Only support 'boxes' and 'masks'"
if key == "boxes":
box_format = getattr(BoundingBoxFormat, box_format.upper())
_kwargs = dict(zip(_boxes_keys, [box_format, spatial_size]))
return BoundingBoxes(tensor, **_kwargs)
if key == "masks":
return Mask(tensor)
|