Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import os | |
| import json5 | |
| from collections import OrderedDict | |
| from tqdm import tqdm | |
| import json | |
| import shutil | |
| from models.svc.base import SVCTrainer | |
| from modules.encoder.condition_encoder import ConditionEncoder | |
| from models.svc.comosvc.comosvc import ComoSVC | |
| class ComoSVCTrainer(SVCTrainer): | |
| r"""The base trainer for all diffusion models. It inherits from SVCTrainer and | |
| implements ``_build_model`` and ``_forward_step`` methods. | |
| """ | |
| def __init__(self, args=None, cfg=None): | |
| SVCTrainer.__init__(self, args, cfg) | |
| self.distill = cfg.model.comosvc.distill | |
| self.skip_diff = True | |
| if self.distill: # and args.resume is None: | |
| self.teacher_model_path = cfg.model.teacher_model_path | |
| self.teacher_state_dict = self._load_teacher_state_dict() | |
| self._load_teacher_model(self.teacher_state_dict) | |
| self.acoustic_mapper.decoder.init_consistency_training() | |
| ### Following are methods only for comoSVC models ### | |
| def _load_teacher_state_dict(self): | |
| self.checkpoint_file = self.teacher_model_path | |
| print("Load teacher acoustic model from {}".format(self.checkpoint_file)) | |
| raw_state_dict = torch.load(self.checkpoint_file) # , map_location=self.device) | |
| return raw_state_dict | |
| def _load_teacher_model(self, state_dict): | |
| raw_dict = state_dict | |
| clean_dict = OrderedDict() | |
| for k, v in raw_dict.items(): | |
| if k.startswith("module."): | |
| clean_dict[k[7:]] = v | |
| else: | |
| clean_dict[k] = v | |
| self.model.load_state_dict(clean_dict) | |
| def _build_model(self): | |
| r"""Build the model for training. This function is called in ``__init__`` function.""" | |
| # TODO: sort out the config | |
| self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min | |
| self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max | |
| self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) | |
| self.acoustic_mapper = ComoSVC(self.cfg) | |
| model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) | |
| return model | |
| def _forward_step(self, batch): | |
| r"""Forward step for training and inference. This function is called | |
| in ``_train_step`` & ``_test_step`` function. | |
| """ | |
| loss = {} | |
| mask = batch["mask"] | |
| mel_input = batch["mel"] | |
| cond = self.condition_encoder(batch) | |
| if self.distill: | |
| cond = cond.detach() | |
| self.skip_diff = True if self.step < self.cfg.train.fast_steps else False | |
| ssim_loss, prior_loss, diff_loss = self.acoustic_mapper.compute_loss( | |
| mask, cond, mel_input, skip_diff=self.skip_diff | |
| ) | |
| if self.distill: | |
| loss["distil_loss"] = diff_loss | |
| else: | |
| loss["ssim_loss_encoder"] = ssim_loss | |
| loss["prior_loss_encoder"] = prior_loss | |
| loss["diffusion_loss_decoder"] = diff_loss | |
| return loss | |
| def _train_epoch(self): | |
| r"""Training epoch. Should return average loss of a batch (sample) over | |
| one epoch. See ``train_loop`` for usage. | |
| """ | |
| self.model.train() | |
| epoch_sum_loss: float = 0.0 | |
| epoch_step: int = 0 | |
| for batch in tqdm( | |
| self.train_dataloader, | |
| desc=f"Training Epoch {self.epoch}", | |
| unit="batch", | |
| colour="GREEN", | |
| leave=False, | |
| dynamic_ncols=True, | |
| smoothing=0.04, | |
| disable=not self.accelerator.is_main_process, | |
| ): | |
| # Do training step and BP | |
| with self.accelerator.accumulate(self.model): | |
| loss = self._train_step(batch) | |
| total_loss = 0 | |
| for k, v in loss.items(): | |
| total_loss += v | |
| self.accelerator.backward(total_loss) | |
| enc_grad_norm = torch.nn.utils.clip_grad_norm_( | |
| self.acoustic_mapper.encoder.parameters(), max_norm=1 | |
| ) | |
| dec_grad_norm = torch.nn.utils.clip_grad_norm_( | |
| self.acoustic_mapper.decoder.parameters(), max_norm=1 | |
| ) | |
| self.optimizer.step() | |
| self.optimizer.zero_grad() | |
| self.batch_count += 1 | |
| # Update info for each step | |
| # TODO: step means BP counts or batch counts? | |
| if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: | |
| epoch_sum_loss += total_loss | |
| log_info = {} | |
| for k, v in loss.items(): | |
| key = "Step/Train Loss/{}".format(k) | |
| log_info[key] = v | |
| log_info["Step/Learning Rate"]: self.optimizer.param_groups[0]["lr"] | |
| self.accelerator.log( | |
| log_info, | |
| step=self.step, | |
| ) | |
| self.step += 1 | |
| epoch_step += 1 | |
| self.accelerator.wait_for_everyone() | |
| return ( | |
| epoch_sum_loss | |
| / len(self.train_dataloader) | |
| * self.cfg.train.gradient_accumulation_step, | |
| loss, | |
| ) | |
| def train_loop(self): | |
| r"""Training loop. The public entry of training process.""" | |
| # Wait everyone to prepare before we move on | |
| self.accelerator.wait_for_everyone() | |
| # dump config file | |
| if self.accelerator.is_main_process: | |
| self.__dump_cfg(self.config_save_path) | |
| self.model.train() | |
| self.optimizer.zero_grad() | |
| # Wait to ensure good to go | |
| self.accelerator.wait_for_everyone() | |
| while self.epoch < self.max_epoch: | |
| self.logger.info("\n") | |
| self.logger.info("-" * 32) | |
| self.logger.info("Epoch {}: ".format(self.epoch)) | |
| ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict) | |
| ### It's inconvenient for the model with multiple losses | |
| # Do training & validating epoch | |
| train_loss, loss = self._train_epoch() | |
| self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss)) | |
| for k, v in loss.items(): | |
| self.logger.info(" |- Train/Loss/{}: {:.6f}".format(k, v)) | |
| valid_loss = self._valid_epoch() | |
| self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss)) | |
| self.accelerator.log( | |
| {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss}, | |
| step=self.epoch, | |
| ) | |
| self.accelerator.wait_for_everyone() | |
| # TODO: what is scheduler? | |
| self.scheduler.step(valid_loss) # FIXME: use epoch track correct? | |
| # Check if hit save_checkpoint_stride and run_eval | |
| run_eval = False | |
| if self.accelerator.is_main_process: | |
| save_checkpoint = False | |
| hit_dix = [] | |
| for i, num in enumerate(self.save_checkpoint_stride): | |
| if self.epoch % num == 0: | |
| save_checkpoint = True | |
| hit_dix.append(i) | |
| run_eval |= self.run_eval[i] | |
| self.accelerator.wait_for_everyone() | |
| if ( | |
| self.accelerator.is_main_process | |
| and save_checkpoint | |
| and (self.distill or not self.skip_diff) | |
| ): | |
| path = os.path.join( | |
| self.checkpoint_dir, | |
| "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( | |
| self.epoch, self.step, train_loss | |
| ), | |
| ) | |
| self.accelerator.save_state(path) | |
| json.dump( | |
| self.checkpoints_path, | |
| open(os.path.join(path, "ckpts.json"), "w"), | |
| ensure_ascii=False, | |
| indent=4, | |
| ) | |
| # Remove old checkpoints | |
| to_remove = [] | |
| for idx in hit_dix: | |
| self.checkpoints_path[idx].append(path) | |
| while len(self.checkpoints_path[idx]) > self.keep_last[idx]: | |
| to_remove.append((idx, self.checkpoints_path[idx].pop(0))) | |
| # Search conflicts | |
| total = set() | |
| for i in self.checkpoints_path: | |
| total |= set(i) | |
| do_remove = set() | |
| for idx, path in to_remove[::-1]: | |
| if path in total: | |
| self.checkpoints_path[idx].insert(0, path) | |
| else: | |
| do_remove.add(path) | |
| # Remove old checkpoints | |
| for path in do_remove: | |
| shutil.rmtree(path, ignore_errors=True) | |
| self.logger.debug(f"Remove old checkpoint: {path}") | |
| self.accelerator.wait_for_everyone() | |
| if run_eval: | |
| # TODO: run evaluation | |
| pass | |
| # Update info for each epoch | |
| self.epoch += 1 | |
| # Finish training and save final checkpoint | |
| self.accelerator.wait_for_everyone() | |
| if self.accelerator.is_main_process: | |
| self.accelerator.save_state( | |
| os.path.join( | |
| self.checkpoint_dir, | |
| "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( | |
| self.epoch, self.step, valid_loss | |
| ), | |
| ) | |
| ) | |
| self.accelerator.end_training() | |
| def _valid_epoch(self): | |
| r"""Testing epoch. Should return average loss of a batch (sample) over | |
| one epoch. See ``train_loop`` for usage. | |
| """ | |
| self.model.eval() | |
| epoch_sum_loss = 0.0 | |
| for batch in tqdm( | |
| self.valid_dataloader, | |
| desc=f"Validating Epoch {self.epoch}", | |
| unit="batch", | |
| colour="GREEN", | |
| leave=False, | |
| dynamic_ncols=True, | |
| smoothing=0.04, | |
| disable=not self.accelerator.is_main_process, | |
| ): | |
| batch_loss = self._valid_step(batch) | |
| for k, v in batch_loss.items(): | |
| epoch_sum_loss += v | |
| self.accelerator.wait_for_everyone() | |
| return epoch_sum_loss / len(self.valid_dataloader) | |
| def __count_parameters(model): | |
| model_param = 0.0 | |
| if isinstance(model, dict): | |
| for key, value in model.items(): | |
| model_param += sum(p.numel() for p in model[key].parameters()) | |
| else: | |
| model_param = sum(p.numel() for p in model.parameters()) | |
| return model_param | |
| def __dump_cfg(self, path): | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| json5.dump( | |
| self.cfg, | |
| open(path, "w"), | |
| indent=4, | |
| sort_keys=True, | |
| ensure_ascii=False, | |
| quote_keys=True, | |
| ) | |