lorocksUMD commited on
Commit
bf82294
·
verified ·
1 Parent(s): c2d138b

Update DenseAV/denseav/aggregators.py

Browse files
Files changed (1) hide show
  1. DenseAV/denseav/aggregators.py +518 -517
DenseAV/denseav/aggregators.py CHANGED
@@ -1,517 +1,518 @@
1
- from abc import abstractmethod
2
-
3
- import math
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- from tqdm import tqdm
8
-
9
- from DenseAV.denseav.constants import *
10
-
11
-
12
- @torch.jit.script
13
- def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int):
14
- mask = mask.to(x)
15
- return (x * mask).sum(dim, keepdim=True) / mask.sum(dim, keepdim=True).clamp_min(.001)
16
-
17
-
18
- @torch.jit.script
19
- def masked_max(x: torch.Tensor, mask: torch.Tensor, dim: int):
20
- mask = mask.to(torch.bool)
21
- eps = 1e7
22
- # eps = torch.finfo(x.dtype).max
23
- return (x - (~mask) * eps).max(dim, keepdim=True).values
24
-
25
-
26
- def masked_lse(x: torch.Tensor, mask: torch.Tensor, dim: int, temp):
27
- x = x.to(torch.float32)
28
- mask = mask.to(torch.float32)
29
- x_masked = (x - (1 - mask) * torch.finfo(x.dtype).max)
30
- return (torch.logsumexp(x_masked * temp, dim, keepdim=True) - torch.log(mask.sum(dim, keepdim=True))) / temp
31
-
32
-
33
- class BaseAggregator(torch.nn.Module):
34
-
35
- def __init__(self, nonneg_sim, mask_silence, num_heads, head_agg, use_cls):
36
- super().__init__()
37
-
38
- self.nonneg_sim = nonneg_sim
39
- self.mask_silence = mask_silence
40
- self.num_heads = num_heads
41
- self.head_agg = head_agg
42
- self.use_cls = use_cls
43
-
44
- @abstractmethod
45
- def _agg_sim(self, sim, mask):
46
- pass
47
-
48
- def prepare_sims(self, sim, mask, agg_sim, agg_heads):
49
- sim_size = sim.shape
50
- assert len(mask.shape) == 2
51
- assert len(sim_size) in {6, 7}, f"sim has wrong number of dimensions: {sim.shape}"
52
- pairwise = len(sim_size) == 6
53
-
54
- if self.mask_silence:
55
- mask = mask
56
- else:
57
- mask = torch.ones_like(mask)
58
-
59
- if self.nonneg_sim:
60
- sim = sim.clamp_min(0)
61
-
62
- if pairwise:
63
- head_dim = 1
64
- else:
65
- head_dim = 2
66
-
67
- if self.head_agg == "max_elementwise" and agg_heads:
68
- sim = sim.max(head_dim, keepdim=True).values
69
-
70
- if agg_sim:
71
- sim = self._agg_sim(sim, mask)
72
-
73
- if agg_heads:
74
- if self.head_agg == "sum" or self.head_agg == "max_elementwise":
75
- sim = sim.sum(head_dim)
76
- elif self.head_agg == "max":
77
- sim = sim.max(head_dim).values
78
- else:
79
- raise ValueError(f"Unknown head_agg: {self.head_agg}")
80
-
81
- return sim
82
-
83
- def _get_full_sims(self, preds, raw, agg_sim, agg_heads):
84
- if agg_sim or agg_heads or raw:
85
- assert (agg_sim or agg_heads) != raw, "Cannot have raw on at the same time as agg_sim or agg_heads"
86
-
87
- audio_feats = preds[AUDIO_FEATS]
88
- audio_mask = preds[AUDIO_MASK]
89
- image_feats = preds[IMAGE_FEATS]
90
-
91
- b1, c2, f, t1 = audio_feats.shape
92
- b2, t2 = audio_mask.shape
93
- d, c1, h, w = image_feats.shape
94
- assert b1 == b2 and c1 == c2 and t1 == t2
95
- assert c1 % self.num_heads == 0
96
- new_c = c1 // self.num_heads
97
- audio_feats = audio_feats.reshape(b1, self.num_heads, new_c, f, t1)
98
- image_feats = image_feats.reshape(d, self.num_heads, new_c, h, w)
99
- raw_sims = torch.einsum(
100
- "akcft,vkchw->avkhwft",
101
- audio_feats.to(torch.float32),
102
- image_feats.to(torch.float32))
103
-
104
- if self.use_cls:
105
- audio_cls = preds[AUDIO_CLS].reshape(b1, self.num_heads, new_c)
106
- image_cls = preds[IMAGE_CLS].reshape(d, self.num_heads, new_c)
107
- cls_sims = torch.einsum(
108
- "akc,vkc->avk",
109
- audio_cls.to(torch.float32),
110
- image_cls.to(torch.float32))
111
- raw_sims += cls_sims.reshape(b1, d, self.num_heads, 1, 1, 1, 1)
112
-
113
- if raw:
114
- return raw_sims
115
- else:
116
- return self.prepare_sims(raw_sims, audio_mask, agg_sim, agg_heads)
117
-
118
- def get_pairwise_sims(self, preds, raw, agg_sim, agg_heads):
119
- if agg_sim or agg_heads or raw:
120
- assert (agg_sim or agg_heads) != raw, "Cannot have raw on at the same time as agg_sim or agg_heads"
121
-
122
- audio_feats = preds[AUDIO_FEATS]
123
- audio_mask = preds[AUDIO_MASK]
124
- image_feats = preds[IMAGE_FEATS]
125
-
126
- a1, c1, f, t1 = audio_feats.shape
127
- a2, t2 = audio_mask.shape
128
-
129
- assert c1 % self.num_heads == 0
130
- new_c = c1 // self.num_heads
131
- audio_feats = audio_feats.reshape(a1, self.num_heads, new_c, f, t1)
132
-
133
- if len(image_feats.shape) == 5:
134
- print("Using similarity for video, should only be called during plotting")
135
- v, vt, c2, h, w = image_feats.shape
136
- image_feats = image_feats.reshape(v, vt, self.num_heads, new_c, h, w)
137
- raw_sims = torch.einsum(
138
- "bkcft,bskchw,bt->bskhwft",
139
- audio_feats.to(torch.float32),
140
- image_feats.to(torch.float32),
141
- audio_mask.to(torch.float32))
142
-
143
- if self.use_cls:
144
- audio_cls = preds[AUDIO_CLS].reshape(v, self.num_heads, new_c)
145
- image_cls = preds[IMAGE_CLS].reshape(v, vt, self.num_heads, new_c)
146
- cls_sims = torch.einsum(
147
- "bkc,bskc->bsk",
148
- audio_cls.to(torch.float32),
149
- image_cls.to(torch.float32))
150
- raw_sims += cls_sims.reshape(v, vt, self.num_heads, 1, 1, 1, 1)
151
-
152
-
153
- elif len(image_feats.shape) == 4:
154
- v, c2, h, w = image_feats.shape
155
- image_feats = image_feats.reshape(v, self.num_heads, new_c, h, w)
156
- raw_sims = torch.einsum(
157
- "bkcft,bkchw,bt->bkhwft",
158
- audio_feats.to(torch.float32),
159
- image_feats.to(torch.float32),
160
- audio_mask.to(torch.float32))
161
-
162
- if self.use_cls:
163
- audio_cls = preds[AUDIO_CLS].reshape(v, self.num_heads, new_c)
164
- image_cls = preds[IMAGE_CLS].reshape(v, self.num_heads, new_c)
165
- cls_sims = torch.einsum(
166
- "bkc,bkc->bk",
167
- audio_cls.to(torch.float32),
168
- image_cls.to(torch.float32))
169
- raw_sims += cls_sims.reshape(v, self.num_heads, 1, 1, 1, 1)
170
- else:
171
- raise ValueError(f"Improper image shape: {image_feats.shape}")
172
-
173
- assert a1 == a2 and c2 == c2 and t1 == t2
174
-
175
- if raw:
176
- return raw_sims
177
- else:
178
- return self.prepare_sims(raw_sims, audio_mask, agg_sim, agg_heads)
179
-
180
- def forward(self, preds, agg_heads):
181
- return self._get_full_sims(
182
- preds, raw=False, agg_sim=True, agg_heads=agg_heads)
183
-
184
- def forward_batched(self, preds, agg_heads, batch_size):
185
- new_preds = {k: v for k, v in preds.items()}
186
- big_image_feats = new_preds.pop(IMAGE_FEATS)
187
- if self.use_cls:
188
- big_image_cls = new_preds.pop(IMAGE_CLS)
189
-
190
- n = big_image_feats.shape[0]
191
- n_steps = math.ceil(n / batch_size)
192
- outputs = []
193
- for step in tqdm(range(n_steps), "Calculating Sim", leave=False):
194
- new_preds[IMAGE_FEATS] = big_image_feats[step * batch_size:(step + 1) * batch_size].cuda()
195
- if self.use_cls:
196
- new_preds[IMAGE_CLS] = big_image_cls[step * batch_size:(step + 1) * batch_size].cuda()
197
-
198
- sim = self.forward(new_preds, agg_heads=agg_heads)
199
- outputs.append(sim.cpu())
200
- return torch.cat(outputs, dim=1)
201
-
202
-
203
- class ImageThenAudioAggregator(BaseAggregator):
204
-
205
- def __init__(self, image_agg_type, audio_agg_type, nonneg_sim, mask_silence, num_heads, head_agg, use_cls):
206
- super().__init__(nonneg_sim, mask_silence, num_heads, head_agg, use_cls)
207
- if image_agg_type == "max":
208
- self.image_agg = lambda x, dim: x.max(dim=dim, keepdim=True).values
209
- elif image_agg_type == "avg":
210
- self.image_agg = lambda x, dim: x.mean(dim=dim, keepdim=True)
211
- else:
212
- raise ValueError(f"Unknown image_agg_type {image_agg_type}")
213
-
214
- if audio_agg_type == "max":
215
- self.time_agg = masked_max
216
- elif audio_agg_type == "avg":
217
- self.time_agg = masked_mean
218
- else:
219
- raise ValueError(f"Unknown audio_agg_type {audio_agg_type}")
220
-
221
- self.freq_agg = lambda x, dim: x.mean(dim=dim, keepdim=True)
222
-
223
- def _agg_sim(self, sim, mask):
224
- sim_shape = sim.shape
225
- new_mask_shape = [1] * len(sim_shape)
226
- new_mask_shape[0] = sim_shape[0]
227
- new_mask_shape[-1] = sim_shape[-1]
228
- mask = mask.reshape(new_mask_shape)
229
- sim = self.image_agg(sim, -3)
230
- sim = self.image_agg(sim, -4)
231
- sim = self.freq_agg(sim, -2)
232
- sim = self.time_agg(sim, mask, -1)
233
- return sim.squeeze(-1).squeeze(-1).squeeze(-1).squeeze(-1)
234
-
235
-
236
- class PairedAggregator(BaseAggregator):
237
-
238
- def __init__(self, nonneg_sim, mask_silence, num_heads, head_agg, use_cls):
239
- super().__init__(nonneg_sim, mask_silence, num_heads, head_agg, use_cls)
240
- self.image_agg_max = lambda x, dim: x.max(dim=dim, keepdim=True).values
241
- self.image_agg_mean = lambda x, dim: x.mean(dim=dim, keepdim=True)
242
-
243
- self.time_agg_max = masked_max
244
- self.time_agg_mean = masked_mean
245
-
246
- self.freq_agg = lambda x, dim: x.mean(dim=dim, keepdim=True)
247
-
248
- def _agg_sim(self, sim, mask):
249
- sim_shape = sim.shape
250
- new_mask_shape = [1] * len(sim_shape)
251
- new_mask_shape[0] = sim_shape[0]
252
- new_mask_shape[-1] = sim_shape[-1]
253
- mask = mask.reshape(new_mask_shape)
254
-
255
- sim_1 = self.image_agg_max(sim, -3)
256
- sim_1 = self.image_agg_max(sim_1, -4)
257
- sim_1 = self.freq_agg(sim_1, -2)
258
- sim_1 = self.time_agg_mean(sim_1, mask, -1)
259
-
260
- sim_2 = self.freq_agg(sim, -2)
261
- sim_2 = self.time_agg_max(sim_2, mask, -1)
262
- sim_2 = self.image_agg_mean(sim_2, -3)
263
- sim_2 = self.image_agg_mean(sim_2, -4)
264
-
265
- sim = 1 / 2 * (sim_1 + sim_2)
266
-
267
- return sim.squeeze(-1).squeeze(-1).squeeze(-1).squeeze(-1)
268
-
269
-
270
-
271
- class CAVMAEAggregator(BaseAggregator):
272
-
273
- def __init__(self, *args, **kwargs):
274
- super().__init__(False, False, 1, "sum", False)
275
-
276
- def _get_full_sims(self, preds, raw, agg_sim, agg_heads):
277
- if agg_sim:
278
- audio_feats = preds[AUDIO_FEATS]
279
- image_feats = preds[IMAGE_FEATS]
280
- pool_audio_feats = F.normalize(audio_feats.mean(dim=[-1, -2]), dim=1)
281
- pool_image_feats = F.normalize(image_feats.mean(dim=[-1, -2]), dim=1)
282
- sims = torch.einsum(
283
- "bc,dc->bd",
284
- pool_audio_feats.to(torch.float32),
285
- pool_image_feats.to(torch.float32))
286
- if agg_heads:
287
- return sims
288
- else:
289
- return sims.unsqueeze(-1)
290
-
291
- else:
292
- return BaseAggregator._get_full_sims(self, preds, raw, agg_sim, agg_heads)
293
-
294
- def get_pairwise_sims(self, preds, raw, agg_sim, agg_heads):
295
- if agg_sim:
296
- audio_feats = preds[AUDIO_FEATS]
297
- image_feats = preds[IMAGE_FEATS]
298
- pool_audio_feats = F.normalize(audio_feats.mean(dim=[-1, -2]), dim=1)
299
- pool_image_feats = F.normalize(image_feats.mean(dim=[-1, -2]), dim=1)
300
- sims = torch.einsum(
301
- "bc,bc->b",
302
- pool_audio_feats.to(torch.float32),
303
- pool_image_feats.to(torch.float32))
304
- if agg_heads:
305
- return sims
306
- else:
307
- return sims.unsqueeze(-1)
308
-
309
- else:
310
- return BaseAggregator.get_pairwise_sims(self, preds, raw, agg_sim, agg_heads)
311
-
312
-
313
- class ImageBindAggregator(BaseAggregator):
314
-
315
- def __init__(self, num_heads, *args, **kwargs):
316
- super().__init__(False, False, num_heads, "sum", False)
317
-
318
- def _get_full_sims(self, preds, raw, agg_sim, agg_heads):
319
- if agg_sim:
320
- sims = torch.einsum(
321
- "bc,dc->bd",
322
- preds[AUDIO_CLS].to(torch.float32),
323
- preds[IMAGE_CLS].to(torch.float32))
324
- if agg_heads:
325
- return sims
326
- else:
327
- sims = sims.unsqueeze(-1)
328
- return sims.repeat(*([1] * (sims.dim() - 1)), self.num_heads)
329
-
330
-
331
- else:
332
- return BaseAggregator._get_full_sims(self, preds, raw, agg_sim, agg_heads)
333
-
334
- def get_pairwise_sims(self, preds, raw, agg_sim, agg_heads):
335
- if agg_sim:
336
- sims = torch.einsum(
337
- "bc,dc->b",
338
- preds[AUDIO_CLS].to(torch.float32),
339
- preds[IMAGE_CLS].to(torch.float32))
340
- if agg_heads:
341
- return sims
342
- else:
343
- sims = sims.unsqueeze(-1)
344
- return sims.repeat(*([1] * (sims.dim() - 1)), self.num_heads)
345
-
346
- else:
347
- return BaseAggregator.get_pairwise_sims(self, preds, raw, agg_sim, agg_heads)
348
-
349
- def forward_batched(self, preds, agg_heads, batch_size):
350
- return self.forward(preds, agg_heads)
351
-
352
-
353
- class SimPool(nn.Module):
354
- def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, gamma=None, use_beta=False):
355
- super().__init__()
356
- self.num_heads = num_heads
357
- head_dim = dim // num_heads
358
- self.scale = qk_scale or head_dim ** -0.5
359
-
360
- self.norm_patches = nn.LayerNorm(dim, eps=1e-6)
361
-
362
- self.wq = nn.Linear(dim, dim, bias=qkv_bias)
363
- self.wk = nn.Linear(dim, dim, bias=qkv_bias)
364
-
365
- if gamma is not None:
366
- self.gamma = torch.tensor([gamma])
367
- if use_beta:
368
- self.beta = nn.Parameter(torch.tensor([0.0]))
369
- self.eps = torch.tensor([1e-6])
370
-
371
- self.gamma = gamma
372
- self.use_beta = use_beta
373
-
374
- def prepare_input(self, x):
375
- if len(x.shape) == 3: # Transformer
376
- # Input tensor dimensions:
377
- # x: (B, N, d), where B is batch size, N are patch tokens, d is depth (channels)
378
- B, N, d = x.shape
379
- gap_cls = x.mean(-2) # (B, N, d) -> (B, d)
380
- gap_cls = gap_cls.unsqueeze(1) # (B, d) -> (B, 1, d)
381
- return gap_cls, x
382
- if len(x.shape) == 4: # CNN
383
- # Input tensor dimensions:
384
- # x: (B, d, H, W), where B is batch size, d is depth (channels), H is height, and W is width
385
- B, d, H, W = x.shape
386
- gap_cls = x.mean([-2, -1]) # (B, d, H, W) -> (B, d)
387
- x = x.reshape(B, d, H * W).permute(0, 2, 1) # (B, d, H, W) -> (B, d, H*W) -> (B, H*W, d)
388
- gap_cls = gap_cls.unsqueeze(1) # (B, d) -> (B, 1, d)
389
- return gap_cls, x
390
- else:
391
- raise ValueError(f"Unsupported number of dimensions in input tensor: {len(x.shape)}")
392
-
393
- def forward(self, x):
394
- self.eps = self.eps.to(x.device)
395
- # Prepare input tensor and perform GAP as initialization
396
- gap_cls, x = self.prepare_input(x)
397
-
398
- # Prepare queries (q), keys (k), and values (v)
399
- q, k, v = gap_cls, self.norm_patches(x), self.norm_patches(x)
400
-
401
- # Extract dimensions after normalization
402
- Bq, Nq, dq = q.shape
403
- Bk, Nk, dk = k.shape
404
- Bv, Nv, dv = v.shape
405
-
406
- # Check dimension consistency across batches and channels
407
- assert Bq == Bk == Bv
408
- assert dq == dk == dv
409
-
410
- # Apply linear transformation for queries and keys then reshape
411
- qq = self.wq(q).reshape(Bq, Nq, self.num_heads, dq // self.num_heads).permute(0, 2, 1,
412
- 3) # (Bq, Nq, dq) -> (B, num_heads, Nq, dq/num_heads)
413
- kk = self.wk(k).reshape(Bk, Nk, self.num_heads, dk // self.num_heads).permute(0, 2, 1,
414
- 3) # (Bk, Nk, dk) -> (B, num_heads, Nk, dk/num_heads)
415
-
416
- vv = v.reshape(Bv, Nv, self.num_heads, dv // self.num_heads).permute(0, 2, 1,
417
- 3) # (Bv, Nv, dv) -> (B, num_heads, Nv, dv/num_heads)
418
-
419
- # Compute attention scores
420
- attn = (qq @ kk.transpose(-2, -1)) * self.scale
421
- # Apply softmax for normalization
422
- attn = attn.softmax(dim=-1)
423
-
424
- # If gamma scaling is used
425
- if self.gamma is not None:
426
- # Apply gamma scaling on values and compute the weighted sum using attention scores
427
- x = torch.pow(attn @ torch.pow((vv - vv.min() + self.eps), self.gamma),
428
- 1 / self.gamma) # (B, num_heads, Nv, dv/num_heads) -> (B, 1, 1, d)
429
- # If use_beta, add a learnable translation
430
- if self.use_beta:
431
- x = x + self.beta
432
- else:
433
- # Compute the weighted sum using attention scores
434
- x = (attn @ vv).transpose(1, 2).reshape(Bq, Nq, dq)
435
-
436
- return x.squeeze()
437
-
438
-
439
-
440
- class SimPoolAggregator(BaseAggregator):
441
-
442
- def __init__(self, num_heads, dim, *args, **kwargs):
443
- super().__init__(False, False, num_heads, "sum", False)
444
- self.pool = SimPool(dim, gamma=1.25)
445
-
446
- def _get_full_sims(self, preds, raw, agg_sim, agg_heads):
447
- if agg_sim:
448
- device = self.pool.wq.weight.data.device
449
- pooled_audio = self.pool(preds[AUDIO_FEATS].to(torch.float32).to(device))
450
- pooled_image = self.pool(preds[IMAGE_FEATS].to(torch.float32).to(device))
451
-
452
- sims = torch.einsum(
453
- "bc,dc->bd",
454
- pooled_audio,
455
- pooled_image)
456
- if agg_heads:
457
- return sims
458
- else:
459
- sims = sims.unsqueeze(-1)
460
- return sims.repeat(*([1] * (sims.dim() - 1)), self.num_heads)
461
-
462
-
463
- else:
464
- return BaseAggregator._get_full_sims(self, preds, raw, agg_sim, agg_heads)
465
-
466
- def get_pairwise_sims(self, preds, raw, agg_sim, agg_heads):
467
- if agg_sim:
468
- device = self.pool.wq.weight.data.device
469
- pooled_audio = self.pool(preds[AUDIO_FEATS].to(torch.float32).to(device))
470
- pooled_image = self.pool(preds[IMAGE_FEATS].to(torch.float32).to(device))
471
-
472
- sims = torch.einsum(
473
- "bc,dc->b",
474
- pooled_audio,
475
- pooled_image)
476
- if agg_heads:
477
- return sims
478
- else:
479
- sims = sims.unsqueeze(-1)
480
- return sims.repeat(*([1] * (sims.dim() - 1)), self.num_heads)
481
-
482
- else:
483
- return BaseAggregator.get_pairwise_sims(self, preds, raw, agg_sim, agg_heads)
484
-
485
- def forward_batched(self, preds, agg_heads, batch_size):
486
- return self.forward(preds, agg_heads)
487
-
488
-
489
-
490
- def get_aggregator(sim_agg_type, nonneg_sim, mask_silence, num_heads, head_agg, use_cls, dim):
491
- shared_args = dict(
492
- nonneg_sim=nonneg_sim,
493
- mask_silence=mask_silence,
494
- num_heads=num_heads,
495
- head_agg=head_agg,
496
- use_cls=use_cls,
497
- )
498
-
499
- if sim_agg_type == "paired":
500
- agg1 = PairedAggregator(**shared_args)
501
- elif sim_agg_type == "misa":
502
- agg1 = ImageThenAudioAggregator("max", "avg", **shared_args)
503
- elif sim_agg_type == "mima":
504
- agg1 = ImageThenAudioAggregator("max", "max", **shared_args)
505
- elif sim_agg_type == "sisa":
506
- agg1 = ImageThenAudioAggregator("avg", "avg", **shared_args)
507
- elif sim_agg_type == "cavmae":
508
- agg1 = CAVMAEAggregator()
509
- elif sim_agg_type == "imagebind":
510
- agg1 = ImageBindAggregator(num_heads=shared_args["num_heads"])
511
- elif sim_agg_type == "simpool":
512
- agg1 = SimPoolAggregator(num_heads=shared_args["num_heads"], dim=dim)
513
- else:
514
- raise ValueError(f"Unknown loss_type {sim_agg_type}")
515
-
516
- return agg1
517
-
 
 
1
+ from abc import abstractmethod
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from tqdm import tqdm
8
+
9
+ from DenseAV.denseav.constants import *
10
+
11
+
12
+ @torch.jit.script
13
+ def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int):
14
+ mask = mask.to(x)
15
+ return (x * mask).sum(dim, keepdim=True) / mask.sum(dim, keepdim=True).clamp_min(.001)
16
+
17
+
18
+ @torch.jit.script
19
+ def masked_max(x: torch.Tensor, mask: torch.Tensor, dim: int):
20
+ mask = mask.to(torch.bool)
21
+ eps = 1e7
22
+ # eps = torch.finfo(x.dtype).max
23
+ return (x - (~mask) * eps).max(dim, keepdim=True).values
24
+
25
+
26
+ def masked_lse(x: torch.Tensor, mask: torch.Tensor, dim: int, temp):
27
+ x = x.to(torch.float32)
28
+ mask = mask.to(torch.float32)
29
+ x_masked = (x - (1 - mask) * torch.finfo(x.dtype).max)
30
+ return (torch.logsumexp(x_masked * temp, dim, keepdim=True) - torch.log(mask.sum(dim, keepdim=True))) / temp
31
+
32
+
33
+ class BaseAggregator(torch.nn.Module):
34
+
35
+ def __init__(self, nonneg_sim, mask_silence, num_heads, head_agg, use_cls):
36
+ super().__init__()
37
+
38
+ self.nonneg_sim = nonneg_sim
39
+ self.mask_silence = mask_silence
40
+ self.num_heads = num_heads
41
+ self.head_agg = head_agg
42
+ self.use_cls = use_cls
43
+
44
+ @abstractmethod
45
+ def _agg_sim(self, sim, mask):
46
+ pass
47
+
48
+ def prepare_sims(self, sim, mask, agg_sim, agg_heads):
49
+ sim_size = sim.shape
50
+ assert len(mask.shape) == 2
51
+ assert len(sim_size) in {6, 7}, f"sim has wrong number of dimensions: {sim.shape}"
52
+ pairwise = len(sim_size) == 6
53
+
54
+ if self.mask_silence:
55
+ mask = mask
56
+ else:
57
+ mask = torch.ones_like(mask)
58
+
59
+ if self.nonneg_sim:
60
+ sim = sim.clamp_min(0)
61
+
62
+ if pairwise:
63
+ head_dim = 1
64
+ else:
65
+ head_dim = 2
66
+
67
+ if self.head_agg == "max_elementwise" and agg_heads:
68
+ sim = sim.max(head_dim, keepdim=True).values
69
+
70
+ if agg_sim:
71
+ sim = self._agg_sim(sim, mask)
72
+
73
+ if agg_heads:
74
+ if self.head_agg == "sum" or self.head_agg == "max_elementwise":
75
+ sim = sim.sum(head_dim)
76
+ elif self.head_agg == "max":
77
+ sim = sim.max(head_dim).values
78
+ else:
79
+ raise ValueError(f"Unknown head_agg: {self.head_agg}")
80
+
81
+ return sim
82
+
83
+ def _get_full_sims(self, preds, raw, agg_sim, agg_heads):
84
+ if agg_sim or agg_heads or raw:
85
+ assert (agg_sim or agg_heads) != raw, "Cannot have raw on at the same time as agg_sim or agg_heads"
86
+
87
+ audio_feats = preds[AUDIO_FEATS]
88
+ audio_mask = preds[AUDIO_MASK]
89
+ image_feats = preds[IMAGE_FEATS]
90
+
91
+ b1, c2, f, t1 = audio_feats.shape
92
+ b2, t2 = audio_mask.shape
93
+ d, c1, h, w = image_feats.shape
94
+ assert b1 == b2 and c1 == c2 and t1 == t2
95
+ assert c1 % self.num_heads == 0
96
+ new_c = c1 // self.num_heads
97
+ audio_feats = audio_feats.reshape(b1, self.num_heads, new_c, f, t1)
98
+ image_feats = image_feats.reshape(d, self.num_heads, new_c, h, w)
99
+ raw_sims = torch.einsum(
100
+ "akcft,vkchw->avkhwft",
101
+ audio_feats.to(torch.float32),
102
+ image_feats.to(torch.float32))
103
+
104
+ if self.use_cls:
105
+ audio_cls = preds[AUDIO_CLS].reshape(b1, self.num_heads, new_c)
106
+ image_cls = preds[IMAGE_CLS].reshape(d, self.num_heads, new_c)
107
+ cls_sims = torch.einsum(
108
+ "akc,vkc->avk",
109
+ audio_cls.to(torch.float32),
110
+ image_cls.to(torch.float32))
111
+ raw_sims += cls_sims.reshape(b1, d, self.num_heads, 1, 1, 1, 1)
112
+
113
+ if raw:
114
+ return raw_sims
115
+ else:
116
+ return self.prepare_sims(raw_sims, audio_mask, agg_sim, agg_heads)
117
+
118
+ def get_pairwise_sims(self, preds, raw, agg_sim, agg_heads):
119
+ if agg_sim or agg_heads or raw:
120
+ assert (agg_sim or agg_heads) != raw, "Cannot have raw on at the same time as agg_sim or agg_heads"
121
+
122
+ audio_feats = preds[AUDIO_FEATS]
123
+ audio_mask = preds[AUDIO_MASK]
124
+ image_feats = preds[IMAGE_FEATS]
125
+
126
+ a1, c1, f, t1 = audio_feats.shape
127
+ a2, t2 = audio_mask.shape
128
+
129
+ assert c1 % self.num_heads == 0
130
+ new_c = c1 // self.num_heads
131
+ audio_feats = audio_feats.reshape(a1, self.num_heads, new_c, f, t1)
132
+
133
+ if len(image_feats.shape) == 5:
134
+ print("Using similarity for video, should only be called during plotting")
135
+ v, vt, c2, h, w = image_feats.shape
136
+ image_feats = image_feats.reshape(v, vt, self.num_heads, new_c, h, w)
137
+ raw_sims = torch.einsum(
138
+ "bkcft,bskchw,bt->bskhwft",
139
+ audio_feats.to(torch.float32),
140
+ image_feats.to(torch.float32),
141
+ audio_mask.to(torch.float32))
142
+
143
+ if self.use_cls:
144
+ print(preds[AUDIO_CLS].shape)
145
+ audio_cls = preds[AUDIO_CLS].reshape(v, self.num_heads, new_c)
146
+ image_cls = preds[IMAGE_CLS].reshape(v, vt, self.num_heads, new_c)
147
+ cls_sims = torch.einsum(
148
+ "bkc,bskc->bsk",
149
+ audio_cls.to(torch.float32),
150
+ image_cls.to(torch.float32))
151
+ raw_sims += cls_sims.reshape(v, vt, self.num_heads, 1, 1, 1, 1)
152
+
153
+
154
+ elif len(image_feats.shape) == 4:
155
+ v, c2, h, w = image_feats.shape
156
+ image_feats = image_feats.reshape(v, self.num_heads, new_c, h, w)
157
+ raw_sims = torch.einsum(
158
+ "bkcft,bkchw,bt->bkhwft",
159
+ audio_feats.to(torch.float32),
160
+ image_feats.to(torch.float32),
161
+ audio_mask.to(torch.float32))
162
+
163
+ if self.use_cls:
164
+ audio_cls = preds[AUDIO_CLS].reshape(v, self.num_heads, new_c)
165
+ image_cls = preds[IMAGE_CLS].reshape(v, self.num_heads, new_c)
166
+ cls_sims = torch.einsum(
167
+ "bkc,bkc->bk",
168
+ audio_cls.to(torch.float32),
169
+ image_cls.to(torch.float32))
170
+ raw_sims += cls_sims.reshape(v, self.num_heads, 1, 1, 1, 1)
171
+ else:
172
+ raise ValueError(f"Improper image shape: {image_feats.shape}")
173
+
174
+ assert a1 == a2 and c2 == c2 and t1 == t2
175
+
176
+ if raw:
177
+ return raw_sims
178
+ else:
179
+ return self.prepare_sims(raw_sims, audio_mask, agg_sim, agg_heads)
180
+
181
+ def forward(self, preds, agg_heads):
182
+ return self._get_full_sims(
183
+ preds, raw=False, agg_sim=True, agg_heads=agg_heads)
184
+
185
+ def forward_batched(self, preds, agg_heads, batch_size):
186
+ new_preds = {k: v for k, v in preds.items()}
187
+ big_image_feats = new_preds.pop(IMAGE_FEATS)
188
+ if self.use_cls:
189
+ big_image_cls = new_preds.pop(IMAGE_CLS)
190
+
191
+ n = big_image_feats.shape[0]
192
+ n_steps = math.ceil(n / batch_size)
193
+ outputs = []
194
+ for step in tqdm(range(n_steps), "Calculating Sim", leave=False):
195
+ new_preds[IMAGE_FEATS] = big_image_feats[step * batch_size:(step + 1) * batch_size].cuda()
196
+ if self.use_cls:
197
+ new_preds[IMAGE_CLS] = big_image_cls[step * batch_size:(step + 1) * batch_size].cuda()
198
+
199
+ sim = self.forward(new_preds, agg_heads=agg_heads)
200
+ outputs.append(sim.cpu())
201
+ return torch.cat(outputs, dim=1)
202
+
203
+
204
+ class ImageThenAudioAggregator(BaseAggregator):
205
+
206
+ def __init__(self, image_agg_type, audio_agg_type, nonneg_sim, mask_silence, num_heads, head_agg, use_cls):
207
+ super().__init__(nonneg_sim, mask_silence, num_heads, head_agg, use_cls)
208
+ if image_agg_type == "max":
209
+ self.image_agg = lambda x, dim: x.max(dim=dim, keepdim=True).values
210
+ elif image_agg_type == "avg":
211
+ self.image_agg = lambda x, dim: x.mean(dim=dim, keepdim=True)
212
+ else:
213
+ raise ValueError(f"Unknown image_agg_type {image_agg_type}")
214
+
215
+ if audio_agg_type == "max":
216
+ self.time_agg = masked_max
217
+ elif audio_agg_type == "avg":
218
+ self.time_agg = masked_mean
219
+ else:
220
+ raise ValueError(f"Unknown audio_agg_type {audio_agg_type}")
221
+
222
+ self.freq_agg = lambda x, dim: x.mean(dim=dim, keepdim=True)
223
+
224
+ def _agg_sim(self, sim, mask):
225
+ sim_shape = sim.shape
226
+ new_mask_shape = [1] * len(sim_shape)
227
+ new_mask_shape[0] = sim_shape[0]
228
+ new_mask_shape[-1] = sim_shape[-1]
229
+ mask = mask.reshape(new_mask_shape)
230
+ sim = self.image_agg(sim, -3)
231
+ sim = self.image_agg(sim, -4)
232
+ sim = self.freq_agg(sim, -2)
233
+ sim = self.time_agg(sim, mask, -1)
234
+ return sim.squeeze(-1).squeeze(-1).squeeze(-1).squeeze(-1)
235
+
236
+
237
+ class PairedAggregator(BaseAggregator):
238
+
239
+ def __init__(self, nonneg_sim, mask_silence, num_heads, head_agg, use_cls):
240
+ super().__init__(nonneg_sim, mask_silence, num_heads, head_agg, use_cls)
241
+ self.image_agg_max = lambda x, dim: x.max(dim=dim, keepdim=True).values
242
+ self.image_agg_mean = lambda x, dim: x.mean(dim=dim, keepdim=True)
243
+
244
+ self.time_agg_max = masked_max
245
+ self.time_agg_mean = masked_mean
246
+
247
+ self.freq_agg = lambda x, dim: x.mean(dim=dim, keepdim=True)
248
+
249
+ def _agg_sim(self, sim, mask):
250
+ sim_shape = sim.shape
251
+ new_mask_shape = [1] * len(sim_shape)
252
+ new_mask_shape[0] = sim_shape[0]
253
+ new_mask_shape[-1] = sim_shape[-1]
254
+ mask = mask.reshape(new_mask_shape)
255
+
256
+ sim_1 = self.image_agg_max(sim, -3)
257
+ sim_1 = self.image_agg_max(sim_1, -4)
258
+ sim_1 = self.freq_agg(sim_1, -2)
259
+ sim_1 = self.time_agg_mean(sim_1, mask, -1)
260
+
261
+ sim_2 = self.freq_agg(sim, -2)
262
+ sim_2 = self.time_agg_max(sim_2, mask, -1)
263
+ sim_2 = self.image_agg_mean(sim_2, -3)
264
+ sim_2 = self.image_agg_mean(sim_2, -4)
265
+
266
+ sim = 1 / 2 * (sim_1 + sim_2)
267
+
268
+ return sim.squeeze(-1).squeeze(-1).squeeze(-1).squeeze(-1)
269
+
270
+
271
+
272
+ class CAVMAEAggregator(BaseAggregator):
273
+
274
+ def __init__(self, *args, **kwargs):
275
+ super().__init__(False, False, 1, "sum", False)
276
+
277
+ def _get_full_sims(self, preds, raw, agg_sim, agg_heads):
278
+ if agg_sim:
279
+ audio_feats = preds[AUDIO_FEATS]
280
+ image_feats = preds[IMAGE_FEATS]
281
+ pool_audio_feats = F.normalize(audio_feats.mean(dim=[-1, -2]), dim=1)
282
+ pool_image_feats = F.normalize(image_feats.mean(dim=[-1, -2]), dim=1)
283
+ sims = torch.einsum(
284
+ "bc,dc->bd",
285
+ pool_audio_feats.to(torch.float32),
286
+ pool_image_feats.to(torch.float32))
287
+ if agg_heads:
288
+ return sims
289
+ else:
290
+ return sims.unsqueeze(-1)
291
+
292
+ else:
293
+ return BaseAggregator._get_full_sims(self, preds, raw, agg_sim, agg_heads)
294
+
295
+ def get_pairwise_sims(self, preds, raw, agg_sim, agg_heads):
296
+ if agg_sim:
297
+ audio_feats = preds[AUDIO_FEATS]
298
+ image_feats = preds[IMAGE_FEATS]
299
+ pool_audio_feats = F.normalize(audio_feats.mean(dim=[-1, -2]), dim=1)
300
+ pool_image_feats = F.normalize(image_feats.mean(dim=[-1, -2]), dim=1)
301
+ sims = torch.einsum(
302
+ "bc,bc->b",
303
+ pool_audio_feats.to(torch.float32),
304
+ pool_image_feats.to(torch.float32))
305
+ if agg_heads:
306
+ return sims
307
+ else:
308
+ return sims.unsqueeze(-1)
309
+
310
+ else:
311
+ return BaseAggregator.get_pairwise_sims(self, preds, raw, agg_sim, agg_heads)
312
+
313
+
314
+ class ImageBindAggregator(BaseAggregator):
315
+
316
+ def __init__(self, num_heads, *args, **kwargs):
317
+ super().__init__(False, False, num_heads, "sum", False)
318
+
319
+ def _get_full_sims(self, preds, raw, agg_sim, agg_heads):
320
+ if agg_sim:
321
+ sims = torch.einsum(
322
+ "bc,dc->bd",
323
+ preds[AUDIO_CLS].to(torch.float32),
324
+ preds[IMAGE_CLS].to(torch.float32))
325
+ if agg_heads:
326
+ return sims
327
+ else:
328
+ sims = sims.unsqueeze(-1)
329
+ return sims.repeat(*([1] * (sims.dim() - 1)), self.num_heads)
330
+
331
+
332
+ else:
333
+ return BaseAggregator._get_full_sims(self, preds, raw, agg_sim, agg_heads)
334
+
335
+ def get_pairwise_sims(self, preds, raw, agg_sim, agg_heads):
336
+ if agg_sim:
337
+ sims = torch.einsum(
338
+ "bc,dc->b",
339
+ preds[AUDIO_CLS].to(torch.float32),
340
+ preds[IMAGE_CLS].to(torch.float32))
341
+ if agg_heads:
342
+ return sims
343
+ else:
344
+ sims = sims.unsqueeze(-1)
345
+ return sims.repeat(*([1] * (sims.dim() - 1)), self.num_heads)
346
+
347
+ else:
348
+ return BaseAggregator.get_pairwise_sims(self, preds, raw, agg_sim, agg_heads)
349
+
350
+ def forward_batched(self, preds, agg_heads, batch_size):
351
+ return self.forward(preds, agg_heads)
352
+
353
+
354
+ class SimPool(nn.Module):
355
+ def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, gamma=None, use_beta=False):
356
+ super().__init__()
357
+ self.num_heads = num_heads
358
+ head_dim = dim // num_heads
359
+ self.scale = qk_scale or head_dim ** -0.5
360
+
361
+ self.norm_patches = nn.LayerNorm(dim, eps=1e-6)
362
+
363
+ self.wq = nn.Linear(dim, dim, bias=qkv_bias)
364
+ self.wk = nn.Linear(dim, dim, bias=qkv_bias)
365
+
366
+ if gamma is not None:
367
+ self.gamma = torch.tensor([gamma])
368
+ if use_beta:
369
+ self.beta = nn.Parameter(torch.tensor([0.0]))
370
+ self.eps = torch.tensor([1e-6])
371
+
372
+ self.gamma = gamma
373
+ self.use_beta = use_beta
374
+
375
+ def prepare_input(self, x):
376
+ if len(x.shape) == 3: # Transformer
377
+ # Input tensor dimensions:
378
+ # x: (B, N, d), where B is batch size, N are patch tokens, d is depth (channels)
379
+ B, N, d = x.shape
380
+ gap_cls = x.mean(-2) # (B, N, d) -> (B, d)
381
+ gap_cls = gap_cls.unsqueeze(1) # (B, d) -> (B, 1, d)
382
+ return gap_cls, x
383
+ if len(x.shape) == 4: # CNN
384
+ # Input tensor dimensions:
385
+ # x: (B, d, H, W), where B is batch size, d is depth (channels), H is height, and W is width
386
+ B, d, H, W = x.shape
387
+ gap_cls = x.mean([-2, -1]) # (B, d, H, W) -> (B, d)
388
+ x = x.reshape(B, d, H * W).permute(0, 2, 1) # (B, d, H, W) -> (B, d, H*W) -> (B, H*W, d)
389
+ gap_cls = gap_cls.unsqueeze(1) # (B, d) -> (B, 1, d)
390
+ return gap_cls, x
391
+ else:
392
+ raise ValueError(f"Unsupported number of dimensions in input tensor: {len(x.shape)}")
393
+
394
+ def forward(self, x):
395
+ self.eps = self.eps.to(x.device)
396
+ # Prepare input tensor and perform GAP as initialization
397
+ gap_cls, x = self.prepare_input(x)
398
+
399
+ # Prepare queries (q), keys (k), and values (v)
400
+ q, k, v = gap_cls, self.norm_patches(x), self.norm_patches(x)
401
+
402
+ # Extract dimensions after normalization
403
+ Bq, Nq, dq = q.shape
404
+ Bk, Nk, dk = k.shape
405
+ Bv, Nv, dv = v.shape
406
+
407
+ # Check dimension consistency across batches and channels
408
+ assert Bq == Bk == Bv
409
+ assert dq == dk == dv
410
+
411
+ # Apply linear transformation for queries and keys then reshape
412
+ qq = self.wq(q).reshape(Bq, Nq, self.num_heads, dq // self.num_heads).permute(0, 2, 1,
413
+ 3) # (Bq, Nq, dq) -> (B, num_heads, Nq, dq/num_heads)
414
+ kk = self.wk(k).reshape(Bk, Nk, self.num_heads, dk // self.num_heads).permute(0, 2, 1,
415
+ 3) # (Bk, Nk, dk) -> (B, num_heads, Nk, dk/num_heads)
416
+
417
+ vv = v.reshape(Bv, Nv, self.num_heads, dv // self.num_heads).permute(0, 2, 1,
418
+ 3) # (Bv, Nv, dv) -> (B, num_heads, Nv, dv/num_heads)
419
+
420
+ # Compute attention scores
421
+ attn = (qq @ kk.transpose(-2, -1)) * self.scale
422
+ # Apply softmax for normalization
423
+ attn = attn.softmax(dim=-1)
424
+
425
+ # If gamma scaling is used
426
+ if self.gamma is not None:
427
+ # Apply gamma scaling on values and compute the weighted sum using attention scores
428
+ x = torch.pow(attn @ torch.pow((vv - vv.min() + self.eps), self.gamma),
429
+ 1 / self.gamma) # (B, num_heads, Nv, dv/num_heads) -> (B, 1, 1, d)
430
+ # If use_beta, add a learnable translation
431
+ if self.use_beta:
432
+ x = x + self.beta
433
+ else:
434
+ # Compute the weighted sum using attention scores
435
+ x = (attn @ vv).transpose(1, 2).reshape(Bq, Nq, dq)
436
+
437
+ return x.squeeze()
438
+
439
+
440
+
441
+ class SimPoolAggregator(BaseAggregator):
442
+
443
+ def __init__(self, num_heads, dim, *args, **kwargs):
444
+ super().__init__(False, False, num_heads, "sum", False)
445
+ self.pool = SimPool(dim, gamma=1.25)
446
+
447
+ def _get_full_sims(self, preds, raw, agg_sim, agg_heads):
448
+ if agg_sim:
449
+ device = self.pool.wq.weight.data.device
450
+ pooled_audio = self.pool(preds[AUDIO_FEATS].to(torch.float32).to(device))
451
+ pooled_image = self.pool(preds[IMAGE_FEATS].to(torch.float32).to(device))
452
+
453
+ sims = torch.einsum(
454
+ "bc,dc->bd",
455
+ pooled_audio,
456
+ pooled_image)
457
+ if agg_heads:
458
+ return sims
459
+ else:
460
+ sims = sims.unsqueeze(-1)
461
+ return sims.repeat(*([1] * (sims.dim() - 1)), self.num_heads)
462
+
463
+
464
+ else:
465
+ return BaseAggregator._get_full_sims(self, preds, raw, agg_sim, agg_heads)
466
+
467
+ def get_pairwise_sims(self, preds, raw, agg_sim, agg_heads):
468
+ if agg_sim:
469
+ device = self.pool.wq.weight.data.device
470
+ pooled_audio = self.pool(preds[AUDIO_FEATS].to(torch.float32).to(device))
471
+ pooled_image = self.pool(preds[IMAGE_FEATS].to(torch.float32).to(device))
472
+
473
+ sims = torch.einsum(
474
+ "bc,dc->b",
475
+ pooled_audio,
476
+ pooled_image)
477
+ if agg_heads:
478
+ return sims
479
+ else:
480
+ sims = sims.unsqueeze(-1)
481
+ return sims.repeat(*([1] * (sims.dim() - 1)), self.num_heads)
482
+
483
+ else:
484
+ return BaseAggregator.get_pairwise_sims(self, preds, raw, agg_sim, agg_heads)
485
+
486
+ def forward_batched(self, preds, agg_heads, batch_size):
487
+ return self.forward(preds, agg_heads)
488
+
489
+
490
+
491
+ def get_aggregator(sim_agg_type, nonneg_sim, mask_silence, num_heads, head_agg, use_cls, dim):
492
+ shared_args = dict(
493
+ nonneg_sim=nonneg_sim,
494
+ mask_silence=mask_silence,
495
+ num_heads=num_heads,
496
+ head_agg=head_agg,
497
+ use_cls=use_cls,
498
+ )
499
+
500
+ if sim_agg_type == "paired":
501
+ agg1 = PairedAggregator(**shared_args)
502
+ elif sim_agg_type == "misa":
503
+ agg1 = ImageThenAudioAggregator("max", "avg", **shared_args)
504
+ elif sim_agg_type == "mima":
505
+ agg1 = ImageThenAudioAggregator("max", "max", **shared_args)
506
+ elif sim_agg_type == "sisa":
507
+ agg1 = ImageThenAudioAggregator("avg", "avg", **shared_args)
508
+ elif sim_agg_type == "cavmae":
509
+ agg1 = CAVMAEAggregator()
510
+ elif sim_agg_type == "imagebind":
511
+ agg1 = ImageBindAggregator(num_heads=shared_args["num_heads"])
512
+ elif sim_agg_type == "simpool":
513
+ agg1 = SimPoolAggregator(num_heads=shared_args["num_heads"], dim=dim)
514
+ else:
515
+ raise ValueError(f"Unknown loss_type {sim_agg_type}")
516
+
517
+ return agg1
518
+