lorocksUMD's picture
Upload 32 files
e6d4b46 verified
raw
history blame
48.5 kB
import os
from collections import deque
from itertools import combinations
from os.path import join
import hydra
import numpy as np
import pytorch_lightning as pl
import torch
import torch.distributed as dist
import torch.nn.functional as F
from omegaconf import DictConfig, OmegaConf
from peft import get_peft_model, LoraConfig
from pytorch_lightning import Trainer
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities import grad_norm
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, SequentialLR, LambdaLR
from torchmetrics.functional.classification import binary_average_precision
from huggingface_hub import PyTorchModelHubMixin
from denseav.aggregators import get_aggregator
from denseav.aligners import get_aligner, ProgressiveGrowing
from denseav.constants import *
from denseav.data.AVDatasets import AVDataModule
from denseav.shared import flatten_preds, GatherLayer, \
get_image_featurizer, get_audio_featurizer, RollingAvg, create_model_from_cfg
torch.multiprocessing.set_sharing_strategy('file_system')
def _imposter_indices_helper(true_indices: torch.Tensor, samples: torch.Tensor):
mask = (true_indices == samples).to(torch.int64)
n = mask.shape[0]
if not mask.any():
return samples
else:
new_samples = torch.randint(0, n, size=(n,), device=true_indices.device)
comb_samples = mask * new_samples + (1 - mask) * samples
return _imposter_indices_helper(true_indices, comb_samples)
def imposter_indices(n, device):
return _imposter_indices_helper(
torch.arange(0, n, device=device),
torch.randint(0, n, size=(n,), device=device))
def get_sim_per_row(image_outputs, audio_outputs, n_frames, sim_type):
max_t = audio_outputs.shape[-1]
oh = F.one_hot(n_frames - 1, num_classes=max_t)
audio_mask = 1 - torch.cumsum(oh, dim=1)
audio_mask = F.pad(audio_mask, [1, 0], value=1)[:, :max_t].to(audio_outputs.dtype)
full_sim = torch.einsum("bct,bchw->bthw", audio_outputs, image_outputs)
expanded_am = audio_mask.unsqueeze(-1).unsqueeze(-1)
if sim_type.endswith("mi"):
offset = 10 * (full_sim.max() - full_sim.min())
full_sim = (full_sim - ((1 - expanded_am) * offset)).max(1, keepdim=True).values
if sim_type.startswith("mi"):
full_sim = full_sim.max(-1, keepdim=True).values.max(-2, keepdim=True).values
if sim_type.endswith("sa"):
full_sim = (full_sim * (expanded_am / expanded_am.sum(1, keepdim=True).clamp_min(1))).sum(1, keepdim=True)
return full_sim.mean(dim=[1, 2, 3])
def sampled_margin_rank_loss(image_outputs, audio_outputs, n_frames, sim_type, margin=1.):
"""
Computes the triplet margin ranking loss for each anchor image/caption pair
The impostor image/caption is randomly sampled from the minibatch
"""
assert (image_outputs.dim() == 4)
assert (audio_outputs.dim() == 3)
n = image_outputs.size(0)
imp_ind_i = imposter_indices(n, image_outputs.device)
imp_ind_a = imposter_indices(n, image_outputs.device)
true_sim = get_sim_per_row(image_outputs, audio_outputs, n_frames, sim_type)
imp_sim_i = get_sim_per_row(image_outputs[imp_ind_i], audio_outputs, n_frames, sim_type)
imp_sim_a = get_sim_per_row(image_outputs, audio_outputs[imp_ind_a], n_frames[imp_ind_a], sim_type)
a2i_loss = (margin + imp_sim_i - true_sim).clamp_min(0)
i2a_loss = (margin + imp_sim_a - true_sim).clamp_min(0)
return (a2i_loss + i2a_loss).mean() / 2
class SimilarityCalibrator(torch.nn.Module):
def __init__(self, cal_init, max_w=100, min_w=.01, subtract_mean=True, use_bias=False):
super().__init__()
self.max_w = max_w
self.min_w = min_w
self.w = torch.nn.Parameter(torch.tensor([cal_init]).log())
self.use_bias = use_bias
if self.use_bias:
self.b = torch.nn.Parameter(torch.tensor([0.0]))
self.subtract_mean = subtract_mean
def get_w(self):
return torch.exp(self.w).clamp_max(self.max_w).clamp_min(self.min_w)
def forward(self, x):
sims = self.get_w() * x
if self.use_bias:
sims = sims + self.b
if self.subtract_mean:
return sims - sims.mean()
else:
return sims
class SpatialDropout(torch.nn.Module):
def __init__(self, p, *args, **kwargs):
super().__init__(*args, **kwargs)
self.p = p
def forward(self, x):
b, c, h, w = x.shape
dropout = torch.rand((b, 1, h, w), dtype=x.dtype, device=x.device) > self.p
if self.training:
return x * dropout
else:
return x
class LitAVAligner(pl.LightningModule, PyTorchModelHubMixin, repo_url="https://github.com/mhamilton723/DenseAV", license="mit", tags=["denseav"]):
def __init__(self,
code_dim,
image_model_type,
image_model_token_type,
image_aligner_type,
image_pool_width,
audio_model_type,
audio_aligner_type,
audio_pool_width,
audio_lora,
audio_lora_rank,
image_lora,
image_lora_rank,
gradient_clipping,
learn_audio_cls,
silence_l1,
silence_l2,
tv_weight,
nonneg_sim,
nonneg_pressure,
pretrain_lr,
lr,
lr_warmup,
lr_schedule,
lr_cycle_length,
optimizer,
gather_tensors,
sim_agg_type,
sim_agg_heads,
sim_use_cls,
disentangle_weight,
norm_vectors,
cal_init,
cal_balance_weight,
loss_type,
loss_margin,
mask_silence,
finetune_image_model,
finetune_audio_model,
use_cached_embs,
output_root,
neg_audio,
neg_audio_weight,
head_agg,
adaptive_clipping,
specialization_weight,
spatial_dropout,
channel_dropout,
mixup_weight,
memory_buffer_size,
loss_leak,
):
super().__init__()
self.code_dim = code_dim
self.image_model_type = image_model_type
self.image_model_token_type = image_model_token_type
self.image_aligner_type = image_aligner_type
self.image_pool_width = image_pool_width
self.audio_model_type = audio_model_type
self.audio_aligner_type = audio_aligner_type
self.audio_pool_width = audio_pool_width
self.gradient_clipping = gradient_clipping
self.learn_audio_cls = learn_audio_cls
self.silence_l1 = silence_l1
self.silence_l2 = silence_l2
self.tv_weight = tv_weight
self.nonneg_sim = nonneg_sim
self.nonneg_pressure = nonneg_pressure
self.pretrain_lr = pretrain_lr
self.lr = lr
self.lr_warmup = lr_warmup
self.lr_schedule = lr_schedule
self.lr_cycle_length = lr_cycle_length
self.optimizer = optimizer
self.gather_tensors = gather_tensors
self.sim_agg_type = sim_agg_type
self.sim_agg_heads = sim_agg_heads
self.sim_use_cls = sim_use_cls
self.disentangle_weight = disentangle_weight
self.norm_vectors = norm_vectors
self.cal_init = cal_init
self.cal_balance_weight = cal_balance_weight
self.loss_type = loss_type
self.loss_margin = loss_margin
self.mask_silence = mask_silence
self.finetune_image_model = finetune_image_model
self.finetune_audio_model = finetune_audio_model
self.use_cached_embs = use_cached_embs
self.output_root = output_root
self.audio_lora = audio_lora
self.audio_lora_rank = audio_lora_rank
self.image_lora = image_lora
self.image_lora_rank = image_lora_rank
self.neg_audio = neg_audio
self.neg_audio_weight = neg_audio_weight
self.head_agg = head_agg
self.adaptive_clipping = adaptive_clipping
self.specialization_weight = specialization_weight
self.spatial_dropout = spatial_dropout
self.channel_dropout = channel_dropout
self.mixup_weight = mixup_weight
self.memory_buffer_size = memory_buffer_size
self.memory_buffer = deque(maxlen=self.memory_buffer_size)
self.loss_leak = loss_leak
self.full_train = False # Added by me
if self.audio_model_type in {"audiomae", "audiomae-finetuned", "cavmae", "cavmae-mixed", "imagebind"}:
self.audio_input = "spec"
elif self.audio_model_type == "davenet":
self.audio_input = "davenet_spec"
elif self.audio_model_type == "fnac":
self.audio_input = "fnac_spec"
else:
self.audio_input = "audio"
extra_model_args = dict(output_root=output_root)
self.image_model, _, self.image_feat_dim = get_image_featurizer(
image_model_type, token_type=self.image_model_token_type, **extra_model_args)
self.image_model.eval()
if not self.finetune_image_model:
for param in self.image_model.parameters():
param.requires_grad = False
if image_model_type in {"cavmae", "cavmae-mixed", "imagebind", "fnac"}:
extra_model_args["model"] = self.image_model.model
if use_cached_embs:
_, self.audio_feat_dim = get_audio_featurizer(audio_model_type, **extra_model_args)
else:
self.audio_model, self.audio_feat_dim = get_audio_featurizer(audio_model_type, **extra_model_args)
self.audio_model.eval()
if not self.finetune_audio_model:
for param in self.audio_model.parameters():
param.requires_grad = False
if self.image_lora:
if self.image_model_type in {"sam", "dino8", "dinov2", "cavmae", "cavmae-mixed"}:
target_modules = ["qkv"]
elif self.image_model_type == "clip":
target_modules = ["out_proj"]
elif self.image_model_type == "imagebind":
target_modules = ["out_proj", "fc1", "fc2"]
else:
target_modules = ["q", "k", "v"]
peft_config = LoraConfig(
target_modules=target_modules,
inference_mode=False,
r=image_lora_rank,
lora_alpha=32,
lora_dropout=0.1
)
self.image_model = get_peft_model(self.image_model, peft_config)
self.image_model.print_trainable_parameters()
if self.audio_lora:
if self.audio_model_type == "hubert":
target_modules = ["q_proj", "k_proj", "v_proj"]
else:
target_modules = ["q", "k", "v"]
peft_config = LoraConfig(
inference_mode=False,
target_modules=target_modules,
r=audio_lora_rank,
lora_alpha=32,
lora_dropout=0.1
)
self.audio_model = get_peft_model(self.audio_model, peft_config)
self.audio_model.print_trainable_parameters()
shared_aligner_args = dict(out_dim=self.code_dim)
self.audio_aligner = get_aligner(
self.audio_aligner_type, self.audio_feat_dim, **shared_aligner_args)
self.image_aligner = get_aligner(
self.image_aligner_type, self.image_feat_dim, **shared_aligner_args)
if self.loss_type == "nce":
self.sim_cal = SimilarityCalibrator(self.cal_init, subtract_mean=True, use_bias=False)
else:
self.sim_cal = SimilarityCalibrator(self.cal_init, subtract_mean=False, use_bias=True)
if self.learn_audio_cls:
self.audio_cls = torch.nn.Parameter(torch.randn(self.audio_feat_dim))
if self.spatial_dropout > 0.0:
self.spatial_dropout_layer = SpatialDropout(self.spatial_dropout)
if self.channel_dropout > 0.0:
self.channel_dropout_layer = torch.nn.Dropout2d(self.channel_dropout)
self.sim_agg = get_aggregator(
self.sim_agg_type,
self.nonneg_sim,
self.mask_silence,
self.sim_agg_heads,
self.head_agg,
self.sim_use_cls,
dim=self.image_feat_dim
)
self.hparams_logged = False
self.rolling_avg = RollingAvg(50)
self.grad_avg = RollingAvg(50, nonzero=True)
self.save_hyperparameters()
def set_full_train(self, full_train):
self.full_train = full_train
def prep_feats(self, feats, is_audio):
if not is_audio and self.training and self.image_pool_width > 1:
feats = torch.nn.AvgPool2d(self.image_pool_width)(feats)
if is_audio and self.training and self.audio_pool_width > 1:
feats = torch.nn.AvgPool2d((1, self.audio_pool_width))(feats)
if self.norm_vectors:
feats = F.normalize(feats, dim=1)
return feats
def on_before_optimizer_step(self, optimizer, optimizer_idx):
norms = grad_norm(self, norm_type=2)
avg_grads = self.grad_avg.get_all()
params = {
f"grad_2.0_norm/{name}": p
for name, p in self.named_parameters()
if p.grad is not None
}
if self.adaptive_clipping:
for k in norms.keys():
if k in params:
avg_grad = max(avg_grads.get(k, norms[k]), 1e-5)
if self.global_step > 10 and norms[k] > avg_grad * 5:
print(f"Bad grad for {k}: {norms[k]} scaling to {avg_grad * 5}")
torch.nn.utils.clip_grad_norm_(params[k], avg_grad * 5)
norms[k] = avg_grad * 5
if norms[k] > self.gradient_clipping:
# print(f"Bad grad for {k}: {norms[k]} scaling to {self.gradient_clipping}")
torch.nn.utils.clip_grad_norm_(params[k], self.gradient_clipping)
# self.grad_avg.add_all(norms)
# self.log_dict(norms)
def interpolate_mask(self, mask, target_length, discrete):
b, t = mask.shape
mask = F.interpolate(mask.reshape(b, 1, 1, t), (1, target_length), mode="bilinear") \
.reshape(b, target_length)
if discrete:
mask = mask > 0.01
sums = mask.sum(1)
all_zeros = torch.where(sums == 0)[0]
if len(all_zeros) > 0:
print("Fixing a bad mask")
for entry in all_zeros:
mask[entry, torch.randint(0, target_length - 1, size=())] = True
else:
return mask
return mask
def forward_audio(self, batch):
if self.use_cached_embs:
audio_feats = batch["audio_emb"]
if "audio_cls" in batch:
audio_cls = batch["audio_cls"]
else:
audio_cls = None
else:
audio = batch[self.audio_input]
if self.full_train:
audio_feats, audio_cls = self.audio_model(audio, include_cls=True)
else:
with torch.no_grad():
audio_feats, audio_cls = self.audio_model(audio, include_cls=True)
mask = batch[AUDIO_MASK] if AUDIO_MASK in batch else torch.ones_like(audio)
pos_mask = batch[AUDIO_POS_MASK] if AUDIO_POS_MASK in batch else torch.ones_like(audio)
if self.learn_audio_cls:
assert audio_cls is None
audio_cls = torch.broadcast_to(self.audio_cls.unsqueeze(0), (audio_feats.shape[0], audio_feats.shape[1]))
aligned_audio_feats, aligned_audio_cls = self.audio_aligner(audio_feats, audio_cls)
if self.channel_dropout > 0.0:
aligned_audio_feats = self.channel_dropout_layer(aligned_audio_feats)
aligned_audio_feats = self.prep_feats(aligned_audio_feats, is_audio=True)
audio_mask = self.interpolate_mask(mask, aligned_audio_feats.shape[-1], True)
audio_pos_mask = self.interpolate_mask(pos_mask, aligned_audio_feats.shape[-1], False)
ret = {
AUDIO_MASK: audio_mask,
AUDIO_POS_MASK: audio_pos_mask,
AUDIO_FEATS: aligned_audio_feats,
}
if aligned_audio_cls is not None:
ret[AUDIO_CLS] = aligned_audio_cls
return ret
# @autocast(device_type="cuda", enabled=False)
def forward_image(self, batch, max_batch_size=None):
with torch.no_grad():
image = batch[IMAGE_INPUT]
b, nf, c, h, w = image.shape
image = image.reshape(b * nf, c, h, w)
if max_batch_size is None:
max_batch_size = image.shape[0]
chunks = [image[i:i + max_batch_size] for i in range(0, image.shape[0], max_batch_size)]
all_image_feats = []
all_image_cls = []
for chunk in chunks:
if self.full_train:
image_feats, image_cls = self.image_model(chunk, include_cls=True)
else:
with torch.no_grad():
image_feats, image_cls = self.image_model(chunk, include_cls=True)
aligned_image_feats, aligned_image_cls = self.image_aligner(image_feats, image_cls)
all_image_feats.append(aligned_image_feats)
all_image_cls.append(aligned_image_cls)
# Stitch the chunks back together
aligned_image_feats = torch.cat(all_image_feats, dim=0)
aligned_image_cls = torch.cat(all_image_cls, dim=0)
if self.channel_dropout > 0.0:
aligned_image_feats = self.channel_dropout_layer(aligned_image_feats)
if self.spatial_dropout > 0.0:
aligned_image_feats = self.spatial_dropout_layer(aligned_image_feats)
aligned_image_feats = self.prep_feats(aligned_image_feats, is_audio=False)
ret = {IMAGE_FEATS: aligned_image_feats}
if IMAGE_MASK in batch:
with torch.no_grad():
mask = batch[IMAGE_MASK]
mask = mask.reshape(b * nf, 1, h, w)
b, c, h, w = aligned_image_feats.shape
mask = F.adaptive_avg_pool2d(mask.to(aligned_image_feats), output_size=(h, w))
ret[IMAGE_MASK] = mask
if aligned_image_cls is not None:
ret[IMAGE_CLS] = aligned_image_cls
return ret
def forward(self, batch):
audio_feat_dict = self.forward_audio(batch)
image_feat_dict = self.forward_image(batch)
return {**image_feat_dict, **audio_feat_dict}
def contrast_loss(self, sims):
b = sims.shape[0]
sims = sims - torch.eye(b, b, device=sims.device) * self.loss_margin
sims_1 = sims
sims_2 = sims.permute(1, 0)
if self.loss_leak > 0.0:
id = torch.eye(sims_1.shape[0], sims_1.shape[1], device=sims.device, dtype=sims.dtype)
label_mask = id * (1 - self.loss_leak)
label_mask += (1 - id) * self.loss_leak / (sims_1.shape[0] - 1)
label_mask /= label_mask.sum(dim=1, keepdim=True)
else:
label_mask = torch.eye(sims_1.shape[0], sims_1.shape[1], device=sims.device, dtype=sims.dtype)
labels = torch.arange(0, sims.shape[0], device=sims.device)
self.rolling_avg.add(f"acc/1", (sims.argmax(dim=1) == labels).to(sims).mean())
self.rolling_avg.add(f"acc/2", (sims.argmax(dim=0) == labels).to(sims).mean())
if self.loss_type == "margin":
margin_loss_tensor = (sims - torch.diag(sims)).clamp_min(0)
margin_loss = margin_loss_tensor.mean()
self.rolling_avg.add(f"loss/frac_nonzero", (margin_loss_tensor > 0).to(sims).mean())
self.rolling_avg.add(f"loss/margin", margin_loss)
return margin_loss
elif self.loss_type == "ce":
ce_loss = 1 / 2 * F.cross_entropy(sims_1, labels) + \
1 / 2 * F.cross_entropy(sims_2, labels)
self.rolling_avg.add(f"loss/ce", ce_loss)
return ce_loss
elif self.loss_type == "bce":
bce_loss = F.binary_cross_entropy_with_logits(sims_1.flatten(), label_mask.flatten())
self.rolling_avg.add(f"loss/bce", bce_loss)
return bce_loss
elif self.loss_type == "nce":
nce_loss = 1 / 2 * (-F.log_softmax(sims_1, dim=-1) * label_mask).sum(1).mean() + \
1 / 2 * (-F.log_softmax(sims_2, dim=-1) * label_mask).sum(1).mean()
self.rolling_avg.add(f"loss/nce", nce_loss)
return nce_loss
else:
raise ValueError(f"Unknown loss type {self.loss_type}")
def loss(self, preds):
image_feats = preds[IMAGE_FEATS]
audio_feats = preds[AUDIO_FEATS]
audio_mask = preds[AUDIO_MASK]
image_mask = preds[IMAGE_MASK]
audio_pos_mask = preds[AUDIO_POS_MASK]
if DATA_SOURCE in preds:
source = preds[DATA_SOURCE].to(torch.int64)
else:
source = None
uncal_sims = self.sim_agg(preds, agg_heads=True)
sims = self.sim_cal(uncal_sims)
_mask = 1 - torch.eye(sims.shape[0], device=sims.device)
self.log(f"sim/pos", torch.diag(sims).mean())
self.log(f"sim/neg", (sims * _mask).sum() / (_mask.sum()))
self.log(f"sim/uncal_pos", torch.diag(uncal_sims).mean())
self.log(f"sim/uncal_neg", (uncal_sims * _mask).sum() / (_mask.sum()))
b, c, h, w = image_feats.shape
b, c, f, t = audio_feats.shape
n_samples = 250
nh = self.sim_agg_heads
image_feats_by_head = image_feats.reshape(b, self.sim_agg_heads, c // nh, h, w)
audio_feats_by_head = audio_feats.reshape(b, self.sim_agg_heads, c // nh, f, t)
def maybe_clamp(t):
return t.clamp_min(0) if self.nonneg_sim else t
paired_sim_raw = self.sim_agg.get_pairwise_sims(preds, raw=True, agg_sim=False, agg_heads=False)
paired_sim = maybe_clamp(paired_sim_raw)
loss = 0.0
if self.nonneg_pressure:
afb, afk, afc, aff, aft = audio_feats_by_head.shape
ifb, ifk, ifc, ifh, ifw = image_feats_by_head.shape
assert (afb == ifb)
device = audio_feats_by_head.device
random_b = torch.randint(0, afb, size=(n_samples,), device=device)
random_t = torch.randint(0, aft, size=(n_samples,), device=device)
random_f = torch.randint(0, aff, size=(n_samples,), device=device)
random_h = torch.randint(0, ifh, size=(n_samples,), device=device)
random_w = torch.randint(0, ifw, size=(n_samples,), device=device)
random_audio_feats = audio_feats_by_head[random_b, :, :, random_f, random_t]
random_image_feats = image_feats_by_head[random_b, :, :, random_h, random_w]
random_sim_raw = torch.einsum("bkc,dkc->bdk", random_audio_feats, random_image_feats)
nonneg_loss = random_sim_raw.clamp_max(0).square().mean()
self.rolling_avg.add(f"loss/nonneg", nonneg_loss)
loss += nonneg_loss * self.nonneg_pressure
if self.silence_l1 > 0 or self.silence_l2 > 0:
masked_b, masked_t = torch.where(~audio_mask)
if len(masked_b) > n_samples:
subset = torch.randperm(len(masked_b))[:n_samples]
masked_b = masked_b[subset]
masked_t = masked_t[subset]
if len(masked_b) == n_samples:
silent_audio_feats = audio_feats_by_head[masked_b, :, :, :, masked_t].mean(-1) # d k c
silence_tensor = maybe_clamp(
torch.einsum("bkchw,dkc->bkdhw", image_feats_by_head, silent_audio_feats))
silence_l1_loss = silence_tensor.abs().mean()
self.rolling_avg.add(f"loss/silence_l1", silence_l1_loss)
loss += silence_l1_loss * self.silence_l1
silence_l2_loss = silence_tensor.square().mean()
self.rolling_avg.add(f"loss/silence_l2", silence_l2_loss)
loss += silence_l2_loss * self.silence_l2
else:
pass
if self.neg_audio_weight > 0 and self.neg_audio:
b, t = audio_pos_mask.shape
negative_weight = ((1 - audio_pos_mask) * audio_mask.to(sims)).reshape(b, 1, 1, 1, 1, t)
negative_weight = torch.broadcast_to(negative_weight, paired_sim.shape)
if negative_weight.sum() > 0:
neg_audio_loss = (paired_sim.square() * negative_weight).sum() \
/ negative_weight.sum().clamp_min(0.1)
self.rolling_avg.add(f"loss/neg_audio", neg_audio_loss)
self.rolling_avg.add(f"loss/neg_weight_avg", negative_weight.mean())
loss += neg_audio_loss * self.neg_audio_weight
else:
print("WARNING: No negative samples found in batch")
if self.tv_weight > 0:
tv_loss = (paired_sim[:, :, :, :, :, 1:] - paired_sim[:, :, :, :, :, :-1]).square().mean()
self.rolling_avg.add(f"loss/tv", tv_loss)
loss += tv_loss * self.tv_weight
self.log(f"cal/w", self.sim_cal.get_w())
if self.cal_balance_weight > 0.0:
cal_balance = (np.log(self.cal_init) - torch.log(self.sim_cal.get_w().clamp_min(.00000001))) \
.clamp_min(0).square().mean()
self.rolling_avg.add(f"loss/cal_balance", cal_balance)
loss += cal_balance * self.cal_balance_weight
if self.disentangle_weight > 0.0:
assert source is not None
assert self.sim_agg_heads % 2 == 0
dilation = self.sim_agg_heads // 2
sources_oh = F.one_hot(source, num_classes=2)
b, h = sources_oh.shape
sources_mask = 1 - torch.broadcast_to(sources_oh.unsqueeze(-1), (b, h, dilation)) \
.reshape(b, h * dilation).to(paired_sim)
disentangle_loss = torch.einsum("bkhwft,bk->bhwft", paired_sim, sources_mask).square().mean()
self.rolling_avg.add(f"loss/disentangle", disentangle_loss)
loss += disentangle_loss * self.disentangle_weight
if self.specialization_weight > 0.0 and self.sim_agg_heads > 1:
total_specialization_loss = 0.0
combos = list(combinations(range(self.sim_agg_heads), 2))
for i, j in combos:
specialization_loss_pair = (paired_sim[:, i].abs() * paired_sim[:, j].abs()).mean()
total_specialization_loss += specialization_loss_pair
avg_specialization_loss = total_specialization_loss / len(combos)
self.rolling_avg.add(f"loss/specialize", avg_specialization_loss)
loss += avg_specialization_loss * self.specialization_weight
if self.mixup_weight > 0.0:
b, _, h, w = image_mask.shape
neg_img_mask = torch.broadcast_to(
1 - image_mask.to(paired_sim).reshape(b, 1, h, w, 1, 1),
paired_sim.shape)
image_mixup_loss = (paired_sim * neg_img_mask).square().sum() / neg_img_mask.sum().clamp_min(0.1)
self.rolling_avg.add(f"loss/image_mixup", image_mixup_loss)
loss += image_mixup_loss * self.mixup_weight
sims = sims
loss += self.contrast_loss(sims)
self.rolling_avg.add(f"loss/total", loss)
return loss
def setup_hparams(self):
recalls = ['A_r1', 'A_r5', 'A_r10', 'I_r1', 'I_r5', 'I_r10']
if self.trainer.datamodule.use_extra_val_sets:
datasets = ["Places", "AudioSet"]
else:
datasets = ["Val"]
heads = ["total"]
metric_names = [
"hp/speech_basic_ap", "hp/speech_advanced_ap", "hp/sound_basic_ap",
"hp/speech_basic_iou", "hp/speech_advanced_iou", "hp/sound_basic_iou",
]
for dataset in datasets:
for head in heads:
for recall in recalls:
metric_names.append(f"hp/{dataset}/{head}/{recall}")
if self.sim_agg_heads == 2:
metric_names.extend(["hp/ap_dis", "hp/act_dis"])
if hasattr(self.trainer, "datamodule"):
all_hparams = {**self.hparams, **self.trainer.datamodule.hparams}
else:
all_hparams = self.hparams
starting_values = {n: torch.nan for n in metric_names}
self.logger.log_hyperparams(all_hparams, starting_values)
def on_train_start(self):
self.setup_hparams()
self.hparams_logged = True
def on_train_batch_start(self, batch, batch_idx):
remake_optimizers = False
if isinstance(self.image_aligner, ProgressiveGrowing):
should_remake = self.image_aligner.maybe_change_phase(self.global_step)
remake_optimizers = remake_optimizers or should_remake
if isinstance(self.audio_aligner, ProgressiveGrowing):
should_remake = self.audio_aligner.maybe_change_phase(self.global_step)
remake_optimizers = remake_optimizers or should_remake
if remake_optimizers:
raise NotImplementedError()
def _combine_preds(self, all_preds):
temp = {}
new_preds = {}
# Collect tensors for each key into lists
for d in all_preds:
for key, value in d.items():
if isinstance(value, torch.Tensor):
if key not in temp:
temp[key] = []
temp[key].append(value)
# Concatenate all tensors for each key using a single call to torch.cat
for key, tensor_list in temp.items():
new_preds[key] = torch.cat(tensor_list)
return new_preds
def training_step(self, batch, batch_idx):
assert batch[IMAGE_INPUT].shape[1] == 1
preds = self.forward(batch)
if DATA_SOURCE in batch:
preds[DATA_SOURCE] = batch[DATA_SOURCE]
if self.trainer.world_size > 1 and self.gather_tensors:
for k, v in preds.items():
new_v = v.contiguous()
preds[k] = torch.cat(GatherLayer.apply(new_v), dim=0)
if self.memory_buffer_size > 0:
new_preds = self._combine_preds(list(self.memory_buffer) + [preds])
else:
new_preds = preds
loss = self.loss(new_preds)
if self.memory_buffer_size > 0:
self.memory_buffer.append(self._recursive_detach(preds, gather=False))
if self.trainer.is_global_zero and self.global_step % 50 == 1:
writer = self.logger.experiment
self.rolling_avg.logall(lambda k, v: writer.add_scalar(k, v, global_step=self.global_step))
if self.trainer.scaler is not None:
self.log("loss_scale", self.trainer.scaler.get_scale())
if self.global_step % 10000 == 0 and self.global_step > 0:
print("RESETTING TFEVENT FILE")
self.logger.experiment.close()
self.logger.experiment._get_file_writer()
return loss
def on_validation_start(self) -> None:
if not self.hparams_logged:
self.setup_hparams()
self.hparams_logged = True
def _auto_gather(self, t):
if t.dtype == torch.bool:
t = t.to(torch.float)
if self.trainer.num_devices == 1:
return t.cpu()
t = torch.clone(t).contiguous()
if self.trainer.is_global_zero:
gather_list = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
dist.gather(t, gather_list)
return torch.cat(gather_list, dim=0).cpu()
else:
dist.gather(t)
def validation_step(self, batch, batch_idx, dataloader_idx=0):
with torch.no_grad():
preds = self.forward(batch)
ret = {}
for k in preds.keys():
if k in preds:
ret[k] = self._auto_gather(preds[k])
batch_keys = [IMAGE_INPUT, "spec", "semseg", "num_pixels_per_class", 'total_length']
for k in batch_keys:
if k in batch:
ret[k] = self._auto_gather(batch[k])
if "metadata" in batch:
if isinstance(batch["metadata"]["id"], torch.Tensor):
ret["id"] = self._auto_gather(batch["metadata"]["id"])
ret["index"] = self._auto_gather(batch["metadata"]["index"])
return ret
def _calc_recalls(self, sim):
top_10_a = sim.topk(10, 0).indices == torch.arange(sim.shape[0]).unsqueeze(0)
top_10_i = (sim.topk(10, 1).indices == torch.arange(sim.shape[0]).unsqueeze(1)).permute(1, 0)
a_recall = lambda p: top_10_a[0:p].any(0).to(sim).mean()
i_recall = lambda p: top_10_i[0:p].any(0).to(sim).mean()
return {'A_r1': a_recall(1),
'A_r5': a_recall(5),
'A_r10': a_recall(10),
'I_r1': i_recall(1),
'I_r5': i_recall(5),
'I_r10': i_recall(10)}
def calc_recalls(self, preds, dataset):
sim = self.sim_agg.forward_batched(
preds=preds,
agg_heads=False,
batch_size=4,
).cpu()
all_metrics = dict()
for k, v in self._calc_recalls(sim.sum(-1)).items():
all_metrics[f"hp/{dataset}/total/" + k] = v
return all_metrics
def retrieval_validation(self, outputs, dataset_name):
if len(outputs) == 0:
return
if self.trainer.is_global_zero:
results = flatten_preds(outputs)
if not self.trainer.sanity_checking:
print(results[IMAGE_FEATS].shape[0])
# assert (results[IMAGE_FEATS].shape[0] == 1000)
results[IMAGE_FEATS] = results[IMAGE_FEATS].cpu()
results[AUDIO_FEATS] = results[AUDIO_FEATS].cuda()
if self.sim_use_cls:
results[AUDIO_CLS] = results[AUDIO_CLS].cuda()
results[AUDIO_CLS] = results[AUDIO_CLS].cuda()
results[AUDIO_MASK] = results[AUDIO_MASK].cuda()
recalls = self.calc_recalls(results, dataset_name)
results[IMAGE_FEATS] = results[IMAGE_FEATS].cuda()
writer = self.logger.experiment
print("here")
for name, v in recalls.items():
writer.add_scalar(f"{name}", v, self.global_step + 1)
def semseg_validation(self, speech_preds, sound_preds):
if self.trainer.is_global_zero:
from eval_utils import get_paired_heatmaps
def prep_preds(preds, loader):
results = flatten_preds(preds)
metadata = loader.dataset.metadata
ordered_metadata = metadata.iloc[results["index"].numpy(), :].copy()
ordered_metadata["order"] = range(len(ordered_metadata))
return results, ordered_metadata
[_, _, speech_loader, sound_loader] = self.trainer.val_dataloaders
speech_results, speech_metadata = prep_preds(speech_preds, speech_loader)
sound_results, sound_metadata = prep_preds(sound_preds, sound_loader)
self.sound_metrics, unique_sound_indices = get_paired_heatmaps(
self, sound_results, sound_metadata["ade_class_id"], None)
self.speech_metrics, unique_word_indices = get_paired_heatmaps(
self, speech_results, speech_metadata["ade_class_id"], speech_metadata["timing"])
writer = self.logger.experiment
all_metrics = {
**{"sound_" + k: v for k, v in self.sound_metrics.items()},
**{"speech_" + k: v for k, v in self.speech_metrics.items()},
}
for k, v in all_metrics.items():
writer.add_scalar(f"hp/{k}", torch.tensor(v).mean(), self.global_step + 1)
def disentangle_validation(self, word_preds, sound_preds):
if len(word_preds) == 0 or len(sound_preds) == 0:
return
if self.trainer.is_global_zero:
word_preds = flatten_preds(word_preds)
sound_preds = flatten_preds(sound_preds)
word_scores = self.sim_agg.get_pairwise_sims(
word_preds,
raw=False,
agg_sim=True,
agg_heads=False,
)
sound_scores = self.sim_agg.get_pairwise_sims(
sound_preds,
raw=False,
agg_sim=True,
agg_heads=False,
)
all_scores = torch.cat([word_scores, sound_scores], dim=0)
all_scores -= all_scores.min(dim=0, keepdim=True).values
all_scores /= all_scores.max(dim=0, keepdim=True).values.clamp_min(.0001)
is_words = torch.cat([
torch.ones(word_scores.shape[0]),
torch.zeros(sound_scores.shape[0])], dim=0).to(torch.bool)
assert all_scores.shape[1] == 2
ap_matrix = torch.zeros(2, 2)
act_matrix = torch.zeros(2, 2)
for head in range(2):
# writer.add_histogram(f"h{head}_all_scores", all_scores[:, head])
for dataset_num in range(2):
if dataset_num == 0:
labels = is_words
else:
labels = ~is_words
ap_matrix[head, dataset_num] = binary_average_precision(
all_scores[:, head].cpu(), labels.to(torch.int64).cpu())
act_matrix[head, dataset_num] = 1 - (all_scores[:, head][labels]).mean()
ap_dis = max(.5 * (ap_matrix[0, 0] + ap_matrix[1, 1]),
.5 * (ap_matrix[0, 1] + ap_matrix[1, 0]))
act_dis = max(.5 * (act_matrix[0, 0] + act_matrix[1, 1]),
.5 * (act_matrix[0, 1] + act_matrix[1, 0]))
print("AP", ap_matrix)
print("AP dis", ap_dis)
print("Act", act_matrix)
print("Act dis", act_dis)
writer = self.logger.experiment
writer.add_scalar("hp/ap_dis", ap_dis, self.global_step + 1)
writer.add_scalar("hp/act_dis", act_dis, self.global_step + 1)
def validation_epoch_end(self, outputs) -> None:
print("Val end")
with torch.no_grad():
if self.trainer.datamodule.use_extra_val_sets:
if self.sim_agg_heads == 2:
self.disentangle_validation(outputs[0], outputs[1])
self.retrieval_validation(outputs[0], "Places")
self.retrieval_validation(outputs[1], "AudioSet")
self.semseg_validation(outputs[2], outputs[3])
else:
print("HERE!")
self.retrieval_validation(outputs, "Val")
writer = self.logger.experiment
writer.flush()
def _recursive_detach(self, obj, gather=True):
if isinstance(obj, torch.Tensor):
if gather:
return self._auto_gather(obj)
else:
obj.detach()
elif isinstance(obj, dict):
return {k: self._recursive_detach(v, gather) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._recursive_detach(v, gather) for v in obj]
else:
return obj
def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
with torch.no_grad():
predictions = {}
for k, v in batch.items():
predictions[k] = self._recursive_detach(v)
for k, v in self.forward(batch).items():
predictions[k] = self._auto_gather(v)
return predictions
def _configure_optimizers(self, full_train, lr):
params = [
*self.audio_aligner.parameters(),
*self.image_aligner.parameters(),
*self.sim_cal.parameters(),
*self.sim_agg.parameters()
]
if (self.finetune_image_model or self.image_lora) and full_train:
params.extend(self.image_model.parameters())
if (self.finetune_audio_model or self.audio_lora) and full_train:
params.extend(self.audio_model.parameters())
if self.learn_audio_cls:
params.append(self.audio_cls)
last_epoch = self.global_step - 1
if self.optimizer == "adam":
opt = torch.optim.Adam(params, lr=lr, eps=1e-7)
elif self.optimizer == "nadam":
opt = torch.optim.NAdam(params, lr=lr, eps=1e-7)
else:
raise ValueError(f"Unknown optimizer {self.optimizer}")
if self.lr_schedule == "sgdr":
scheduler = CosineAnnealingWarmRestarts(
opt, self.lr_cycle_length, 2, eta_min=lr * 2e-2, last_epoch=last_epoch)
else:
scheduler = LambdaLR(opt, lr_lambda=lambda step: 1.0, last_epoch=last_epoch)
if self.lr_warmup > 0:
warmup = LambdaLR(
opt,
lr_lambda=lambda step: min(max(float(step), 0.0) / self.lr_warmup, 1.0),
last_epoch=last_epoch,
)
scheduler = SequentialLR(
opt,
schedulers=[warmup, scheduler],
milestones=[self.lr_warmup],
last_epoch=last_epoch)
scheduler = {"scheduler": scheduler, "interval": "step"}
return [opt], [scheduler]
def configure_optimizers(self):
if self.full_train:
return self._configure_optimizers(self.full_train, self.lr)
else:
return self._configure_optimizers(self.full_train, self.pretrain_lr)
@hydra.main(config_path="configs", config_name="av_align.yaml", version_base=None)
def my_app(cfg: DictConfig) -> None:
print(OmegaConf.to_yaml(cfg))
seed_everything(cfg.seed, workers=True)
exp_name = f"{cfg.resume_prefix}"
if cfg.image_model_type == "dino8":
patch_size = 8 * cfg.image_pool_width
elif cfg.image_model_type == "cavmae":
patch_size = 16 * cfg.image_pool_width
elif cfg.image_model_type == "imagebind":
patch_size = 16 * cfg.image_pool_width
elif cfg.image_model_type == "clip":
patch_size = 16 * cfg.image_pool_width
elif cfg.image_model_type == "cavmae-mixed":
patch_size = 16 * cfg.image_pool_width
elif cfg.image_model_type == "dinov2":
patch_size = 14 * cfg.image_pool_width
else:
raise ValueError(f"Unknown patch size for model {cfg.image_model_type}")
datamodule = AVDataModule(
dataset_name=cfg.dataset_name,
load_size=cfg.load_size,
image_aug=cfg.image_aug,
audio_aug=cfg.audio_aug,
extra_audio_masking=cfg.extra_audio_masking,
audio_model_type=cfg.audio_model_type,
pytorch_data_dir=cfg.pytorch_data_dir,
use_cached_embs=cfg.use_cached_embs,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
audio_level=cfg.audio_level,
neg_audio=cfg.neg_audio,
use_original_val_set=not cfg.use_extra_val_sets,
use_extra_val_sets=cfg.use_extra_val_sets,
data_for_plotting=False,
quad_mixup=cfg.quad_mixup,
bg_mixup=cfg.bg_mixup,
patch_mixup=cfg.patch_mixup,
patch_size=patch_size
)
datamodule.maybe_unpack(remove_source=cfg.submitting_to_aml)
aligner = create_model_from_cfg(LitAVAligner, cfg, {})
if cfg.starting_weights is not None:
loaded = torch.load(join(cfg.output_root, cfg.starting_weights), map_location='cpu')
state = loaded["state_dict"]
aligner.load_state_dict(state, strict=cfg.load_strict)
del state
del loaded
if cfg.num_gpus > 1:
# strategy = "ddp_sharded" # _find_unused_parameters_true"
strategy = "ddp" # _find_unused_parameters_true"
else:
strategy = "auto"
if cfg.dataset_name in {"places-audio", "mixed", "audio-set", "mixed-full"}:
val_args = dict(check_val_every_n_epoch=2)
elif cfg.dataset_name in {"dolphin"}:
val_args = dict(check_val_every_n_epoch=5)
else:
val_args = dict(val_check_interval=10000)
# val_args = dict(val_check_interval=1000)
def maybe_get_ckpt(ckpt_dir):
if cfg.auto_resume and os.path.exists(ckpt_dir):
print(f"Attempting to resume from {ckpt_dir}")
candidates = os.listdir(ckpt_dir)
assert (len(candidates) == 1)
return join(ckpt_dir, candidates[0])
elif cfg.auto_resume:
print(f"Could not find checkpoint at {ckpt_dir}")
return None
else:
return None
log_dir = join(cfg.output_root, "logs", cfg.grouping_name, exp_name)
ckpt_dir = join(cfg.output_root, "checkpoints", cfg.grouping_name, exp_name)
import gc
torch.cuda.empty_cache()
gc.collect()
def run_exp(aligner, full_train):
trainer_args = dict(
accelerator='gpu',
strategy=strategy,
devices=cfg.num_gpus,
num_sanity_val_steps=cfg.num_sanity_val_steps,
log_every_n_steps=50,
reload_dataloaders_every_n_epochs=10,
precision="16",
# profiler="simple",
# precision="bf16",
max_steps=cfg.max_steps,
**val_args)
aligner.set_full_train(full_train)
if full_train:
suffix = "train"
else:
suffix = "pretrain"
trainer_args["max_steps"] = cfg.pretrain_steps
print(f"Starting {suffix} phase")
logger = TensorBoardLogger(join(log_dir, suffix), default_hp_metric=False)
callbacks = [
ModelCheckpoint(join(ckpt_dir, suffix), every_n_epochs=1),
LearningRateMonitor(logging_interval='step'),
]
Trainer(logger=logger,
callbacks=callbacks,
**trainer_args).fit(
aligner,
datamodule=datamodule,
ckpt_path=maybe_get_ckpt(join(ckpt_dir, suffix)))
train_chkpt = maybe_get_ckpt(join(ckpt_dir, "train"))
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if cfg.pretrain_steps > 0 and train_chkpt is None:
print("---"*10)
print("Setup with full_train = False")
run_exp(aligner, full_train=False)
print("---"*10)
else:
print("---"*10)
print("Setup with full_train = False")
run_exp(aligner, full_train=True)
print("---"*10)
if __name__ == "__main__":
my_app()