Spaces:
Configuration error
Configuration error
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from torch.optim.lr_scheduler import ExponentialLR | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| import shutil | |
| import accelerate | |
| # from models.svc.base import SVCTrainer | |
| from models.svc.base.svc_dataset import SVCOfflineCollator, SVCOfflineDataset | |
| from models.svc.vits.vits import * | |
| from models.svc.base import SVCTrainer | |
| from utils.mel import mel_spectrogram_torch | |
| import json | |
| from models.vocoders.gan.discriminator.mpd import ( | |
| MultiPeriodDiscriminator_vits as MultiPeriodDiscriminator, | |
| ) | |
| class VitsSVCTrainer(SVCTrainer): | |
| def __init__(self, args, cfg): | |
| self.args = args | |
| self.cfg = cfg | |
| SVCTrainer.__init__(self, args, cfg) | |
| def _accelerator_prepare(self): | |
| ( | |
| self.train_dataloader, | |
| self.valid_dataloader, | |
| ) = self.accelerator.prepare( | |
| self.train_dataloader, | |
| self.valid_dataloader, | |
| ) | |
| if isinstance(self.model, dict): | |
| for key in self.model.keys(): | |
| self.model[key] = self.accelerator.prepare(self.model[key]) | |
| else: | |
| self.model = self.accelerator.prepare(self.model) | |
| if isinstance(self.optimizer, dict): | |
| for key in self.optimizer.keys(): | |
| self.optimizer[key] = self.accelerator.prepare(self.optimizer[key]) | |
| else: | |
| self.optimizer = self.accelerator.prepare(self.optimizer) | |
| if isinstance(self.scheduler, dict): | |
| for key in self.scheduler.keys(): | |
| self.scheduler[key] = self.accelerator.prepare(self.scheduler[key]) | |
| else: | |
| self.scheduler = self.accelerator.prepare(self.scheduler) | |
| def _load_model( | |
| self, | |
| checkpoint_dir: str = None, | |
| checkpoint_path: str = None, | |
| resume_type: str = "", | |
| ): | |
| r"""Load model from checkpoint. If checkpoint_path is None, it will | |
| load the latest checkpoint in checkpoint_dir. If checkpoint_path is not | |
| None, it will load the checkpoint specified by checkpoint_path. **Only use this | |
| method after** ``accelerator.prepare()``. | |
| """ | |
| if checkpoint_path is None: | |
| ls = [str(i) for i in Path(checkpoint_dir).glob("*")] | |
| ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True) | |
| checkpoint_path = ls[0] | |
| self.logger.info("Resume from {}...".format(checkpoint_path)) | |
| if resume_type in ["resume", ""]: | |
| # Load all the things, including model weights, optimizer, scheduler, and random states. | |
| self.accelerator.load_state(input_dir=checkpoint_path) | |
| # set epoch and step | |
| self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1 | |
| self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1 | |
| elif resume_type == "finetune": | |
| # Load only the model weights | |
| accelerate.load_checkpoint_and_dispatch( | |
| self.accelerator.unwrap_model(self.model["generator"]), | |
| os.path.join(checkpoint_path, "pytorch_model.bin"), | |
| ) | |
| accelerate.load_checkpoint_and_dispatch( | |
| self.accelerator.unwrap_model(self.model["discriminator"]), | |
| os.path.join(checkpoint_path, "pytorch_model.bin"), | |
| ) | |
| self.logger.info("Load model weights for finetune...") | |
| else: | |
| raise ValueError("Resume_type must be `resume` or `finetune`.") | |
| return checkpoint_path | |
| def _build_model(self): | |
| net_g = SynthesizerTrn( | |
| self.cfg.preprocess.n_fft // 2 + 1, | |
| self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size, | |
| # directly use cfg | |
| self.cfg, | |
| ) | |
| net_d = MultiPeriodDiscriminator(self.cfg.model.vits.use_spectral_norm) | |
| model = {"generator": net_g, "discriminator": net_d} | |
| return model | |
| def _build_dataset(self): | |
| return SVCOfflineDataset, SVCOfflineCollator | |
| def _build_optimizer(self): | |
| optimizer_g = torch.optim.AdamW( | |
| self.model["generator"].parameters(), | |
| self.cfg.train.learning_rate, | |
| betas=self.cfg.train.AdamW.betas, | |
| eps=self.cfg.train.AdamW.eps, | |
| ) | |
| optimizer_d = torch.optim.AdamW( | |
| self.model["discriminator"].parameters(), | |
| self.cfg.train.learning_rate, | |
| betas=self.cfg.train.AdamW.betas, | |
| eps=self.cfg.train.AdamW.eps, | |
| ) | |
| optimizer = {"optimizer_g": optimizer_g, "optimizer_d": optimizer_d} | |
| return optimizer | |
| def _build_scheduler(self): | |
| scheduler_g = ExponentialLR( | |
| self.optimizer["optimizer_g"], | |
| gamma=self.cfg.train.lr_decay, | |
| last_epoch=self.epoch - 1, | |
| ) | |
| scheduler_d = ExponentialLR( | |
| self.optimizer["optimizer_d"], | |
| gamma=self.cfg.train.lr_decay, | |
| last_epoch=self.epoch - 1, | |
| ) | |
| scheduler = {"scheduler_g": scheduler_g, "scheduler_d": scheduler_d} | |
| return scheduler | |
| def _build_criterion(self): | |
| class GeneratorLoss(nn.Module): | |
| def __init__(self, cfg): | |
| super(GeneratorLoss, self).__init__() | |
| self.cfg = cfg | |
| self.l1_loss = nn.L1Loss() | |
| def generator_loss(self, disc_outputs): | |
| loss = 0 | |
| gen_losses = [] | |
| for dg in disc_outputs: | |
| dg = dg.float() | |
| l = torch.mean((1 - dg) ** 2) | |
| gen_losses.append(l) | |
| loss += l | |
| return loss, gen_losses | |
| def feature_loss(self, fmap_r, fmap_g): | |
| loss = 0 | |
| for dr, dg in zip(fmap_r, fmap_g): | |
| for rl, gl in zip(dr, dg): | |
| rl = rl.float().detach() | |
| gl = gl.float() | |
| loss += torch.mean(torch.abs(rl - gl)) | |
| return loss * 2 | |
| def kl_loss(self, z_p, logs_q, m_p, logs_p, z_mask): | |
| """ | |
| z_p, logs_q: [b, h, t_t] | |
| m_p, logs_p: [b, h, t_t] | |
| """ | |
| z_p = z_p.float() | |
| logs_q = logs_q.float() | |
| m_p = m_p.float() | |
| logs_p = logs_p.float() | |
| z_mask = z_mask.float() | |
| kl = logs_p - logs_q - 0.5 | |
| kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) | |
| kl = torch.sum(kl * z_mask) | |
| l = kl / torch.sum(z_mask) | |
| return l | |
| def forward( | |
| self, | |
| outputs_g, | |
| outputs_d, | |
| y_mel, | |
| y_hat_mel, | |
| ): | |
| loss_g = {} | |
| # mel loss | |
| loss_mel = self.l1_loss(y_mel, y_hat_mel) * self.cfg.train.c_mel | |
| loss_g["loss_mel"] = loss_mel | |
| # kl loss | |
| loss_kl = ( | |
| self.kl_loss( | |
| outputs_g["z_p"], | |
| outputs_g["logs_q"], | |
| outputs_g["m_p"], | |
| outputs_g["logs_p"], | |
| outputs_g["z_mask"], | |
| ) | |
| * self.cfg.train.c_kl | |
| ) | |
| loss_g["loss_kl"] = loss_kl | |
| # feature loss | |
| loss_fm = self.feature_loss(outputs_d["fmap_rs"], outputs_d["fmap_gs"]) | |
| loss_g["loss_fm"] = loss_fm | |
| # gan loss | |
| loss_gen, losses_gen = self.generator_loss(outputs_d["y_d_hat_g"]) | |
| loss_g["loss_gen"] = loss_gen | |
| loss_g["loss_gen_all"] = loss_mel + loss_kl + loss_fm + loss_gen | |
| return loss_g | |
| class DiscriminatorLoss(nn.Module): | |
| def __init__(self, cfg): | |
| super(DiscriminatorLoss, self).__init__() | |
| self.cfg = cfg | |
| self.l1Loss = torch.nn.L1Loss(reduction="mean") | |
| def __call__(self, disc_real_outputs, disc_generated_outputs): | |
| loss_d = {} | |
| loss = 0 | |
| r_losses = [] | |
| g_losses = [] | |
| for dr, dg in zip(disc_real_outputs, disc_generated_outputs): | |
| dr = dr.float() | |
| dg = dg.float() | |
| r_loss = torch.mean((1 - dr) ** 2) | |
| g_loss = torch.mean(dg**2) | |
| loss += r_loss + g_loss | |
| r_losses.append(r_loss.item()) | |
| g_losses.append(g_loss.item()) | |
| loss_d["loss_disc_all"] = loss | |
| return loss_d | |
| criterion = { | |
| "generator": GeneratorLoss(self.cfg), | |
| "discriminator": DiscriminatorLoss(self.cfg), | |
| } | |
| return criterion | |
| # Keep legacy unchanged | |
| def write_summary( | |
| self, | |
| losses, | |
| stats, | |
| images={}, | |
| audios={}, | |
| audio_sampling_rate=24000, | |
| tag="train", | |
| ): | |
| for key, value in losses.items(): | |
| self.sw.add_scalar(tag + "/" + key, value, self.step) | |
| self.sw.add_scalar( | |
| "learning_rate", | |
| self.optimizer["optimizer_g"].param_groups[0]["lr"], | |
| self.step, | |
| ) | |
| if len(images) != 0: | |
| for key, value in images.items(): | |
| self.sw.add_image(key, value, self.global_step, batchformats="HWC") | |
| if len(audios) != 0: | |
| for key, value in audios.items(): | |
| self.sw.add_audio(key, value, self.global_step, audio_sampling_rate) | |
| def write_valid_summary( | |
| self, losses, stats, images={}, audios={}, audio_sampling_rate=24000, tag="val" | |
| ): | |
| for key, value in losses.items(): | |
| self.sw.add_scalar(tag + "/" + key, value, self.step) | |
| if len(images) != 0: | |
| for key, value in images.items(): | |
| self.sw.add_image(key, value, self.global_step, batchformats="HWC") | |
| if len(audios) != 0: | |
| for key, value in audios.items(): | |
| self.sw.add_audio(key, value, self.global_step, audio_sampling_rate) | |
| def _get_state_dict(self): | |
| state_dict = { | |
| "generator": self.model["generator"].state_dict(), | |
| "discriminator": self.model["discriminator"].state_dict(), | |
| "optimizer_g": self.optimizer["optimizer_g"].state_dict(), | |
| "optimizer_d": self.optimizer["optimizer_d"].state_dict(), | |
| "scheduler_g": self.scheduler["scheduler_g"].state_dict(), | |
| "scheduler_d": self.scheduler["scheduler_d"].state_dict(), | |
| "step": self.step, | |
| "epoch": self.epoch, | |
| "batch_size": self.cfg.train.batch_size, | |
| } | |
| return state_dict | |
| def get_state_dict(self): | |
| state_dict = { | |
| "generator": self.model["generator"].state_dict(), | |
| "discriminator": self.model["discriminator"].state_dict(), | |
| "optimizer_g": self.optimizer["optimizer_g"].state_dict(), | |
| "optimizer_d": self.optimizer["optimizer_d"].state_dict(), | |
| "scheduler_g": self.scheduler["scheduler_g"].state_dict(), | |
| "scheduler_d": self.scheduler["scheduler_d"].state_dict(), | |
| "step": self.step, | |
| "epoch": self.epoch, | |
| "batch_size": self.cfg.train.batch_size, | |
| } | |
| return state_dict | |
| def load_model(self, checkpoint): | |
| self.step = checkpoint["step"] | |
| self.epoch = checkpoint["epoch"] | |
| self.model["generator"].load_state_dict(checkpoint["generator"]) | |
| self.model["discriminator"].load_state_dict(checkpoint["discriminator"]) | |
| self.optimizer["optimizer_g"].load_state_dict(checkpoint["optimizer_g"]) | |
| self.optimizer["optimizer_d"].load_state_dict(checkpoint["optimizer_d"]) | |
| self.scheduler["scheduler_g"].load_state_dict(checkpoint["scheduler_g"]) | |
| self.scheduler["scheduler_d"].load_state_dict(checkpoint["scheduler_d"]) | |
| def _valid_step(self, batch): | |
| r"""Testing forward step. Should return average loss of a sample over | |
| one batch. Provoke ``_forward_step`` is recommended except for special case. | |
| See ``_test_epoch`` for usage. | |
| """ | |
| valid_losses = {} | |
| total_loss = 0 | |
| valid_stats = {} | |
| # Discriminator | |
| # Generator output | |
| outputs_g = self.model["generator"](batch) | |
| y_mel = slice_segments( | |
| batch["mel"].transpose(1, 2), | |
| outputs_g["ids_slice"], | |
| self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size, | |
| ) | |
| y_hat_mel = mel_spectrogram_torch( | |
| outputs_g["y_hat"].squeeze(1), self.cfg.preprocess | |
| ) | |
| y = slice_segments( | |
| batch["audio"].unsqueeze(1), | |
| outputs_g["ids_slice"] * self.cfg.preprocess.hop_size, | |
| self.cfg.preprocess.segment_size, | |
| ) | |
| # Discriminator output | |
| outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach()) | |
| ## Discriminator loss | |
| loss_d = self.criterion["discriminator"]( | |
| outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"] | |
| ) | |
| valid_losses.update(loss_d) | |
| ## Generator | |
| outputs_d = self.model["discriminator"](y, outputs_g["y_hat"]) | |
| loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel) | |
| valid_losses.update(loss_g) | |
| for item in valid_losses: | |
| valid_losses[item] = valid_losses[item].item() | |
| total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"] | |
| return ( | |
| total_loss.item(), | |
| valid_losses, | |
| valid_stats, | |
| ) | |
| def _valid_epoch(self): | |
| r"""Testing epoch. Should return average loss of a batch (sample) over | |
| one epoch. See ``train_loop`` for usage. | |
| """ | |
| if isinstance(self.model, dict): | |
| for key in self.model.keys(): | |
| self.model[key].eval() | |
| else: | |
| self.model.eval() | |
| epoch_sum_loss = 0.0 | |
| epoch_losses = dict() | |
| for batch in tqdm( | |
| self.valid_dataloader, | |
| desc=f"Validating Epoch {self.epoch}", | |
| unit="batch", | |
| colour="GREEN", | |
| leave=False, | |
| dynamic_ncols=True, | |
| smoothing=0.04, | |
| disable=not self.accelerator.is_main_process, | |
| ): | |
| total_loss, valid_losses, valid_stats = self._valid_step(batch) | |
| epoch_sum_loss += total_loss | |
| if isinstance(valid_losses, dict): | |
| for key, value in valid_losses.items(): | |
| if key not in epoch_losses.keys(): | |
| epoch_losses[key] = value | |
| else: | |
| epoch_losses[key] += value | |
| epoch_sum_loss = epoch_sum_loss / len(self.valid_dataloader) | |
| for key in epoch_losses.keys(): | |
| epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader) | |
| self.accelerator.wait_for_everyone() | |
| return epoch_sum_loss, epoch_losses | |
| ### THIS IS MAIN ENTRY ### | |
| def train_loop(self): | |
| r"""Training loop. The public entry of training process.""" | |
| # Wait everyone to prepare before we move on | |
| self.accelerator.wait_for_everyone() | |
| # dump config file | |
| if self.accelerator.is_main_process: | |
| self.__dump_cfg(self.config_save_path) | |
| # self.optimizer.zero_grad() | |
| # Wait to ensure good to go | |
| self.accelerator.wait_for_everyone() | |
| while self.epoch < self.max_epoch: | |
| self.logger.info("\n") | |
| self.logger.info("-" * 32) | |
| self.logger.info("Epoch {}: ".format(self.epoch)) | |
| # Do training & validating epoch | |
| train_total_loss, train_losses = self._train_epoch() | |
| if isinstance(train_losses, dict): | |
| for key, loss in train_losses.items(): | |
| self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss)) | |
| self.accelerator.log( | |
| {"Epoch/Train {} Loss".format(key): loss}, | |
| step=self.epoch, | |
| ) | |
| valid_total_loss, valid_losses = self._valid_epoch() | |
| if isinstance(valid_losses, dict): | |
| for key, loss in valid_losses.items(): | |
| self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss)) | |
| self.accelerator.log( | |
| {"Epoch/Train {} Loss".format(key): loss}, | |
| step=self.epoch, | |
| ) | |
| self.logger.info(" |- Train/Loss: {:.6f}".format(train_total_loss)) | |
| self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_total_loss)) | |
| self.accelerator.log( | |
| { | |
| "Epoch/Train Loss": train_total_loss, | |
| "Epoch/Valid Loss": valid_total_loss, | |
| }, | |
| step=self.epoch, | |
| ) | |
| self.accelerator.wait_for_everyone() | |
| # Check if hit save_checkpoint_stride and run_eval | |
| run_eval = False | |
| if self.accelerator.is_main_process: | |
| save_checkpoint = False | |
| hit_dix = [] | |
| for i, num in enumerate(self.save_checkpoint_stride): | |
| if self.epoch % num == 0: | |
| save_checkpoint = True | |
| hit_dix.append(i) | |
| run_eval |= self.run_eval[i] | |
| self.accelerator.wait_for_everyone() | |
| if self.accelerator.is_main_process and save_checkpoint: | |
| path = os.path.join( | |
| self.checkpoint_dir, | |
| "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( | |
| self.epoch, self.step, train_total_loss | |
| ), | |
| ) | |
| self.tmp_checkpoint_save_path = path | |
| self.accelerator.save_state(path) | |
| json.dump( | |
| self.checkpoints_path, | |
| open(os.path.join(path, "ckpts.json"), "w"), | |
| ensure_ascii=False, | |
| indent=4, | |
| ) | |
| self._save_auxiliary_states() | |
| # Remove old checkpoints | |
| to_remove = [] | |
| for idx in hit_dix: | |
| self.checkpoints_path[idx].append(path) | |
| while len(self.checkpoints_path[idx]) > self.keep_last[idx]: | |
| to_remove.append((idx, self.checkpoints_path[idx].pop(0))) | |
| # Search conflicts | |
| total = set() | |
| for i in self.checkpoints_path: | |
| total |= set(i) | |
| do_remove = set() | |
| for idx, path in to_remove[::-1]: | |
| if path in total: | |
| self.checkpoints_path[idx].insert(0, path) | |
| else: | |
| do_remove.add(path) | |
| # Remove old checkpoints | |
| for path in do_remove: | |
| shutil.rmtree(path, ignore_errors=True) | |
| self.logger.debug(f"Remove old checkpoint: {path}") | |
| self.accelerator.wait_for_everyone() | |
| if run_eval: | |
| # TODO: run evaluation | |
| pass | |
| # Update info for each epoch | |
| self.epoch += 1 | |
| # Finish training and save final checkpoint | |
| self.accelerator.wait_for_everyone() | |
| if self.accelerator.is_main_process: | |
| path = os.path.join( | |
| self.checkpoint_dir, | |
| "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( | |
| self.epoch, self.step, valid_total_loss | |
| ), | |
| ) | |
| self.tmp_checkpoint_save_path = path | |
| self.accelerator.save_state( | |
| os.path.join( | |
| self.checkpoint_dir, | |
| "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( | |
| self.epoch, self.step, valid_total_loss | |
| ), | |
| ) | |
| ) | |
| json.dump( | |
| self.checkpoints_path, | |
| open(os.path.join(path, "ckpts.json"), "w"), | |
| ensure_ascii=False, | |
| indent=4, | |
| ) | |
| self._save_auxiliary_states() | |
| self.accelerator.end_training() | |
| def _train_step(self, batch): | |
| r"""Forward step for training and inference. This function is called | |
| in ``_train_step`` & ``_test_step`` function. | |
| """ | |
| train_losses = {} | |
| total_loss = 0 | |
| training_stats = {} | |
| ## Train Discriminator | |
| # Generator output | |
| outputs_g = self.model["generator"](batch) | |
| y_mel = slice_segments( | |
| batch["mel"].transpose(1, 2), | |
| outputs_g["ids_slice"], | |
| self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size, | |
| ) | |
| y_hat_mel = mel_spectrogram_torch( | |
| outputs_g["y_hat"].squeeze(1), self.cfg.preprocess | |
| ) | |
| y = slice_segments( | |
| # [1, 168418] -> [1, 1, 168418] | |
| batch["audio"].unsqueeze(1), | |
| outputs_g["ids_slice"] * self.cfg.preprocess.hop_size, | |
| self.cfg.preprocess.segment_size, | |
| ) | |
| # Discriminator output | |
| outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach()) | |
| # Discriminator loss | |
| loss_d = self.criterion["discriminator"]( | |
| outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"] | |
| ) | |
| train_losses.update(loss_d) | |
| # BP and Grad Updated | |
| self.optimizer["optimizer_d"].zero_grad() | |
| self.accelerator.backward(loss_d["loss_disc_all"]) | |
| self.optimizer["optimizer_d"].step() | |
| ## Train Generator | |
| outputs_d = self.model["discriminator"](y, outputs_g["y_hat"]) | |
| loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel) | |
| train_losses.update(loss_g) | |
| # BP and Grad Updated | |
| self.optimizer["optimizer_g"].zero_grad() | |
| self.accelerator.backward(loss_g["loss_gen_all"]) | |
| self.optimizer["optimizer_g"].step() | |
| for item in train_losses: | |
| train_losses[item] = train_losses[item].item() | |
| total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"] | |
| return ( | |
| total_loss.item(), | |
| train_losses, | |
| training_stats, | |
| ) | |
| def _train_epoch(self): | |
| r"""Training epoch. Should return average loss of a batch (sample) over | |
| one epoch. See ``train_loop`` for usage. | |
| """ | |
| epoch_sum_loss: float = 0.0 | |
| epoch_losses: dict = {} | |
| epoch_step: int = 0 | |
| for batch in tqdm( | |
| self.train_dataloader, | |
| desc=f"Training Epoch {self.epoch}", | |
| unit="batch", | |
| colour="GREEN", | |
| leave=False, | |
| dynamic_ncols=True, | |
| smoothing=0.04, | |
| disable=not self.accelerator.is_main_process, | |
| ): | |
| # Do training step and BP | |
| with self.accelerator.accumulate(self.model): | |
| total_loss, train_losses, training_stats = self._train_step(batch) | |
| self.batch_count += 1 | |
| # Update info for each step | |
| if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: | |
| epoch_sum_loss += total_loss | |
| for key, value in train_losses.items(): | |
| if key not in epoch_losses.keys(): | |
| epoch_losses[key] = value | |
| else: | |
| epoch_losses[key] += value | |
| self.accelerator.log( | |
| { | |
| "Step/Generator Loss": train_losses["loss_gen_all"], | |
| "Step/Discriminator Loss": train_losses["loss_disc_all"], | |
| "Step/Generator Learning Rate": self.optimizer[ | |
| "optimizer_d" | |
| ].param_groups[0]["lr"], | |
| "Step/Discriminator Learning Rate": self.optimizer[ | |
| "optimizer_g" | |
| ].param_groups[0]["lr"], | |
| }, | |
| step=self.step, | |
| ) | |
| self.step += 1 | |
| epoch_step += 1 | |
| self.accelerator.wait_for_everyone() | |
| epoch_sum_loss = ( | |
| epoch_sum_loss | |
| / len(self.train_dataloader) | |
| * self.cfg.train.gradient_accumulation_step | |
| ) | |
| for key in epoch_losses.keys(): | |
| epoch_losses[key] = ( | |
| epoch_losses[key] | |
| / len(self.train_dataloader) | |
| * self.cfg.train.gradient_accumulation_step | |
| ) | |
| return epoch_sum_loss, epoch_losses | |
| def __dump_cfg(self, path): | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| json5.dump( | |
| self.cfg, | |
| open(path, "w"), | |
| indent=4, | |
| sort_keys=True, | |
| ensure_ascii=False, | |
| quote_keys=True, | |
| ) | |