File size: 8,362 Bytes
406f22d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from itertools import permutations
import torch
from torch import nn
from scipy.optimize import linear_sum_assignment


class PITLossWrapper(nn.Module):
    def __init__(
        self, loss_func, pit_from="pw_mtx", equidistant_weight=False, perm_reduce=None, threshold_byloss=True
    ):
        super().__init__()
        self.loss_func = loss_func
        self.pit_from = pit_from
        self.perm_reduce = perm_reduce
        self.threshold_byloss = threshold_byloss
        self.equidistant_weight = equidistant_weight
        if self.pit_from not in ["pw_mtx", "pw_pt", "perm_avg", "pw_mtx_broadcast", "pw_mtx_multidecoder_keepmtx", "pw_mtx_multidecoder_batchmin"]:
            raise ValueError(
                "Unsupported loss function type {} for now. Expected"
                "one of [`pw_mtx`, `pw_pt`, `perm_avg`, `pw_mtx_broadcast`]".format(self.pit_from)
            )

    def forward(self, ests, targets, return_ests=False, reduce_kwargs=None, **kwargs):
        n_src = targets.shape[1]
        if self.pit_from == "pw_mtx":
            pw_loss = self.loss_func(ests, targets, **kwargs)
        elif self.pit_from == "pw_mtx_broadcast":
            pw_loss = self.loss_func[0](ests, targets, **kwargs)
        elif self.pit_from == "pw_mtx_multidecoder_keepmtx":
            ests_last_block = ests[-1]
            pw_loss = self.loss_func[0](ests_last_block, targets, **kwargs)
        elif self.pit_from == "pw_mtx_multidecoder_batchmin":
            blocks_num = len(ests)
            ests = torch.cat(ests, dim=0)
            targets = torch.cat([targets] * blocks_num, dim=0)
            pw_loss = self.loss_func(ests, targets, **kwargs)
        elif self.pit_from == "pw_pt":
            pw_loss = self.get_pw_losses(self.loss_func, ests, targets, **kwargs)
        elif self.pit_from == "perm_avg":
            min_loss, batch_indices = self.best_perm_from_perm_avg_loss(
                self.loss_func, ests, targets, **kwargs
            )
            # print(batch_indices)
            mean_loss = torch.mean(min_loss)
            if not return_ests:
                return mean_loss
            reordered = self.reordered_sources(ests, batch_indices)
            return mean_loss, reordered
        else:
            return
        # import pdb; pdb.set_trace()
        assert pw_loss.ndim == 3, (
            "Something went wrong with the loss " "function, please read the docs."
        )
        assert (
            pw_loss.shape[0] == targets.shape[0]
        ), "PIT loss needs same batch dim as input"

        reduce_kwargs = reduce_kwargs if reduce_kwargs is not None else dict()
        min_loss, batch_indices = self.find_best_perm(
            pw_loss, perm_reduce=self.perm_reduce, **reduce_kwargs
        )
        if self.pit_from == "pw_mtx_multidecoder_keepmtx":
            reordered = []
            mean_loss = 0
            for i in range(len(ests)):
                reordered_ests_each_block = self.reordered_sources(ests[i], batch_indices)
                reordered.append(reordered_ests_each_block)
                loss_each_block = self.loss_func[1](reordered_ests_each_block, targets, **kwargs)
                if self.threshold_byloss:
                    if loss_each_block[loss_each_block > -30].nelement() > 0:
                        loss_each_block = loss_each_block[loss_each_block > -30]
                if self.equidistant_weight:
                    mean_loss = mean_loss + (i + 1) * 1 / len(ests) * loss_each_block.mean()
                else:
                    mean_loss = mean_loss + 1 / len(ests) * loss_each_block.mean()
            reordered = torch.cat(reordered, dim=0)
            if not return_ests:
                return mean_loss
            return mean_loss, reordered
        else:
            if self.threshold_byloss:
                if min_loss[min_loss > -30].nelement() > 0:
                    min_loss = min_loss[min_loss > -30]
            mean_loss = torch.mean(min_loss)
            reordered = self.reordered_sources(ests, batch_indices)
            # import pdb; pdb.set_trace()
            if self.pit_from == "pw_mtx_broadcast":
                mean_loss += 0.5 * self.loss_func[1](reordered, targets, **kwargs).mean()
            if not return_ests:
                return mean_loss
            return mean_loss, reordered

    def get_pw_losses(self, loss_func, ests, targets, **kwargs):
        B, n_src, _ = targets.shape
        pair_wise_losses = targets.new_empty(B, n_src, n_src)
        for est_idx, est_src in enumerate(ests.transpose(0, 1)):
            for target_idx, target_src in enumerate(targets.transpose(0, 1)):
                pair_wise_losses[:, est_idx, target_idx] = loss_func(
                    est_src, target_src, **kwargs
                )
        return pair_wise_losses

    def best_perm_from_perm_avg_loss(self, loss_func, ests, targets, **kwargs):
        n_src = targets.shape[1]
        perms = torch.tensor(list(permutations(range(n_src))), dtype=torch.long)
        # import pdb; pdb.set_trace()
        loss_set = torch.stack(
            [loss_func(ests[:, perm], targets) for perm in perms], dim=1
        )
        min_loss, min_loss_idx = torch.min(loss_set, dim=1)
        batch_indices = torch.stack([perms[m] for m in min_loss_idx], dim=0)
        return min_loss, batch_indices

    def reordered_sources(self, sources, batch_indices):
        reordered_sources = torch.stack(
            [torch.index_select(s, 0, b) for s, b in zip(sources, batch_indices)]
        )
        return reordered_sources

    def find_best_perm(self, pair_wise_losses, perm_reduce=None, **kwargs):
        n_src = pair_wise_losses.shape[-1]
        if perm_reduce is not None or n_src <= 3:
            min_loss, batch_indices = self.find_best_perm_factorial(
                pair_wise_losses, perm_reduce=perm_reduce, **kwargs
            )
        else:
            min_loss, batch_indices = self.find_best_perm_hungarian(pair_wise_losses)
        return min_loss, batch_indices

    def find_best_perm_factorial(self, pair_wise_losses, perm_reduce=None, **kwargs):
        n_src = pair_wise_losses.shape[-1]
        # After transposition, dim 1 corresp. to sources and dim 2 to estimates
        pwl = pair_wise_losses.transpose(-1, -2)
        perms = pwl.new_tensor(list(permutations(range(n_src))), dtype=torch.long)
        # Column permutation indices
        idx = torch.unsqueeze(perms, 2)
        # Loss mean of each permutation
        if perm_reduce is None:
            # one-hot, [n_src!, n_src, n_src]
            # import pdb; pdb.set_trace()
            perms_one_hot = pwl.new_zeros((*perms.size(), n_src)).scatter_(2, idx, 1)
            loss_set = torch.einsum("bij,pij->bp", [pwl, perms_one_hot])
            loss_set /= n_src
        else:
            # batch = pwl.shape[0]; n_perm = idx.shape[0]
            # [batch, n_src!, n_src] : Pairwise losses for each permutation.
            pwl_set = pwl[:, torch.arange(n_src), idx.squeeze(-1)]
            # Apply reduce [batch, n_src!, n_src] --> [batch, n_src!]
            loss_set = perm_reduce(pwl_set, **kwargs)
        # Indexes and values of min losses for each batch element
        min_loss, min_loss_idx = torch.min(loss_set, dim=1)

        # Permutation indices for each batch.
        batch_indices = torch.stack([perms[m] for m in min_loss_idx], dim=0)
        return min_loss, batch_indices

    def find_best_perm_hungarian(self, pair_wise_losses: torch.Tensor):
        pwl = pair_wise_losses.transpose(-1, -2)
        # Just bring the numbers to cpu(), not the graph
        pwl_copy = pwl.detach().cpu()
        # Loop over batch + row indices are always ordered for square matrices.
        batch_indices = torch.tensor(
            [linear_sum_assignment(pwl)[1] for pwl in pwl_copy]
        ).to(pwl.device)
        min_loss = torch.gather(pwl, 2, batch_indices[..., None]).mean([-1, -2])
        return min_loss, batch_indices


if __name__ == "__main__":
    import torch
    from matrix import pairwise_neg_sisdr, pairwise_neg_sisdr

    ests = torch.randn(10, 2, 32000)
    targets = torch.randn(10, 2, 32000)

    pit_wrapper_1 = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
    pit_wrapper_2 = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
    print(pit_wrapper_1(ests, targets))
    print(pit_wrapper_2(ests, targets))