LPDoctor's picture
Add ThinkSound module files to repository
b5eac81
import torch
import torchaudio
import wandb
from einops import rearrange
from safetensors.torch import save_file, save_model
from ema_pytorch import EMA
from .losses.auraloss import SumAndDifferenceSTFTLoss, MultiResolutionSTFTLoss, SpatialSTFTLoss
# import pytorch_lightning as pl
import lightning as L
from lightning.pytorch.callbacks import Callback
from ..models.autoencoders import AudioAutoencoder
from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck
from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss
from .utils import create_optimizer_from_config, create_scheduler_from_config
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image
class AutoencoderTrainingWrapper(L.LightningModule):
def __init__(
self,
autoencoder: AudioAutoencoder,
lr: float = 1e-4,
warmup_steps: int = 0,
encoder_freeze_on_warmup: bool = False,
sample_rate=48000,
loss_config: dict = None,
optimizer_configs: dict = None,
use_ema: bool = True,
ema_copy = None,
force_input_mono = False,
latent_mask_ratio = 0.0,
teacher_model: AudioAutoencoder = None
):
super().__init__()
self.automatic_optimization = False
self.autoencoder = autoencoder
self.warmed_up = False
self.warmup_steps = warmup_steps
self.encoder_freeze_on_warmup = encoder_freeze_on_warmup
self.lr = lr
self.force_input_mono = force_input_mono
self.teacher_model = teacher_model
if optimizer_configs is None:
optimizer_configs ={
"autoencoder": {
"optimizer": {
"type": "AdamW",
"config": {
"lr": lr,
"betas": (.8, .99)
}
}
},
"discriminator": {
"optimizer": {
"type": "AdamW",
"config": {
"lr": lr,
"betas": (.8, .99)
}
}
}
}
self.optimizer_configs = optimizer_configs
if loss_config is None:
scales = [2048, 1024, 512, 256, 128, 64, 32]
hop_sizes = []
win_lengths = []
overlap = 0.75
for s in scales:
hop_sizes.append(int(s * (1 - overlap)))
win_lengths.append(s)
loss_config = {
"discriminator": {
"type": "encodec",
"config": {
"n_ffts": scales,
"hop_lengths": hop_sizes,
"win_lengths": win_lengths,
"filters": 32
},
"weights": {
"adversarial": 0.1,
"feature_matching": 5.0,
}
},
"spectral": {
"type": "mrstft",
"config": {
"fft_sizes": scales,
"hop_sizes": hop_sizes,
"win_lengths": win_lengths,
"perceptual_weighting": True
},
"weights": {
"mrstft": 1.0,
}
},
"time": {
"type": "l1",
"config": {},
"weights": {
"l1": 0.0,
}
}
}
self.loss_config = loss_config
# Spectral reconstruction loss
stft_loss_args = loss_config['spectral']['config']
if self.autoencoder.out_channels == 2:
self.sdstft = SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
self.lrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
elif self.autoencoder.out_channels == 4:
# self.sdstft = SpatialSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
self.sdstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
else:
self.sdstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
# Discriminator
if loss_config['discriminator']['type'] == 'oobleck':
self.discriminator = OobleckDiscriminator(**loss_config['discriminator']['config'])
elif loss_config['discriminator']['type'] == 'encodec':
self.discriminator = EncodecDiscriminator(in_channels=self.autoencoder.out_channels, **loss_config['discriminator']['config'])
elif loss_config['discriminator']['type'] == 'dac':
self.discriminator = DACGANLoss(channels=self.autoencoder.out_channels, sample_rate=sample_rate, **loss_config['discriminator']['config'])
self.gen_loss_modules = []
# Adversarial and feature matching losses
self.gen_loss_modules += [
ValueLoss(key='loss_adv', weight=self.loss_config['discriminator']['weights']['adversarial'], name='loss_adv'),
ValueLoss(key='feature_matching_distance', weight=self.loss_config['discriminator']['weights']['feature_matching'], name='feature_matching'),
]
if self.teacher_model is not None:
# Distillation losses
stft_loss_weight = self.loss_config['spectral']['weights']['mrstft'] * 0.25
self.gen_loss_modules += [
AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=stft_loss_weight), # Reconstruction loss
AuralossLoss(self.sdstft, 'decoded', 'teacher_decoded', name='mrstft_loss_distill', weight=stft_loss_weight), # Distilled model's decoder is compatible with teacher's decoder
AuralossLoss(self.sdstft, 'reals', 'own_latents_teacher_decoded', name='mrstft_loss_own_latents_teacher', weight=stft_loss_weight), # Distilled model's encoder is compatible with teacher's decoder
AuralossLoss(self.sdstft, 'reals', 'teacher_latents_own_decoded', name='mrstft_loss_teacher_latents_own', weight=stft_loss_weight) # Teacher's encoder is compatible with distilled model's decoder
]
else:
# Reconstruction loss
self.gen_loss_modules += [
AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']),
]
if self.autoencoder.out_channels == 2:
# Add left and right channel reconstruction losses in addition to the sum and difference
self.gen_loss_modules += [
AuralossLoss(self.lrstft, 'reals_left', 'decoded_left', name='stft_loss_left', weight=self.loss_config['spectral']['weights']['mrstft']/2),
AuralossLoss(self.lrstft, 'reals_right', 'decoded_right', name='stft_loss_right', weight=self.loss_config['spectral']['weights']['mrstft']/2),
]
elif self.autoencoder.out_channels == 4:
# self.gen_loss_modules += [
# AuralossLoss(self.lrstft, 'reals', 'decoded', name='stft_loss', weight=self.loss_config['spectral']['weights']['mrstft']),
# ]
# Add left and right channel reconstruction losses in addition to the sum and difference
self.gen_loss_modules += [
AuralossLoss(self.sdstft, 'reals_w', 'decoded_w', name='stft_loss_w', weight=self.loss_config['spectral']['weights']['mrstft']/4),
AuralossLoss(self.sdstft, 'reals_x', 'decoded_x', name='stft_loss_x', weight=self.loss_config['spectral']['weights']['mrstft']/4),
AuralossLoss(self.sdstft, 'reals_y', 'decoded_y', name='stft_loss_y', weight=self.loss_config['spectral']['weights']['mrstft']/4),
AuralossLoss(self.sdstft, 'reals_z', 'decoded_z', name='stft_loss_z', weight=self.loss_config['spectral']['weights']['mrstft']/4),
]
self.gen_loss_modules += [
AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']),
]
if self.loss_config['time']['weights']['l1'] > 0.0:
self.gen_loss_modules.append(L1Loss(key_a='reals', key_b='decoded', weight=self.loss_config['time']['weights']['l1'], name='l1_time_loss'))
if self.autoencoder.bottleneck is not None:
self.gen_loss_modules += create_loss_modules_from_bottleneck(self.autoencoder.bottleneck, self.loss_config)
self.losses_gen = MultiLoss(self.gen_loss_modules)
self.disc_loss_modules = [
ValueLoss(key='loss_dis', weight=1.0, name='discriminator_loss'),
]
self.losses_disc = MultiLoss(self.disc_loss_modules)
# Set up EMA for model weights
self.autoencoder_ema = None
self.use_ema = use_ema
if self.use_ema:
self.autoencoder_ema = EMA(
self.autoencoder,
ema_model=ema_copy,
beta=0.9999,
power=3/4,
update_every=1,
update_after_step=1
)
self.latent_mask_ratio = latent_mask_ratio
def configure_optimizers(self):
opt_gen = create_optimizer_from_config(self.optimizer_configs['autoencoder']['optimizer'], self.autoencoder.parameters())
opt_disc = create_optimizer_from_config(self.optimizer_configs['discriminator']['optimizer'], self.discriminator.parameters())
if "scheduler" in self.optimizer_configs['autoencoder'] and "scheduler" in self.optimizer_configs['discriminator']:
sched_gen = create_scheduler_from_config(self.optimizer_configs['autoencoder']['scheduler'], opt_gen)
sched_disc = create_scheduler_from_config(self.optimizer_configs['discriminator']['scheduler'], opt_disc)
return [opt_gen, opt_disc], [sched_gen, sched_disc]
return [opt_gen, opt_disc]
def training_step(self, batch, batch_idx):
reals, _ = batch
# Remove extra dimension added by WebDataset
if reals.ndim == 4 and reals.shape[0] == 1:
reals = reals[0]
if self.global_step >= self.warmup_steps:
self.warmed_up = True
loss_info = {}
loss_info["reals"] = reals
encoder_input = reals
if self.force_input_mono and encoder_input.shape[1] > 1:
encoder_input = encoder_input.mean(dim=1, keepdim=True)
loss_info["encoder_input"] = encoder_input
data_std = encoder_input.std()
if self.warmed_up and self.encoder_freeze_on_warmup:
with torch.no_grad():
latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True)
else:
latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True)
loss_info["latents"] = latents
loss_info.update(encoder_info)
# Encode with teacher model for distillation
if self.teacher_model is not None:
with torch.no_grad():
teacher_latents = self.teacher_model.encode(encoder_input, return_info=False)
loss_info['teacher_latents'] = teacher_latents
# Optionally mask out some latents for noise resistance
if self.latent_mask_ratio > 0.0:
mask = torch.rand_like(latents) < self.latent_mask_ratio
latents = torch.where(mask, torch.zeros_like(latents), latents)
decoded = self.autoencoder.decode(latents)
loss_info["decoded"] = decoded
if self.autoencoder.out_channels == 2:
loss_info["decoded_left"] = decoded[:, 0:1, :]
loss_info["decoded_right"] = decoded[:, 1:2, :]
loss_info["reals_left"] = reals[:, 0:1, :]
loss_info["reals_right"] = reals[:, 1:2, :]
elif self.autoencoder.out_channels == 4:
loss_info["decoded_w"] = decoded[:, 0:1, :]
loss_info["decoded_x"] = decoded[:, 1:2, :]
loss_info["decoded_y"] = decoded[:, 2:3, :]
loss_info["decoded_z"] = decoded[:, 3:4, :]
loss_info["reals_w"] = reals[:, 0:1, :]
loss_info["reals_x"] = reals[:, 1:2, :]
loss_info["reals_y"] = reals[:, 2:3, :]
loss_info["reals_z"] = reals[:, 3:4, :]
# Distillation
if self.teacher_model is not None:
with torch.no_grad():
teacher_decoded = self.teacher_model.decode(teacher_latents)
own_latents_teacher_decoded = self.teacher_model.decode(latents) #Distilled model's latents decoded by teacher
teacher_latents_own_decoded = self.autoencoder.decode(teacher_latents) #Teacher's latents decoded by distilled model
loss_info['teacher_decoded'] = teacher_decoded
loss_info['own_latents_teacher_decoded'] = own_latents_teacher_decoded
loss_info['teacher_latents_own_decoded'] = teacher_latents_own_decoded
if self.warmed_up:
loss_dis, loss_adv, feature_matching_distance = self.discriminator.loss(reals, decoded)
else:
loss_dis = torch.tensor(0.).to(reals)
loss_adv = torch.tensor(0.).to(reals)
feature_matching_distance = torch.tensor(0.).to(reals)
loss_info["loss_dis"] = loss_dis
loss_info["loss_adv"] = loss_adv
loss_info["feature_matching_distance"] = feature_matching_distance
opt_gen, opt_disc = self.optimizers()
lr_schedulers = self.lr_schedulers()
sched_gen = None
sched_disc = None
if lr_schedulers is not None:
sched_gen, sched_disc = lr_schedulers
# Train the discriminator
if self.global_step % 2 and self.warmed_up:
loss, losses = self.losses_disc(loss_info)
log_dict = {
'train/disc_lr': opt_disc.param_groups[0]['lr']
}
opt_disc.zero_grad()
self.manual_backward(loss)
opt_disc.step()
if sched_disc is not None:
# sched step every step
sched_disc.step()
# Train the generator
else:
# import ipdb
# ipdb.set_trace()
loss, losses = self.losses_gen(loss_info)
if self.use_ema:
self.autoencoder_ema.update()
opt_gen.zero_grad()
self.manual_backward(loss)
opt_gen.step()
if sched_gen is not None:
# scheduler step every step
sched_gen.step()
log_dict = {
'train/loss': loss.detach(),
'train/latent_std': latents.std().detach(),
'train/data_std': data_std.detach(),
'train/gen_lr': opt_gen.param_groups[0]['lr']
}
for loss_name, loss_value in losses.items():
log_dict[f'train/{loss_name}'] = loss_value.detach()
self.log_dict(log_dict, prog_bar=True, on_step=True)
return loss
def export_model(self, path, use_safetensors=False):
if self.autoencoder_ema is not None:
model = self.autoencoder_ema.ema_model
else:
model = self.autoencoder
if use_safetensors:
save_model(model, path)
else:
torch.save({"state_dict": model.state_dict()}, path)
class AutoencoderDemoCallback(Callback):
def __init__(
self,
demo_dl,
demo_every=2000,
sample_size=65536,
sample_rate=48000
):
super().__init__()
self.demo_every = demo_every
self.demo_samples = sample_size
self.demo_dl = iter(demo_dl)
self.sample_rate = sample_rate
self.last_demo_step = -1
@rank_zero_only
@torch.no_grad()
def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx):
if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
return
self.last_demo_step = trainer.global_step
module.eval()
try:
demo_reals, _ = next(self.demo_dl)
# Remove extra dimension added by WebDataset
if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
demo_reals = demo_reals[0]
encoder_input = demo_reals
encoder_input = encoder_input.to(module.device)
if module.force_input_mono:
encoder_input = encoder_input.mean(dim=1, keepdim=True)
demo_reals = demo_reals.to(module.device)
with torch.no_grad():
if module.use_ema:
latents = module.autoencoder_ema.ema_model.encode(encoder_input)
fakes = module.autoencoder_ema.ema_model.decode(latents)
else:
latents = module.autoencoder.encode(encoder_input)
fakes = module.autoencoder.decode(latents)
#Interleave reals and fakes
reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n')
# Put the demos together
reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)')
log_dict = {}
filename = f'demos/recon_{trainer.global_step:08}.wav'
reals_fakes = reals_fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
torchaudio.save(filename, reals_fakes, self.sample_rate)
log_dict[f'recon'] = wandb.Audio(filename,
sample_rate=self.sample_rate,
caption=f'Reconstructed')
log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents)
log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents))
log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes))
trainer.logger.experiment.log(log_dict)
except Exception as e:
print(f'{type(e).__name__}: {e}')
raise e
finally:
module.train()
def create_loss_modules_from_bottleneck(bottleneck, loss_config):
losses = []
if isinstance(bottleneck, VAEBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck) or isinstance(bottleneck, RVQVAEBottleneck):
try:
kl_weight = loss_config['bottleneck']['weights']['kl']
except:
kl_weight = 1e-6
kl_loss = ValueLoss(key='kl', weight=kl_weight, name='kl_loss')
losses.append(kl_loss)
if isinstance(bottleneck, RVQBottleneck) or isinstance(bottleneck, RVQVAEBottleneck):
quantizer_loss = ValueLoss(key='quantizer_loss', weight=1.0, name='quantizer_loss')
losses.append(quantizer_loss)
if isinstance(bottleneck, DACRVQBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck):
codebook_loss = ValueLoss(key='vq/codebook_loss', weight=1.0, name='codebook_loss')
commitment_loss = ValueLoss(key='vq/commitment_loss', weight=0.25, name='commitment_loss')
losses.append(codebook_loss)
losses.append(commitment_loss)
if isinstance(bottleneck, WassersteinBottleneck):
try:
mmd_weight = loss_config['bottleneck']['weights']['mmd']
except:
mmd_weight = 100
mmd_loss = ValueLoss(key='mmd', weight=mmd_weight, name='mmd_loss')
losses.append(mmd_loss)
return losses