File size: 1,745 Bytes
ac59957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch

# exclude extremly large displacements
MAX_FLOW = 512


def sequence_loss(
    outputs: dict[str, list[torch.Tensor]],
    flow_gts: torch.Tensor,
    valids: torch.Tensor,
    gamma: float = 0.8,
    max_flow: float = MAX_FLOW,
) -> torch.Tensor:
    """Calculate sequence loss for optical flow estimation.

    Parameters
    ----------
    outputs : Dict[str, List[torch.Tensor]]
        Dictionary containing model outputs:
        - 'flow': List of predicted flow fields, each of shape (B, 2, 2, H, W)
        - 'nf': List of normalized flow losses, each of shape (B, 2, 2, H, W)
    flow_gts : torch.Tensor
        Ground truth flow fields of shape (B, 2, 2, H, W)
    valids : torch.Tensor
        Validity masks of shape (B, 2, H, W)
    gamma : float, optional
        Weight decay factor for sequence loss, by default 0.8
    max_flow : float, optional
        Maximum flow magnitude threshold, by default MAX_FLOW

    Returns
    -------
    torch.Tensor
        Scalar loss value
    """
    n_predictions = len(outputs["flow"])
    flow_loss = 0.0

    # exlude invalid pixels and extremely large diplacements
    mag = torch.sum(flow_gts**2, dim=2).sqrt()
    valid = (valids >= 0.5) & (mag < max_flow)
    for i in range(n_predictions):
        i_weight = gamma ** (n_predictions - i - 1)
        loss_i = outputs["nf"][i]
        final_mask = (
            (~torch.isnan(loss_i.detach()))
            & (~torch.isinf(loss_i.detach()))
            & valid[:, :, None]
        )

        fms = final_mask.sum()
        if fms > 0.5:
            flow_loss += i_weight * ((final_mask * loss_i).sum() / final_mask.sum())
        else:
            flow_loss += (0.0 * loss_i).sum().nan_to_num(0.0)

    return flow_loss