Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| from tqdm import tqdm | |
| import torch | |
| from torch.nn.parallel import DistributedDataParallel | |
| from optimizer.optimizers import Eve, ScaledAdam | |
| from schedulers.scheduler import NoamScheduler, Eden | |
| from models.tts.valle.valle_dataset import VALLEDataset, VALLECollator | |
| from models.tts.base import TTSTrainer | |
| from models.tts.valle.valle import VALLE | |
| class VALLETrainer(TTSTrainer): | |
| def __init__(self, args, cfg): | |
| TTSTrainer.__init__(self, args, cfg) | |
| def _build_model(self): | |
| model = VALLE(self.cfg.model) | |
| return model | |
| def _build_dataset(self): | |
| return VALLEDataset, VALLECollator | |
| def _build_optimizer(self): | |
| if self.args.train_stage: | |
| if isinstance(self.model, DistributedDataParallel): | |
| model = self.model.module | |
| else: | |
| model = self.model | |
| model_parameters = model.stage_parameters(self.args.train_stage) | |
| else: | |
| model_parameters = self.model.parameters() | |
| if self.cfg.train.optimizer == "ScaledAdam": | |
| parameters_names = [] | |
| if self.args.train_stage != 0: | |
| parameters_names.append( | |
| [ | |
| name_param_pair[0] | |
| for name_param_pair in model.stage_named_parameters( | |
| self.args.train_stage | |
| ) | |
| ] | |
| ) | |
| else: | |
| parameters_names.append( | |
| [ | |
| name_param_pair[0] | |
| for name_param_pair in model.named_parameters() | |
| ] | |
| ) | |
| optimizer = ScaledAdam( | |
| model_parameters, | |
| lr=self.cfg.train.base_lr, | |
| betas=(0.9, 0.95), | |
| clipping_scale=2.0, | |
| parameters_names=parameters_names, | |
| show_dominant_parameters=False, | |
| clipping_update_period=1000, | |
| ) | |
| elif self.cfg.train.optimizer == "Eve": | |
| optimizer = Eve( | |
| model_parameters, | |
| lr=self.cfg.train.base_lr, | |
| betas=(0.9, 0.98), | |
| target_rms=0.1, | |
| ) | |
| elif self.cfg.train.optimizer == "AdamW": | |
| optimizer = torch.optim.AdamW( | |
| model_parameters, | |
| lr=self.cfg.train.base_lr, | |
| betas=(0.9, 0.95), | |
| weight_decay=1e-2, | |
| eps=1e-8, | |
| ) | |
| elif self.cfg.train.optimizer == "Adam": | |
| optimizer = torch.optim.Adam( | |
| model_parameters, | |
| lr=self.cfg.train.base_lr, | |
| betas=(0.9, 0.95), | |
| eps=1e-8, | |
| ) | |
| else: | |
| raise NotImplementedError() | |
| return optimizer | |
| def _build_scheduler(self): | |
| if self.cfg.train.scheduler.lower() == "eden": | |
| scheduler = Eden(self.optimizer, 5000, 4, warmup_batches=self.cfg.train.warmup_steps) | |
| elif self.cfg.train.scheduler.lower() == "noam": | |
| scheduler = NoamScheduler( | |
| self.cfg.train.base_lr, | |
| self.optimizer, | |
| self.cfg.model.decoder_dim, | |
| warmup_steps=self.cfg.train.warmup_steps, | |
| ) | |
| elif self.cfg.train.scheduler.lower() == "cosine": | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| self.cfg.train.warmup_steps, | |
| self.optimizer, | |
| eta_min=self.cfg.train.base_lr, | |
| ) | |
| else: | |
| raise NotImplementedError(f"{self.cfg.train.scheduler}") | |
| return scheduler | |
| def _train_epoch(self): | |
| r"""Training epoch. Should return average loss of a batch (sample) over | |
| one epoch. See ``train_loop`` for usage. | |
| """ | |
| if isinstance(self.model, dict): | |
| for key in self.model.keys(): | |
| self.model[key].train() | |
| else: | |
| self.model.train() | |
| epoch_sum_loss: float = 0.0 | |
| epoch_losses: dict = {} | |
| epoch_step: int = 0 | |
| for batch in tqdm( | |
| self.train_dataloader, | |
| desc=f"Training Epoch {self.epoch}", | |
| unit="batch", | |
| colour="GREEN", | |
| leave=False, | |
| dynamic_ncols=True, | |
| smoothing=0.04, | |
| disable=not self.accelerator.is_main_process, | |
| ): | |
| # Do training step and BP | |
| with self.accelerator.accumulate(self.model): | |
| total_loss, train_losses = self._train_step(batch) | |
| self.accelerator.backward(total_loss) | |
| self.optimizer.step() | |
| self.optimizer.zero_grad() | |
| self.batch_count += 1 | |
| if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: | |
| if self.cfg.train.optimizer not in ["ScaledAdam", "Eve"]: | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) | |
| for k in range(self.cfg.train.gradient_accumulation_step): | |
| if isinstance(self.scheduler, Eden): | |
| self.scheduler.step_batch(self.step) | |
| else: | |
| self.scheduler.step() | |
| epoch_sum_loss += total_loss.detach().cpu().item() | |
| if isinstance(train_losses, dict): | |
| for key, value in train_losses.items(): | |
| if key not in epoch_losses.keys(): | |
| epoch_losses[key] = value | |
| else: | |
| epoch_losses[key] += value | |
| if isinstance(train_losses, dict): | |
| for key, loss in train_losses.items(): | |
| self.accelerator.log( | |
| {"Step/Train {}".format(key): "{:.6f}".format(loss)}, | |
| step=self.step, | |
| ) | |
| else: | |
| self.accelerator.log( | |
| {"Step/Train Loss": loss}, | |
| step=self.step, | |
| ) | |
| self.accelerator.log( | |
| {"Step/lr": self.scheduler.get_last_lr()[0]}, | |
| step=self.step, | |
| ) | |
| # print loss every log_epoch_step steps | |
| # if epoch_step % self.cfg.train.log_epoch_step == 0: | |
| # for key, loss in train_losses.items(): | |
| # self.logger.info("Step/Train {}: {:.6f}".format(key, loss)) | |
| # print("Step/Train {}: {:.6f}".format(key, loss)) | |
| self.step += 1 | |
| epoch_step += 1 | |
| self.accelerator.wait_for_everyone() | |
| epoch_sum_loss = ( | |
| epoch_sum_loss | |
| / len(self.train_dataloader) | |
| * self.cfg.train.gradient_accumulation_step | |
| ) | |
| for key in epoch_losses.keys(): | |
| epoch_losses[key] = ( | |
| epoch_losses[key] | |
| / len(self.train_dataloader) | |
| * self.cfg.train.gradient_accumulation_step | |
| ) | |
| return epoch_sum_loss, epoch_losses | |
| def _train_step(self, batch, is_training=True): | |
| text_tokens = batch["phone_seq"].to(self.device) | |
| text_tokens_lens = batch["phone_len"].to(self.device) | |
| assert text_tokens.ndim == 2 | |
| audio_features = batch["acoustic_token"].to(self.device) | |
| audio_features_lens = batch["target_len"].to(self.device) | |
| assert audio_features.ndim == 3 | |
| with torch.set_grad_enabled(is_training): | |
| loss, losses = self.model( | |
| x=text_tokens, | |
| x_lens=text_tokens_lens, | |
| y=audio_features, | |
| y_lens=audio_features_lens, | |
| train_stage=self.args.train_stage | |
| ) | |
| assert loss.requires_grad == is_training | |
| loss_dict = {} | |
| frames_sum = (audio_features_lens).sum() | |
| avg_loss = loss / frames_sum | |
| loss_dict['loss'] = avg_loss.detach().cpu().item() | |
| for l in losses: | |
| loss_dict[l] = losses[l].detach().cpu().item() / frames_sum.item() | |
| return avg_loss, loss_dict | |
| def _valid_step(self, batch): | |
| valid_losses = {} | |
| total_loss = 0 | |
| valid_stats = {} | |
| total_loss, valid_losses = self._train_step( | |
| batch=batch, | |
| is_training=False, | |
| ) | |
| assert total_loss.requires_grad is False | |
| total_loss = total_loss.detach().cpu().item() | |
| return total_loss, valid_losses, valid_stats | |
| def add_arguments(parser: argparse.ArgumentParser): | |
| parser.add_argument( | |
| "--train_stage", | |
| type=int, | |
| default="1", | |
| help="0: train all modules, 1: AR Decoder, 2: NAR Decoder", | |
| ) | |