Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import dataclasses | |
import torch.nn.functional as F | |
from dataclasses import dataclass | |
from typing import Any, Optional, Dict | |
class VideoData: | |
""" | |
Dataclass for storing video tracks data. | |
""" | |
video: torch.Tensor # B,S,C,H,W | |
trajs: torch.Tensor # B,S,N,2 | |
visibs: torch.Tensor # B,S,N | |
valids: Optional[torch.Tensor] = None # B,S,N | |
seq_name: Optional[str] = None | |
dname: Optional[str] = None | |
aug_video: Optional[torch.Tensor] = None | |
def collate_fn(batch): | |
""" | |
Collate function for video tracks data. | |
""" | |
video = torch.stack([b.video for b in batch], dim=0) | |
trajs = torch.stack([b.trajs for b in batch], dim=0) | |
visibs = torch.stack([b.visibs for b in batch], dim=0) | |
seq_name = [b.seq_name for b in batch] | |
dname = [b.dname for b in batch] | |
return VideoData( | |
video=video, | |
trajs=trajs, | |
visibs=visibs, | |
seq_name=seq_name, | |
dname=dname, | |
) | |
def collate_fn_train(batch): | |
""" | |
Collate function for video tracks data during training. | |
""" | |
gotit = [gotit for _, gotit in batch] | |
video = torch.stack([b.video for b, _ in batch], dim=0) | |
trajs = torch.stack([b.trajs for b, _ in batch], dim=0) | |
visibs = torch.stack([b.visibs for b, _ in batch], dim=0) | |
valids = torch.stack([b.valids for b, _ in batch], dim=0) | |
seq_name = [b.seq_name for b, _ in batch] | |
dname = [b.dname for b, _ in batch] | |
return ( | |
VideoData( | |
video=video, | |
trajs=trajs, | |
visibs=visibs, | |
valids=valids, | |
seq_name=seq_name, | |
dname=dname, | |
), | |
gotit, | |
) | |
def try_to_cuda(t: Any) -> Any: | |
""" | |
Try to move the input variable `t` to a cuda device. | |
Args: | |
t: Input. | |
Returns: | |
t_cuda: `t` moved to a cuda device, if supported. | |
""" | |
try: | |
t = t.float().cuda() | |
except AttributeError: | |
pass | |
return t | |
def dataclass_to_cuda_(obj): | |
""" | |
Move all contents of a dataclass to cuda inplace if supported. | |
Args: | |
batch: Input dataclass. | |
Returns: | |
batch_cuda: `batch` moved to a cuda device, if supported. | |
""" | |
for f in dataclasses.fields(obj): | |
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name))) | |
return obj | |