Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| from collections import defaultdict | |
| import matplotlib | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| from tensorboardX import SummaryWriter | |
| from tqdm import tqdm | |
| from utils.tools import save_checkpoint, load_checkpoint | |
| # set to avoid matplotlib error in CLI environment | |
| matplotlib.use("Agg") | |
| class Trainer(object): | |
| """Customized trainer module for FastSVC training.""" | |
| def __init__( | |
| self, | |
| steps, | |
| epochs, | |
| data_loader, | |
| sampler, | |
| model, | |
| criterion, | |
| optimizer, | |
| scheduler, | |
| config, | |
| device=torch.device("cpu"), | |
| ): | |
| """Initialize trainer. | |
| Args: | |
| steps (int): Initial global steps. | |
| epochs (int): Initial global epochs. | |
| data_loader (dict): Dict of data loaders. It must contrain "train" and "dev" loaders. | |
| model (dict): Dict of models. It must contrain "generator" and "discriminator" models. | |
| criterion (dict): Dict of criterions. It must contrain "stft" and "mse" criterions. | |
| optimizer (dict): Dict of optimizers. It must contrain "generator" and "discriminator" optimizers. | |
| scheduler (dict): Dict of schedulers. It must contrain "generator" and "discriminator" schedulers. | |
| config (dict): Config dict loaded from yaml format configuration file. | |
| device (torch.deive): Pytorch device instance. | |
| """ | |
| self.steps = steps | |
| self.epochs = epochs | |
| self.data_loader = data_loader | |
| self.sampler = sampler | |
| self.model = model | |
| self.criterion = criterion | |
| self.optimizer = optimizer | |
| self.scheduler = scheduler | |
| self.config = config | |
| self.device = device | |
| tensorboard_dir = os.path.join(config.interval_config.out_dir, 'logs') | |
| os.makedirs(tensorboard_dir, exist_ok=True) | |
| self.writer = SummaryWriter(tensorboard_dir) | |
| self.finish_train = False | |
| self.total_train_loss = defaultdict(float) | |
| self.total_eval_loss = defaultdict(float) | |
| def run(self): | |
| """Run training.""" | |
| self.tqdm = tqdm( | |
| initial=self.steps, total=self.config.training_config.train_max_steps, desc="[train]" | |
| ) | |
| while True: | |
| # train one epoch | |
| self._train_epoch() | |
| # check whether training is finished | |
| if self.finish_train: | |
| break | |
| self.tqdm.close() | |
| logging.info("Finished training.") | |
| def _train_step(self, batch): | |
| """Train model one step.""" | |
| # parse batch | |
| x, y = batch # x: (mels, pitch, ld, spk_index), y: audio | |
| x = tuple([x_.to(self.device) for x_ in x]) | |
| y = y.to(self.device) | |
| ####################### | |
| # Generator # | |
| ####################### | |
| if self.steps > 0: | |
| y_ = self.model["generator"](*x) | |
| # initialize | |
| gen_loss = 0.0 | |
| # multi-resolution sfft loss | |
| sc_loss, mag_loss = self.criterion["stft"](y_, y) | |
| gen_loss += sc_loss + mag_loss | |
| self.total_train_loss[ | |
| "train/spectral_convergence_loss" | |
| ] += sc_loss.item() | |
| self.total_train_loss[ | |
| "train/log_stft_magnitude_loss" | |
| ] += mag_loss.item() | |
| # weighting aux loss | |
| gen_loss *= self.config.loss_config.lambda_aux | |
| # adversarial loss | |
| if self.steps > self.config.training_config.discriminator_train_start_steps: | |
| p_ = self.model["discriminator"](y_.unsqueeze(1)) | |
| adv_loss = self.criterion["gen_adv"](p_) | |
| self.total_train_loss["train/adversarial_loss"] += adv_loss.item() | |
| # add adversarial loss to generator loss | |
| gen_loss += self.config.loss_config.lambda_adv * adv_loss | |
| self.total_train_loss["train/generator_loss"] += gen_loss.item() | |
| # update generator | |
| self.optimizer["generator"].zero_grad() | |
| self.optimizer["discriminator"].zero_grad() | |
| gen_loss.backward() | |
| if self.config.training_config.generator_grad_norm > 0: | |
| torch.nn.utils.clip_grad_norm_( | |
| self.model["generator"].parameters(), | |
| self.config.training_config.generator_grad_norm, | |
| ) | |
| self.optimizer["generator"].step() | |
| self.scheduler["generator"].step() | |
| ####################### | |
| # Discriminator # | |
| ####################### | |
| if self.steps > self.config.training_config.discriminator_train_start_steps: | |
| # re-compute y_ which leads better quality | |
| with torch.no_grad(): | |
| y_ = self.model["generator"](*x) | |
| # discriminator loss | |
| p = self.model["discriminator"](y.unsqueeze(1)) | |
| p_ = self.model["discriminator"](y_.unsqueeze(1).detach()) | |
| real_loss, fake_loss = self.criterion["dis_adv"](p_, p) | |
| dis_loss = real_loss + fake_loss | |
| self.total_train_loss["train/real_loss"] += real_loss.item() | |
| self.total_train_loss["train/fake_loss"] += fake_loss.item() | |
| self.total_train_loss["train/discriminator_loss"] += dis_loss.item() | |
| # update discriminator | |
| self.optimizer["discriminator"].zero_grad() | |
| dis_loss.backward() | |
| if self.config.training_config.discriminator_grad_norm > 0: | |
| torch.nn.utils.clip_grad_norm_( | |
| self.model["discriminator"].parameters(), | |
| self.config.training_config.discriminator_grad_norm, | |
| ) | |
| self.optimizer["discriminator"].step() | |
| self.scheduler["discriminator"].step() | |
| # update counts | |
| self.steps += 1 | |
| self.tqdm.update(1) | |
| self._check_train_finish() | |
| def _train_epoch(self): | |
| """Train model one epoch.""" | |
| for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1): | |
| # train one step | |
| self._train_step(batch) | |
| # check interval | |
| if self.config.training_config.rank == 0: | |
| self._check_log_interval() | |
| self._check_eval_interval() | |
| self._check_save_interval() | |
| # check whether training is finished | |
| if self.finish_train: | |
| return | |
| # update | |
| self.epochs += 1 | |
| self.train_steps_per_epoch = train_steps_per_epoch | |
| logging.info( | |
| f"(Steps: {self.steps}) Finished {self.epochs} epoch training " | |
| f"({self.train_steps_per_epoch} steps per epoch)." | |
| ) | |
| # needed for shuffle in distributed training | |
| if self.config.training_config.distributed: | |
| self.sampler["train"].set_epoch(self.epochs) | |
| def _eval_step(self, batch): | |
| """Evaluate model one step.""" | |
| # parse batch | |
| x, y = batch | |
| x = tuple([x_.to(self.device) for x_ in x]) | |
| y = y.to(self.device) | |
| ####################### | |
| # Generator # | |
| ####################### | |
| y_ = self.model["generator"](*x) | |
| # initialize | |
| aux_loss = 0.0 | |
| # multi-resolution stft loss | |
| sc_loss, mag_loss = self.criterion["stft"](y_, y) | |
| aux_loss += sc_loss + mag_loss | |
| self.total_eval_loss["eval/spectral_convergence_loss"] += sc_loss.item() | |
| self.total_eval_loss["eval/log_stft_magnitude_loss"] += mag_loss.item() | |
| # weighting stft loss | |
| aux_loss *= self.config.loss_config.lambda_aux | |
| # adversarial loss | |
| p_ = self.model["discriminator"](y_.unsqueeze(1)) | |
| adv_loss = self.criterion["gen_adv"](p_) | |
| gen_loss = aux_loss + self.config.loss_config.lambda_adv * adv_loss | |
| ####################### | |
| # Discriminator # | |
| ####################### | |
| p = self.model["discriminator"](y.unsqueeze(1)) | |
| p_ = self.model["discriminator"](y_.unsqueeze(1)) | |
| # discriminator loss | |
| real_loss, fake_loss = self.criterion["dis_adv"](p_, p) | |
| dis_loss = real_loss + fake_loss | |
| # add to total eval loss | |
| self.total_eval_loss["eval/adversarial_loss"] += adv_loss.item() | |
| self.total_eval_loss["eval/generator_loss"] += gen_loss.item() | |
| self.total_eval_loss["eval/real_loss"] += real_loss.item() | |
| self.total_eval_loss["eval/fake_loss"] += fake_loss.item() | |
| self.total_eval_loss["eval/discriminator_loss"] += dis_loss.item() | |
| def _eval_epoch(self): | |
| """Evaluate model one epoch.""" | |
| logging.info(f"(Steps: {self.steps}) Start evaluation.") | |
| # change mode | |
| for key in self.model.keys(): | |
| self.model[key].eval() | |
| # calculate loss for each batch | |
| for eval_steps_per_epoch, batch in enumerate( | |
| tqdm(self.data_loader["dev"], desc="[eval]"), 1 | |
| ): | |
| # eval one step | |
| self._eval_step(batch) | |
| # save intermediate result | |
| if eval_steps_per_epoch == 1: | |
| self._genearete_and_save_intermediate_result(batch) | |
| logging.info( | |
| f"(Steps: {self.steps}) Finished evaluation " | |
| f"({eval_steps_per_epoch} steps per epoch)." | |
| ) | |
| # average loss | |
| for key in self.total_eval_loss.keys(): | |
| self.total_eval_loss[key] /= eval_steps_per_epoch | |
| logging.info( | |
| f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}." | |
| ) | |
| # record | |
| self._write_to_tensorboard(self.total_eval_loss) | |
| # reset | |
| self.total_eval_loss = defaultdict(float) | |
| # restore mode | |
| for key in self.model.keys(): | |
| self.model[key].train() | |
| def _genearete_and_save_intermediate_result(self, batch): | |
| """Generate and save intermediate result.""" | |
| # delayed import to avoid error related backend error | |
| import matplotlib.pyplot as plt | |
| # generate | |
| x_batch, y_batch = batch | |
| x_batch = tuple([x.to(self.device) for x in x_batch]) | |
| y_batch = y_batch.to(self.device) | |
| y_batch_ = self.model["generator"](*x_batch) | |
| # check directory | |
| dirname = os.path.join(self.config.interval_config.out_dir, f"predictions/{self.steps}steps") | |
| if not os.path.exists(dirname): | |
| os.makedirs(dirname) | |
| for idx, (y, y_) in enumerate(zip(y_batch, y_batch_), 1): | |
| # convert to ndarray | |
| y, y_ = y.view(-1).cpu().numpy(), y_.view(-1).cpu().numpy() | |
| # plot figure and save it | |
| figname = os.path.join(dirname, f"{idx}.png") | |
| plt.subplot(2, 1, 1) | |
| plt.plot(y) | |
| plt.title("groundtruth speech") | |
| plt.subplot(2, 1, 2) | |
| plt.plot(y_) | |
| plt.title(f"generated speech @ {self.steps} steps") | |
| plt.tight_layout() | |
| plt.savefig(figname) | |
| plt.close() | |
| # save as wavfile | |
| y = np.clip(y, -1, 1) | |
| y_ = np.clip(y_, -1, 1) | |
| sf.write( | |
| figname.replace(".png", "_ref.wav"), | |
| y, | |
| self.config.data_config.sampling_rate, | |
| "PCM_16", | |
| ) | |
| sf.write( | |
| figname.replace(".png", "_gen.wav"), | |
| y_, | |
| self.config.data_config.sampling_rate, | |
| "PCM_16", | |
| ) | |
| if idx >= self.config.interval_config.num_save_intermediate_results: | |
| break | |
| def _write_to_tensorboard(self, loss): | |
| """Write to tensorboard.""" | |
| for key, value in loss.items(): | |
| self.writer.add_scalar(key, value, self.steps) | |
| def _check_save_interval(self): | |
| if self.steps % self.config.interval_config.save_interval_steps == 0: | |
| self.save_checkpoint( | |
| os.path.join(self.config.interval_config.out_dir, f"checkpoint-{self.steps}steps.pkl"), self.config.training_config.distributed | |
| ) | |
| logging.info(f"Successfully saved checkpoint @ {self.steps} steps.") | |
| def _check_eval_interval(self): | |
| if self.steps % self.config.interval_config.eval_interval_steps == 0: | |
| self._eval_epoch() | |
| def _check_log_interval(self): | |
| if self.steps % self.config.interval_config.log_interval_steps == 0: | |
| for key in self.total_train_loss.keys(): | |
| self.total_train_loss[key] /= self.config.interval_config.log_interval_steps | |
| logging.info( | |
| f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}." | |
| ) | |
| self._write_to_tensorboard(self.total_train_loss) | |
| # reset | |
| self.total_train_loss = defaultdict(float) | |
| def _check_train_finish(self): | |
| if self.steps >= self.config.training_config.train_max_steps: | |
| self.finish_train = True | |
| def load_checkpoint(self, cp_path, load_only_params, dst_train): | |
| self.steps, self.epochs = load_checkpoint(model=self.model, optimizer=self.optimizer, scheduler=self.scheduler, checkpoint_path=cp_path, load_only_params=load_only_params, dst_train=dst_train) | |
| def save_checkpoint(self, cp_path, dst_train): | |
| save_checkpoint(steps=self.steps, epochs=self.epochs, model=self.model, optimizer=self.optimizer, scheduler=self.scheduler, checkpoint_path=cp_path, dst_train=dst_train) | |