alltracker / utils /loss.py
aharley's picture
added basics
77a88de
raw
history blame
8.65 kB
import torch
import torch.nn.functional as F
import torch.nn as nn
from typing import List
import utils.basic
def sequence_loss(
flow_preds,
flow_gt,
valids,
vis=None,
gamma=0.8,
use_huber_loss=False,
loss_only_for_visible=False,
):
"""Loss function defined over sequence of flow predictions"""
total_flow_loss = 0.0
for j in range(len(flow_gt)):
B, S, N, D = flow_gt[j].shape
B, S2, N = valids[j].shape
assert S == S2
n_predictions = len(flow_preds[j])
flow_loss = 0.0
for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1)
flow_pred = flow_preds[j][i]
if use_huber_loss:
i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0)
else:
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
i_loss = torch.mean(i_loss, dim=3) # B, S, N
valid_ = valids[j].clone()
if loss_only_for_visible:
valid_ = valid_ * vis[j]
flow_loss += i_weight * utils.basic.reduce_masked_mean(i_loss, valid_)
flow_loss = flow_loss / n_predictions
total_flow_loss += flow_loss
return total_flow_loss / len(flow_gt)
def sequence_loss_dense(
flow_preds,
flow_gt,
valids,
vis=None,
gamma=0.8,
use_huber_loss=False,
loss_only_for_visible=False,
):
"""Loss function defined over sequence of flow predictions"""
total_flow_loss = 0.0
for j in range(len(flow_gt)):
# print('flow_gt[j]', flow_gt[j].shape)
B, S, D, H, W = flow_gt[j].shape
B, S2, _, H, W = valids[j].shape
assert S == S2
n_predictions = len(flow_preds[j])
flow_loss = 0.0
# import ipdb; ipdb.set_trace()
for i in range(n_predictions):
# print('flow_e[j][i]', flow_preds[j][i].shape)
i_weight = gamma ** (n_predictions - i - 1)
flow_pred = flow_preds[j][i] # B,S,2,H,W
if use_huber_loss:
i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0) # B,S,2,H,W
else:
i_loss = (flow_pred - flow_gt[j]).abs() # B,S,2,H,W
i_loss_ = torch.mean(i_loss, dim=2) # B,S,H,W
valid_ = valids[j].reshape(B,S,H,W)
# print(' (%d,%d) i_loss_' % (i,j), i_loss_.shape)
# print(' (%d,%d) valid_' % (i,j), valid_.shape)
if loss_only_for_visible:
valid_ = valid_ * vis[j].reshape(B,-1,H,W) # usually B,S,H,W, but maybe B,1,H,W
flow_loss += i_weight * utils.basic.reduce_masked_mean(i_loss_, valid_, broadcast=True)
# import ipdb; ipdb.set_trace()
flow_loss = flow_loss / n_predictions
total_flow_loss += flow_loss
return total_flow_loss / len(flow_gt)
def huber_loss(x, y, delta=1.0):
"""Calculate element-wise Huber loss between x and y"""
diff = x - y
abs_diff = diff.abs()
flag = (abs_diff <= delta).float()
return flag * 0.5 * diff**2 + (1 - flag) * delta * (abs_diff - 0.5 * delta)
def sequence_BCE_loss(vis_preds, vis_gts, valids=None, use_logits=False):
total_bce_loss = 0.0
# all_vis_preds = [torch.stack(vp) for vp in vis_preds]
# all_vis_preds = torch.stack(all_vis_preds)
# utils.basic.print_stats('all_vis_preds', all_vis_preds)
for j in range(len(vis_preds)):
n_predictions = len(vis_preds[j])
bce_loss = 0.0
for i in range(n_predictions):
# utils.basic.print_stats('vis_preds[%d][%d]' % (j,i), vis_preds[j][i])
# utils.basic.print_stats('vis_gts[%d]' % (i), vis_gts[i])
if use_logits:
loss = F.binary_cross_entropy_with_logits(vis_preds[j][i], vis_gts[j], reduction='none')
else:
loss = F.binary_cross_entropy(vis_preds[j][i], vis_gts[j], reduction='none')
if valids is None:
bce_loss += loss.mean()
else:
bce_loss += (loss * valids[j]).mean()
bce_loss = bce_loss / n_predictions
total_bce_loss += bce_loss
return total_bce_loss / len(vis_preds)
# def sequence_BCE_loss_dense(vis_preds, vis_gts):
# total_bce_loss = 0.0
# for j in range(len(vis_preds)):
# n_predictions = len(vis_preds[j])
# bce_loss = 0.0
# for i in range(n_predictions):
# vis_e = vis_preds[j][i]
# vis_g = vis_gts[j]
# print('vis_e', vis_e.shape, 'vis_g', vis_g.shape)
# vis_loss = F.binary_cross_entropy(vis_e, vis_g)
# bce_loss += vis_loss
# bce_loss = bce_loss / n_predictions
# total_bce_loss += bce_loss
# return total_bce_loss / len(vis_preds)
def sequence_prob_loss(
tracks: torch.Tensor,
confidence: torch.Tensor,
target_points: torch.Tensor,
visibility: torch.Tensor,
expected_dist_thresh: float = 12.0,
use_logits=False,
):
"""Loss for classifying if a point is within pixel threshold of its target."""
# Points with an error larger than 12 pixels are likely to be useless; marking
# them as occluded will actually improve Jaccard metrics and give
# qualitatively better results.
total_logprob_loss = 0.0
for j in range(len(tracks)):
n_predictions = len(tracks[j])
logprob_loss = 0.0
for i in range(n_predictions):
err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=-1)
valid = (err <= expected_dist_thresh**2).float()
if use_logits:
loss = F.binary_cross_entropy_with_logits(confidence[j][i], valid, reduction="none")
else:
loss = F.binary_cross_entropy(confidence[j][i], valid, reduction="none")
loss *= visibility[j]
loss = torch.mean(loss, dim=[1, 2])
logprob_loss += loss
logprob_loss = logprob_loss / n_predictions
total_logprob_loss += logprob_loss
return total_logprob_loss / len(tracks)
def sequence_prob_loss_dense(
tracks: torch.Tensor,
confidence: torch.Tensor,
target_points: torch.Tensor,
visibility: torch.Tensor,
expected_dist_thresh: float = 12.0,
use_logits=False,
):
"""Loss for classifying if a point is within pixel threshold of its target."""
# Points with an error larger than 12 pixels are likely to be useless; marking
# them as occluded will actually improve Jaccard metrics and give
# qualitatively better results.
# all_confidence = [torch.stack(vp) for vp in confidence]
# all_confidence = torch.stack(all_confidence)
# utils.basic.print_stats('all_confidence', all_confidence)
total_logprob_loss = 0.0
for j in range(len(tracks)):
n_predictions = len(tracks[j])
logprob_loss = 0.0
for i in range(n_predictions):
# print('trajs_e', tracks[j][i].shape)
# print('trajs_g', target_points[j].shape)
err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=2)
positive = (err <= expected_dist_thresh**2).float()
# print('conf', confidence[j][i].shape, 'positive', positive.shape)
if use_logits:
loss = F.binary_cross_entropy_with_logits(confidence[j][i].squeeze(2), positive, reduction="none")
else:
loss = F.binary_cross_entropy(confidence[j][i].squeeze(2), positive, reduction="none")
loss *= visibility[j].squeeze(2) # B,S,H,W
loss = torch.mean(loss, dim=[1,2,3])
logprob_loss += loss
logprob_loss = logprob_loss / n_predictions
total_logprob_loss += logprob_loss
return total_logprob_loss / len(tracks)
def masked_mean(data, mask, dim):
if mask is None:
return data.mean(dim=dim, keepdim=True)
mask = mask.float()
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
mask_sum, min=1.0
)
return mask_mean
def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
if mask is None:
return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
mask = mask.float()
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
mask_sum, min=1.0
)
mask_var = torch.sum(
mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
) / torch.clamp(mask_sum, min=1.0)
return mask_mean.squeeze(dim), mask_var.squeeze(dim)