Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,345 Bytes
77a88de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import torch
import dataclasses
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Any, Optional, Dict
@dataclass(eq=False)
class VideoData:
"""
Dataclass for storing video tracks data.
"""
video: torch.Tensor # B,S,C,H,W
trajs: torch.Tensor # B,S,N,2
visibs: torch.Tensor # B,S,N
valids: Optional[torch.Tensor] = None # B,S,N
seq_name: Optional[str] = None
dname: Optional[str] = None
aug_video: Optional[torch.Tensor] = None
def collate_fn(batch):
"""
Collate function for video tracks data.
"""
video = torch.stack([b.video for b in batch], dim=0)
trajs = torch.stack([b.trajs for b in batch], dim=0)
visibs = torch.stack([b.visibs for b in batch], dim=0)
seq_name = [b.seq_name for b in batch]
dname = [b.dname for b in batch]
return VideoData(
video=video,
trajs=trajs,
visibs=visibs,
seq_name=seq_name,
dname=dname,
)
def collate_fn_train(batch):
"""
Collate function for video tracks data during training.
"""
gotit = [gotit for _, gotit in batch]
video = torch.stack([b.video for b, _ in batch], dim=0)
trajs = torch.stack([b.trajs for b, _ in batch], dim=0)
visibs = torch.stack([b.visibs for b, _ in batch], dim=0)
valids = torch.stack([b.valids for b, _ in batch], dim=0)
seq_name = [b.seq_name for b, _ in batch]
dname = [b.dname for b, _ in batch]
return (
VideoData(
video=video,
trajs=trajs,
visibs=visibs,
valids=valids,
seq_name=seq_name,
dname=dname,
),
gotit,
)
def try_to_cuda(t: Any) -> Any:
"""
Try to move the input variable `t` to a cuda device.
Args:
t: Input.
Returns:
t_cuda: `t` moved to a cuda device, if supported.
"""
try:
t = t.float().cuda()
except AttributeError:
pass
return t
def dataclass_to_cuda_(obj):
"""
Move all contents of a dataclass to cuda inplace if supported.
Args:
batch: Input dataclass.
Returns:
batch_cuda: `batch` moved to a cuda device, if supported.
"""
for f in dataclasses.fields(obj):
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
return obj
|