Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn.functional as F | |
import torch.nn as nn | |
from typing import List | |
import utils.basic | |
def sequence_loss( | |
flow_preds, | |
flow_gt, | |
valids, | |
vis=None, | |
gamma=0.8, | |
use_huber_loss=False, | |
loss_only_for_visible=False, | |
): | |
"""Loss function defined over sequence of flow predictions""" | |
total_flow_loss = 0.0 | |
for j in range(len(flow_gt)): | |
B, S, N, D = flow_gt[j].shape | |
B, S2, N = valids[j].shape | |
assert S == S2 | |
n_predictions = len(flow_preds[j]) | |
flow_loss = 0.0 | |
for i in range(n_predictions): | |
i_weight = gamma ** (n_predictions - i - 1) | |
flow_pred = flow_preds[j][i] | |
if use_huber_loss: | |
i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0) | |
else: | |
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2 | |
i_loss = torch.mean(i_loss, dim=3) # B, S, N | |
valid_ = valids[j].clone() | |
if loss_only_for_visible: | |
valid_ = valid_ * vis[j] | |
flow_loss += i_weight * utils.basic.reduce_masked_mean(i_loss, valid_) | |
flow_loss = flow_loss / n_predictions | |
total_flow_loss += flow_loss | |
return total_flow_loss / len(flow_gt) | |
def sequence_loss_dense( | |
flow_preds, | |
flow_gt, | |
valids, | |
vis=None, | |
gamma=0.8, | |
use_huber_loss=False, | |
loss_only_for_visible=False, | |
): | |
"""Loss function defined over sequence of flow predictions""" | |
total_flow_loss = 0.0 | |
for j in range(len(flow_gt)): | |
# print('flow_gt[j]', flow_gt[j].shape) | |
B, S, D, H, W = flow_gt[j].shape | |
B, S2, _, H, W = valids[j].shape | |
assert S == S2 | |
n_predictions = len(flow_preds[j]) | |
flow_loss = 0.0 | |
# import ipdb; ipdb.set_trace() | |
for i in range(n_predictions): | |
# print('flow_e[j][i]', flow_preds[j][i].shape) | |
i_weight = gamma ** (n_predictions - i - 1) | |
flow_pred = flow_preds[j][i] # B,S,2,H,W | |
if use_huber_loss: | |
i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0) # B,S,2,H,W | |
else: | |
i_loss = (flow_pred - flow_gt[j]).abs() # B,S,2,H,W | |
i_loss_ = torch.mean(i_loss, dim=2) # B,S,H,W | |
valid_ = valids[j].reshape(B,S,H,W) | |
# print(' (%d,%d) i_loss_' % (i,j), i_loss_.shape) | |
# print(' (%d,%d) valid_' % (i,j), valid_.shape) | |
if loss_only_for_visible: | |
valid_ = valid_ * vis[j].reshape(B,-1,H,W) # usually B,S,H,W, but maybe B,1,H,W | |
flow_loss += i_weight * utils.basic.reduce_masked_mean(i_loss_, valid_, broadcast=True) | |
# import ipdb; ipdb.set_trace() | |
flow_loss = flow_loss / n_predictions | |
total_flow_loss += flow_loss | |
return total_flow_loss / len(flow_gt) | |
def huber_loss(x, y, delta=1.0): | |
"""Calculate element-wise Huber loss between x and y""" | |
diff = x - y | |
abs_diff = diff.abs() | |
flag = (abs_diff <= delta).float() | |
return flag * 0.5 * diff**2 + (1 - flag) * delta * (abs_diff - 0.5 * delta) | |
def sequence_BCE_loss(vis_preds, vis_gts, valids=None, use_logits=False): | |
total_bce_loss = 0.0 | |
# all_vis_preds = [torch.stack(vp) for vp in vis_preds] | |
# all_vis_preds = torch.stack(all_vis_preds) | |
# utils.basic.print_stats('all_vis_preds', all_vis_preds) | |
for j in range(len(vis_preds)): | |
n_predictions = len(vis_preds[j]) | |
bce_loss = 0.0 | |
for i in range(n_predictions): | |
# utils.basic.print_stats('vis_preds[%d][%d]' % (j,i), vis_preds[j][i]) | |
# utils.basic.print_stats('vis_gts[%d]' % (i), vis_gts[i]) | |
if use_logits: | |
loss = F.binary_cross_entropy_with_logits(vis_preds[j][i], vis_gts[j], reduction='none') | |
else: | |
loss = F.binary_cross_entropy(vis_preds[j][i], vis_gts[j], reduction='none') | |
if valids is None: | |
bce_loss += loss.mean() | |
else: | |
bce_loss += (loss * valids[j]).mean() | |
bce_loss = bce_loss / n_predictions | |
total_bce_loss += bce_loss | |
return total_bce_loss / len(vis_preds) | |
# def sequence_BCE_loss_dense(vis_preds, vis_gts): | |
# total_bce_loss = 0.0 | |
# for j in range(len(vis_preds)): | |
# n_predictions = len(vis_preds[j]) | |
# bce_loss = 0.0 | |
# for i in range(n_predictions): | |
# vis_e = vis_preds[j][i] | |
# vis_g = vis_gts[j] | |
# print('vis_e', vis_e.shape, 'vis_g', vis_g.shape) | |
# vis_loss = F.binary_cross_entropy(vis_e, vis_g) | |
# bce_loss += vis_loss | |
# bce_loss = bce_loss / n_predictions | |
# total_bce_loss += bce_loss | |
# return total_bce_loss / len(vis_preds) | |
def sequence_prob_loss( | |
tracks: torch.Tensor, | |
confidence: torch.Tensor, | |
target_points: torch.Tensor, | |
visibility: torch.Tensor, | |
expected_dist_thresh: float = 12.0, | |
use_logits=False, | |
): | |
"""Loss for classifying if a point is within pixel threshold of its target.""" | |
# Points with an error larger than 12 pixels are likely to be useless; marking | |
# them as occluded will actually improve Jaccard metrics and give | |
# qualitatively better results. | |
total_logprob_loss = 0.0 | |
for j in range(len(tracks)): | |
n_predictions = len(tracks[j]) | |
logprob_loss = 0.0 | |
for i in range(n_predictions): | |
err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=-1) | |
valid = (err <= expected_dist_thresh**2).float() | |
if use_logits: | |
loss = F.binary_cross_entropy_with_logits(confidence[j][i], valid, reduction="none") | |
else: | |
loss = F.binary_cross_entropy(confidence[j][i], valid, reduction="none") | |
loss *= visibility[j] | |
loss = torch.mean(loss, dim=[1, 2]) | |
logprob_loss += loss | |
logprob_loss = logprob_loss / n_predictions | |
total_logprob_loss += logprob_loss | |
return total_logprob_loss / len(tracks) | |
def sequence_prob_loss_dense( | |
tracks: torch.Tensor, | |
confidence: torch.Tensor, | |
target_points: torch.Tensor, | |
visibility: torch.Tensor, | |
expected_dist_thresh: float = 12.0, | |
use_logits=False, | |
): | |
"""Loss for classifying if a point is within pixel threshold of its target.""" | |
# Points with an error larger than 12 pixels are likely to be useless; marking | |
# them as occluded will actually improve Jaccard metrics and give | |
# qualitatively better results. | |
# all_confidence = [torch.stack(vp) for vp in confidence] | |
# all_confidence = torch.stack(all_confidence) | |
# utils.basic.print_stats('all_confidence', all_confidence) | |
total_logprob_loss = 0.0 | |
for j in range(len(tracks)): | |
n_predictions = len(tracks[j]) | |
logprob_loss = 0.0 | |
for i in range(n_predictions): | |
# print('trajs_e', tracks[j][i].shape) | |
# print('trajs_g', target_points[j].shape) | |
err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=2) | |
positive = (err <= expected_dist_thresh**2).float() | |
# print('conf', confidence[j][i].shape, 'positive', positive.shape) | |
if use_logits: | |
loss = F.binary_cross_entropy_with_logits(confidence[j][i].squeeze(2), positive, reduction="none") | |
else: | |
loss = F.binary_cross_entropy(confidence[j][i].squeeze(2), positive, reduction="none") | |
loss *= visibility[j].squeeze(2) # B,S,H,W | |
loss = torch.mean(loss, dim=[1,2,3]) | |
logprob_loss += loss | |
logprob_loss = logprob_loss / n_predictions | |
total_logprob_loss += logprob_loss | |
return total_logprob_loss / len(tracks) | |
def masked_mean(data, mask, dim): | |
if mask is None: | |
return data.mean(dim=dim, keepdim=True) | |
mask = mask.float() | |
mask_sum = torch.sum(mask, dim=dim, keepdim=True) | |
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( | |
mask_sum, min=1.0 | |
) | |
return mask_mean | |
def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): | |
if mask is None: | |
return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True) | |
mask = mask.float() | |
mask_sum = torch.sum(mask, dim=dim, keepdim=True) | |
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( | |
mask_sum, min=1.0 | |
) | |
mask_var = torch.sum( | |
mask * (data - mask_mean) ** 2, dim=dim, keepdim=True | |
) / torch.clamp(mask_sum, min=1.0) | |
return mask_mean.squeeze(dim), mask_var.squeeze(dim) | |