File size: 2,345 Bytes
77a88de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import dataclasses
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Any, Optional, Dict


@dataclass(eq=False)
class VideoData:
    """
    Dataclass for storing video tracks data.
    """

    video: torch.Tensor  # B,S,C,H,W
    trajs: torch.Tensor  # B,S,N,2
    visibs: torch.Tensor  # B,S,N
    valids: Optional[torch.Tensor] = None  # B,S,N
    seq_name: Optional[str] = None
    dname: Optional[str] = None
    aug_video: Optional[torch.Tensor] = None


def collate_fn(batch):
    """
    Collate function for video tracks data.
    """
    video = torch.stack([b.video for b in batch], dim=0)
    trajs = torch.stack([b.trajs for b in batch], dim=0)
    visibs = torch.stack([b.visibs for b in batch], dim=0)
    seq_name = [b.seq_name for b in batch]
    dname = [b.dname for b in batch]

    return VideoData(
        video=video,
        trajs=trajs,
        visibs=visibs,
        seq_name=seq_name,
        dname=dname,
    )


def collate_fn_train(batch):
    """
    Collate function for video tracks data during training.
    """
    gotit = [gotit for _, gotit in batch]
    video = torch.stack([b.video for b, _ in batch], dim=0)
    trajs = torch.stack([b.trajs for b, _ in batch], dim=0)
    visibs = torch.stack([b.visibs for b, _ in batch], dim=0)
    valids = torch.stack([b.valids for b, _ in batch], dim=0)
    seq_name = [b.seq_name for b, _ in batch]
    dname = [b.dname for b, _ in batch]

    return (
        VideoData(
            video=video,
            trajs=trajs,
            visibs=visibs,
            valids=valids,
            seq_name=seq_name,
            dname=dname,
        ),
        gotit,
    )


def try_to_cuda(t: Any) -> Any:
    """
    Try to move the input variable `t` to a cuda device.

    Args:
        t: Input.

    Returns:
        t_cuda: `t` moved to a cuda device, if supported.
    """
    try:
        t = t.float().cuda()
    except AttributeError:
        pass
    return t


def dataclass_to_cuda_(obj):
    """
    Move all contents of a dataclass to cuda inplace if supported.

    Args:
        batch: Input dataclass.

    Returns:
        batch_cuda: `batch` moved to a cuda device, if supported.
    """
    for f in dataclasses.fields(obj):
        setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
    return obj