egorchistov's picture
Initial release
ac59957
import torch
# exclude extremly large displacements
MAX_FLOW = 512
def sequence_loss(
outputs: dict[str, list[torch.Tensor]],
flow_gts: torch.Tensor,
valids: torch.Tensor,
gamma: float = 0.8,
max_flow: float = MAX_FLOW,
) -> torch.Tensor:
"""Calculate sequence loss for optical flow estimation.
Parameters
----------
outputs : Dict[str, List[torch.Tensor]]
Dictionary containing model outputs:
- 'flow': List of predicted flow fields, each of shape (B, 2, 2, H, W)
- 'nf': List of normalized flow losses, each of shape (B, 2, 2, H, W)
flow_gts : torch.Tensor
Ground truth flow fields of shape (B, 2, 2, H, W)
valids : torch.Tensor
Validity masks of shape (B, 2, H, W)
gamma : float, optional
Weight decay factor for sequence loss, by default 0.8
max_flow : float, optional
Maximum flow magnitude threshold, by default MAX_FLOW
Returns
-------
torch.Tensor
Scalar loss value
"""
n_predictions = len(outputs["flow"])
flow_loss = 0.0
# exlude invalid pixels and extremely large diplacements
mag = torch.sum(flow_gts**2, dim=2).sqrt()
valid = (valids >= 0.5) & (mag < max_flow)
for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1)
loss_i = outputs["nf"][i]
final_mask = (
(~torch.isnan(loss_i.detach()))
& (~torch.isinf(loss_i.detach()))
& valid[:, :, None]
)
fms = final_mask.sum()
if fms > 0.5:
flow_loss += i_weight * ((final_mask * loss_i).sum() / final_mask.sum())
else:
flow_loss += (0.0 * loss_i).sum().nan_to_num(0.0)
return flow_loss