Spaces:
Running
Running
File size: 8,362 Bytes
406f22d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
from itertools import permutations
import torch
from torch import nn
from scipy.optimize import linear_sum_assignment
class PITLossWrapper(nn.Module):
def __init__(
self, loss_func, pit_from="pw_mtx", equidistant_weight=False, perm_reduce=None, threshold_byloss=True
):
super().__init__()
self.loss_func = loss_func
self.pit_from = pit_from
self.perm_reduce = perm_reduce
self.threshold_byloss = threshold_byloss
self.equidistant_weight = equidistant_weight
if self.pit_from not in ["pw_mtx", "pw_pt", "perm_avg", "pw_mtx_broadcast", "pw_mtx_multidecoder_keepmtx", "pw_mtx_multidecoder_batchmin"]:
raise ValueError(
"Unsupported loss function type {} for now. Expected"
"one of [`pw_mtx`, `pw_pt`, `perm_avg`, `pw_mtx_broadcast`]".format(self.pit_from)
)
def forward(self, ests, targets, return_ests=False, reduce_kwargs=None, **kwargs):
n_src = targets.shape[1]
if self.pit_from == "pw_mtx":
pw_loss = self.loss_func(ests, targets, **kwargs)
elif self.pit_from == "pw_mtx_broadcast":
pw_loss = self.loss_func[0](ests, targets, **kwargs)
elif self.pit_from == "pw_mtx_multidecoder_keepmtx":
ests_last_block = ests[-1]
pw_loss = self.loss_func[0](ests_last_block, targets, **kwargs)
elif self.pit_from == "pw_mtx_multidecoder_batchmin":
blocks_num = len(ests)
ests = torch.cat(ests, dim=0)
targets = torch.cat([targets] * blocks_num, dim=0)
pw_loss = self.loss_func(ests, targets, **kwargs)
elif self.pit_from == "pw_pt":
pw_loss = self.get_pw_losses(self.loss_func, ests, targets, **kwargs)
elif self.pit_from == "perm_avg":
min_loss, batch_indices = self.best_perm_from_perm_avg_loss(
self.loss_func, ests, targets, **kwargs
)
# print(batch_indices)
mean_loss = torch.mean(min_loss)
if not return_ests:
return mean_loss
reordered = self.reordered_sources(ests, batch_indices)
return mean_loss, reordered
else:
return
# import pdb; pdb.set_trace()
assert pw_loss.ndim == 3, (
"Something went wrong with the loss " "function, please read the docs."
)
assert (
pw_loss.shape[0] == targets.shape[0]
), "PIT loss needs same batch dim as input"
reduce_kwargs = reduce_kwargs if reduce_kwargs is not None else dict()
min_loss, batch_indices = self.find_best_perm(
pw_loss, perm_reduce=self.perm_reduce, **reduce_kwargs
)
if self.pit_from == "pw_mtx_multidecoder_keepmtx":
reordered = []
mean_loss = 0
for i in range(len(ests)):
reordered_ests_each_block = self.reordered_sources(ests[i], batch_indices)
reordered.append(reordered_ests_each_block)
loss_each_block = self.loss_func[1](reordered_ests_each_block, targets, **kwargs)
if self.threshold_byloss:
if loss_each_block[loss_each_block > -30].nelement() > 0:
loss_each_block = loss_each_block[loss_each_block > -30]
if self.equidistant_weight:
mean_loss = mean_loss + (i + 1) * 1 / len(ests) * loss_each_block.mean()
else:
mean_loss = mean_loss + 1 / len(ests) * loss_each_block.mean()
reordered = torch.cat(reordered, dim=0)
if not return_ests:
return mean_loss
return mean_loss, reordered
else:
if self.threshold_byloss:
if min_loss[min_loss > -30].nelement() > 0:
min_loss = min_loss[min_loss > -30]
mean_loss = torch.mean(min_loss)
reordered = self.reordered_sources(ests, batch_indices)
# import pdb; pdb.set_trace()
if self.pit_from == "pw_mtx_broadcast":
mean_loss += 0.5 * self.loss_func[1](reordered, targets, **kwargs).mean()
if not return_ests:
return mean_loss
return mean_loss, reordered
def get_pw_losses(self, loss_func, ests, targets, **kwargs):
B, n_src, _ = targets.shape
pair_wise_losses = targets.new_empty(B, n_src, n_src)
for est_idx, est_src in enumerate(ests.transpose(0, 1)):
for target_idx, target_src in enumerate(targets.transpose(0, 1)):
pair_wise_losses[:, est_idx, target_idx] = loss_func(
est_src, target_src, **kwargs
)
return pair_wise_losses
def best_perm_from_perm_avg_loss(self, loss_func, ests, targets, **kwargs):
n_src = targets.shape[1]
perms = torch.tensor(list(permutations(range(n_src))), dtype=torch.long)
# import pdb; pdb.set_trace()
loss_set = torch.stack(
[loss_func(ests[:, perm], targets) for perm in perms], dim=1
)
min_loss, min_loss_idx = torch.min(loss_set, dim=1)
batch_indices = torch.stack([perms[m] for m in min_loss_idx], dim=0)
return min_loss, batch_indices
def reordered_sources(self, sources, batch_indices):
reordered_sources = torch.stack(
[torch.index_select(s, 0, b) for s, b in zip(sources, batch_indices)]
)
return reordered_sources
def find_best_perm(self, pair_wise_losses, perm_reduce=None, **kwargs):
n_src = pair_wise_losses.shape[-1]
if perm_reduce is not None or n_src <= 3:
min_loss, batch_indices = self.find_best_perm_factorial(
pair_wise_losses, perm_reduce=perm_reduce, **kwargs
)
else:
min_loss, batch_indices = self.find_best_perm_hungarian(pair_wise_losses)
return min_loss, batch_indices
def find_best_perm_factorial(self, pair_wise_losses, perm_reduce=None, **kwargs):
n_src = pair_wise_losses.shape[-1]
# After transposition, dim 1 corresp. to sources and dim 2 to estimates
pwl = pair_wise_losses.transpose(-1, -2)
perms = pwl.new_tensor(list(permutations(range(n_src))), dtype=torch.long)
# Column permutation indices
idx = torch.unsqueeze(perms, 2)
# Loss mean of each permutation
if perm_reduce is None:
# one-hot, [n_src!, n_src, n_src]
# import pdb; pdb.set_trace()
perms_one_hot = pwl.new_zeros((*perms.size(), n_src)).scatter_(2, idx, 1)
loss_set = torch.einsum("bij,pij->bp", [pwl, perms_one_hot])
loss_set /= n_src
else:
# batch = pwl.shape[0]; n_perm = idx.shape[0]
# [batch, n_src!, n_src] : Pairwise losses for each permutation.
pwl_set = pwl[:, torch.arange(n_src), idx.squeeze(-1)]
# Apply reduce [batch, n_src!, n_src] --> [batch, n_src!]
loss_set = perm_reduce(pwl_set, **kwargs)
# Indexes and values of min losses for each batch element
min_loss, min_loss_idx = torch.min(loss_set, dim=1)
# Permutation indices for each batch.
batch_indices = torch.stack([perms[m] for m in min_loss_idx], dim=0)
return min_loss, batch_indices
def find_best_perm_hungarian(self, pair_wise_losses: torch.Tensor):
pwl = pair_wise_losses.transpose(-1, -2)
# Just bring the numbers to cpu(), not the graph
pwl_copy = pwl.detach().cpu()
# Loop over batch + row indices are always ordered for square matrices.
batch_indices = torch.tensor(
[linear_sum_assignment(pwl)[1] for pwl in pwl_copy]
).to(pwl.device)
min_loss = torch.gather(pwl, 2, batch_indices[..., None]).mean([-1, -2])
return min_loss, batch_indices
if __name__ == "__main__":
import torch
from matrix import pairwise_neg_sisdr, pairwise_neg_sisdr
ests = torch.randn(10, 2, 32000)
targets = torch.randn(10, 2, 32000)
pit_wrapper_1 = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
pit_wrapper_2 = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
print(pit_wrapper_1(ests, targets))
print(pit_wrapper_2(ests, targets))
|