Spaces:
Sleeping
Sleeping
# ***************************************************************************** | |
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Redistribution and use in source and binary forms, with or without | |
# modification, are permitted provided that the following conditions are met: | |
# * Redistributions of source code must retain the above copyright | |
# notice, this list of conditions and the following disclaimer. | |
# * Redistributions in binary form must reproduce the above copyright | |
# notice, this list of conditions and the following disclaimer in the | |
# documentation and/or other materials provided with the distribution. | |
# * Neither the name of the NVIDIA CORPORATION nor the | |
# names of its contributors may be used to endorse or promote products | |
# derived from this software without specific prior written permission. | |
# | |
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | |
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | |
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY | |
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | |
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | |
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | |
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | |
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
# | |
# ***************************************************************************** | |
import argparse | |
import copy | |
import os | |
import time | |
from collections import defaultdict, OrderedDict | |
from itertools import cycle | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
import amp_C | |
from apex.optimizers import FusedAdam, FusedLAMB | |
from torch.nn.parallel import DistributedDataParallel | |
from torch.utils.data import DataLoader | |
from torch.utils.data.distributed import DistributedSampler | |
import common.tb_dllogger as logger | |
import models | |
from common.tb_dllogger import log | |
from common.repeated_dataloader import (RepeatedDataLoader, | |
RepeatedDistributedSampler) | |
from common.text import cmudict | |
from common.utils import BenchmarkStats, Checkpointer, prepare_tmp | |
from fastpitch.attn_loss_function import AttentionBinarizationLoss | |
from fastpitch.data_function import batch_to_gpu, TTSCollate, TTSDataset | |
from fastpitch.loss_function import FastPitchLoss | |
import matplotlib.pyplot as plt | |
def parse_args(parser): | |
parser.add_argument('-o', '--output', type=str, required=True, | |
help='Directory to save checkpoints') | |
parser.add_argument('-d', '--dataset-path', type=str, default='./', | |
help='Path to dataset') | |
parser.add_argument('--log-file', type=str, default=None, | |
help='Path to a DLLogger log file') | |
train = parser.add_argument_group('training setup') | |
train.add_argument('--epochs', type=int, required=True, | |
help='Number of total epochs to run') | |
train.add_argument('--epochs-per-checkpoint', type=int, default=50, | |
help='Number of epochs per checkpoint') | |
train.add_argument('--checkpoint-path', type=str, default=None, | |
help='Checkpoint path to resume training') | |
train.add_argument('--keep-milestones', default=list(range(100, 1000, 100)), | |
type=int, nargs='+', | |
help='Milestone checkpoints to keep from removing') | |
train.add_argument('--resume', action='store_true', | |
help='Resume training from the last checkpoint') | |
train.add_argument('--seed', type=int, default=1234, | |
help='Seed for PyTorch random number generators') | |
train.add_argument('--amp', action='store_true', | |
help='Enable AMP') | |
train.add_argument('--cuda', action='store_true', | |
help='Run on GPU using CUDA') | |
train.add_argument('--cudnn-benchmark', action='store_true', | |
help='Enable cudnn benchmark mode') | |
train.add_argument('--ema-decay', type=float, default=0, | |
help='Discounting factor for training weights EMA') | |
train.add_argument('--grad-accumulation', type=int, default=1, | |
help='Training steps to accumulate gradients for') | |
train.add_argument('--kl-loss-start-epoch', type=int, default=250, | |
help='Start adding the hard attention loss term') | |
train.add_argument('--kl-loss-warmup-epochs', type=int, default=100, | |
help='Gradually increase the hard attention loss term') | |
train.add_argument('--kl-loss-weight', type=float, default=1.0, | |
help='Gradually increase the hard attention loss term') | |
train.add_argument('--benchmark-epochs-num', type=int, default=20, | |
help='Number of epochs for calculating final stats') | |
train.add_argument('--validation-freq', type=int, default=1, | |
help='Validate every N epochs to use less compute') | |
opt = parser.add_argument_group('optimization setup') | |
opt.add_argument('--optimizer', type=str, default='lamb', | |
help='Optimization algorithm') | |
opt.add_argument('-lr', '--learning-rate', type=float, required=True, | |
help='Learing rate') | |
opt.add_argument('--weight-decay', default=1e-6, type=float, | |
help='Weight decay') | |
opt.add_argument('--grad-clip-thresh', default=1000.0, type=float, | |
help='Clip threshold for gradients') | |
opt.add_argument('-bs', '--batch-size', type=int, required=True, | |
help='Batch size per GPU') | |
opt.add_argument('--warmup-steps', type=int, default=1000, | |
help='Number of steps for lr warmup') | |
opt.add_argument('--dur-predictor-loss-scale', type=float, | |
default=1.0, help='Rescale duration predictor loss') | |
opt.add_argument('--pitch-predictor-loss-scale', type=float, | |
default=1.0, help='Rescale pitch predictor loss') | |
opt.add_argument('--attn-loss-scale', type=float, | |
default=1.0, help='Rescale alignment loss') | |
data = parser.add_argument_group('dataset parameters') | |
data.add_argument('--training-files', type=str, nargs='*', required=True, | |
help='Paths to training filelists.') | |
data.add_argument('--validation-files', type=str, nargs='*', | |
required=True, help='Paths to validation filelists') | |
data.add_argument('--text-cleaners', nargs='*', | |
default=['english_cleaners'], type=str, | |
help='Type of text cleaners for input text') | |
data.add_argument('--symbol-set', type=str, default='english_basic', | |
help='Define symbol set for input text') | |
data.add_argument('--p-arpabet', type=float, default=0.0, | |
help='Probability of using arpabets instead of graphemes ' | |
'for each word; set 0 for pure grapheme training') | |
data.add_argument('--heteronyms-path', type=str, default='cmudict/heteronyms', | |
help='Path to the list of heteronyms') | |
data.add_argument('--cmudict-path', type=str, default='cmudict/cmudict-0.7b', | |
help='Path to the pronouncing dictionary') | |
data.add_argument('--prepend-space-to-text', action='store_true', | |
help='Capture leading silence with a space token') | |
data.add_argument('--append-space-to-text', action='store_true', | |
help='Capture trailing silence with a space token') | |
data.add_argument('--num-workers', type=int, default=2, # 6 | |
help='Subprocesses for train and val DataLoaders') | |
data.add_argument('--trainloader-repeats', type=int, default=100, | |
help='Repeats the dataset to prolong epochs') | |
cond = parser.add_argument_group('data for conditioning') | |
cond.add_argument('--n-speakers', type=int, default=1, | |
help='Number of speakers in the dataset. ' | |
'n_speakers > 1 enables speaker embeddings') | |
# ANT: added language | |
cond.add_argument('--n-languages', type=int, default=1, | |
help='Number of languages in the dataset. ' | |
'n_languages > 1 enables language embeddings') | |
cond.add_argument('--load-pitch-from-disk', action='store_true', | |
help='Use pitch cached on disk with prepare_dataset.py') | |
cond.add_argument('--pitch-online-method', default='pyin', | |
choices=['pyin'], | |
help='Calculate pitch on the fly during trainig') | |
cond.add_argument('--pitch-online-dir', type=str, default=None, | |
help='A directory for storing pitch calculated on-line') | |
cond.add_argument('--pitch-mean', type=float, default=125.626816, #default=214.72203, | |
help='Normalization value for pitch') | |
cond.add_argument('--pitch-std', type=float, default=37.52, #default=65.72038, | |
help='Normalization value for pitch') | |
cond.add_argument('--load-mel-from-disk', action='store_true', | |
help='Use mel-spectrograms cache on the disk') # XXX | |
audio = parser.add_argument_group('audio parameters') | |
audio.add_argument('--max-wav-value', default=32768.0, type=float, | |
help='Maximum audiowave value') | |
audio.add_argument('--sampling-rate', default=22050, type=int, | |
help='Sampling rate') | |
audio.add_argument('--filter-length', default=1024, type=int, | |
help='Filter length') | |
audio.add_argument('--hop-length', default=256, type=int, | |
help='Hop (stride) length') | |
audio.add_argument('--win-length', default=1024, type=int, | |
help='Window length') | |
audio.add_argument('--mel-fmin', default=0.0, type=float, | |
help='Minimum mel frequency') | |
audio.add_argument('--mel-fmax', default=8000.0, type=float, | |
help='Maximum mel frequency') | |
dist = parser.add_argument_group('distributed setup') | |
dist.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0), | |
help='Rank of the process for multiproc; do not set manually') | |
dist.add_argument('--world_size', type=int, default=os.getenv('WORLD_SIZE', 1), | |
help='Number of processes for multiproc; do not set manually') | |
return parser | |
def reduce_tensor(tensor, num_gpus): | |
rt = tensor.clone() | |
dist.all_reduce(rt, op=dist.ReduceOp.SUM) | |
return rt.true_divide(num_gpus) | |
def init_distributed(args, world_size, rank): | |
assert torch.cuda.is_available(), "Distributed mode requires CUDA." | |
print("Initializing distributed training") | |
# Set cuda device so everything is done on the right GPU. | |
torch.cuda.set_device(rank % torch.cuda.device_count()) | |
# Initialize distributed communication | |
dist.init_process_group(backend=('nccl' if args.cuda else 'gloo'), | |
init_method='env://') | |
print("Done initializing distributed training") | |
def validate(model, epoch, total_iter, criterion, val_loader, distributed_run, | |
batch_to_gpu, local_rank, ema=False): | |
was_training = model.training | |
model.eval() | |
tik = time.perf_counter() | |
with torch.no_grad(): | |
val_meta = defaultdict(float) | |
val_num_frames = 0 | |
for i, batch in enumerate(val_loader): | |
x, y, num_frames = batch_to_gpu(batch) | |
y_pred = model(x) | |
loss, meta = criterion(y_pred, y, is_training=False, meta_agg='sum') | |
if distributed_run: | |
for k, v in meta.items(): | |
val_meta[k] += reduce_tensor(v, 1) | |
val_num_frames += reduce_tensor(num_frames.data, 1).item() | |
else: | |
for k, v in meta.items(): | |
val_meta[k] += v | |
val_num_frames += num_frames.item() | |
# NOTE: ugly patch to visualize the first utterance of the validation corpus. | |
# The goal is to determine if the training is progressing properly | |
if (i == 0) and (local_rank == 0) and (not ema): | |
# Plot some debug information | |
fig, axs = plt.subplots(2, 2, figsize=(21,14)) | |
# - Mel-spectrogram | |
pred_mel = y_pred[0][0, :, :].cpu().detach().numpy().astype(np.float32).T | |
orig_mel = y[0][0, :, :].cpu().detach().numpy().astype(np.float32) | |
axs[0,0].imshow(orig_mel, aspect='auto', origin='lower', interpolation='nearest') | |
axs[1,0].imshow(pred_mel, aspect='auto', origin='lower', interpolation='nearest') | |
# Prosody | |
f0_pred = y_pred[4][0, :].cpu().detach().numpy().astype(np.float32) | |
f0_ori = y_pred[5][0, :].cpu().detach().numpy().astype(np.float32) | |
axs[1,1].plot(f0_ori) | |
axs[1,1].plot(f0_pred) | |
# # Duration | |
# att_pred = y_pred[2][0, :].cpu().detach().numpy().astype(np.float32) | |
# att_ori = x[7][0,:].cpu().detach().numpy().astype(np.float32) | |
# axs[0,1].imshow(att_ori, aspect='auto', origin='lower', interpolation='nearest') | |
if not os.path.exists("debug_epoch/"): | |
os.makedirs("debug_epoch_laila/") | |
fig.savefig(f'debug_epoch/{epoch:06d}.png', bbox_inches='tight') | |
val_meta = {k: v / len(val_loader.dataset) for k, v in val_meta.items()} | |
val_meta['took'] = time.perf_counter() - tik | |
log((epoch,) if epoch is not None else (), tb_total_steps=total_iter, | |
subset='val_ema' if ema else 'val', | |
data=OrderedDict([ | |
('loss', val_meta['loss'].item()), | |
('mel_loss', val_meta['mel_loss'].item()), | |
('frames/s', val_num_frames / val_meta['took']), | |
('took', val_meta['took'])]), | |
) | |
if was_training: | |
model.train() | |
return val_meta | |
def adjust_learning_rate(total_iter, opt, learning_rate, warmup_iters=None): | |
if warmup_iters == 0: | |
scale = 1.0 | |
elif total_iter > warmup_iters: | |
scale = 1. / (total_iter ** 0.5) | |
else: | |
scale = total_iter / (warmup_iters ** 1.5) | |
for param_group in opt.param_groups: | |
param_group['lr'] = learning_rate * scale | |
def apply_ema_decay(model, ema_model, decay): | |
if not decay: | |
return | |
st = model.state_dict() | |
add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module') | |
for k, v in ema_model.state_dict().items(): | |
if add_module and not k.startswith('module.'): | |
k = 'module.' + k | |
v.copy_(decay * v + (1 - decay) * st[k]) | |
def init_multi_tensor_ema(model, ema_model): | |
model_weights = list(model.state_dict().values()) | |
ema_model_weights = list(ema_model.state_dict().values()) | |
ema_overflow_buf = torch.cuda.IntTensor([0]) | |
return model_weights, ema_model_weights, ema_overflow_buf | |
def apply_multi_tensor_ema(decay, model_weights, ema_weights, overflow_buf): | |
amp_C.multi_tensor_axpby( | |
65536, overflow_buf, [ema_weights, model_weights, ema_weights], | |
decay, 1-decay, -1) | |
def main(): | |
parser = argparse.ArgumentParser(description='PyTorch FastPitch Training', | |
allow_abbrev=False) | |
parser = parse_args(parser) | |
args, _ = parser.parse_known_args() | |
if args.p_arpabet > 0.0: | |
cmudict.initialize(args.cmudict_path, args.heteronyms_path) | |
distributed_run = args.world_size > 1 | |
torch.manual_seed(args.seed + args.local_rank) | |
np.random.seed(args.seed + args.local_rank) | |
if args.local_rank == 0: | |
if not os.path.exists(args.output): | |
os.makedirs(args.output) | |
log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json') | |
tb_subsets = ['train', 'val'] | |
if args.ema_decay > 0.0: | |
tb_subsets.append('val_ema') | |
logger.init(log_fpath, args.output, enabled=(args.local_rank == 0), | |
tb_subsets=tb_subsets) | |
logger.parameters(vars(args), tb_subset='train') | |
parser = models.parse_model_args('FastPitch', parser) | |
args, unk_args = parser.parse_known_args() | |
if len(unk_args) > 0: | |
raise ValueError(f'Invalid options {unk_args}') | |
torch.backends.cudnn.benchmark = args.cudnn_benchmark | |
if distributed_run: | |
init_distributed(args, args.world_size, args.local_rank) | |
else: | |
if args.trainloader_repeats > 1: | |
print('WARNING: Disabled --trainloader-repeats, supported only for' | |
' multi-GPU data loading.') | |
args.trainloader_repeats = 1 | |
device = torch.device('cuda' if args.cuda else 'cpu') | |
model_config = models.get_model_config('FastPitch', args) | |
model = models.get_model('FastPitch', model_config, device) | |
attention_kl_loss = AttentionBinarizationLoss() | |
# Store pitch mean/std as params to translate from Hz during inference | |
model.pitch_mean[0] = args.pitch_mean | |
model.pitch_std[0] = args.pitch_std | |
kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9, | |
weight_decay=args.weight_decay) | |
if args.optimizer == 'adam': | |
optimizer = FusedAdam(model.parameters(), **kw) | |
# optimizer = torch.optim.Adam(model.parameters(), **kw) | |
elif args.optimizer == 'lamb': | |
optimizer = FusedLAMB(model.parameters(), **kw) | |
# optimizer = torch.optim.Adam(model.parameters(), **kw) | |
else: | |
raise ValueError | |
scaler = torch.cuda.amp.GradScaler(enabled=args.amp) | |
if args.ema_decay > 0: | |
ema_model = copy.deepcopy(model) | |
else: | |
ema_model = None | |
if distributed_run: | |
model = DistributedDataParallel( | |
model, device_ids=[args.local_rank], output_device=args.local_rank, | |
find_unused_parameters=True) | |
train_state = {'epoch': 1, 'total_iter': 1} | |
checkpointer = Checkpointer(args.output, args.keep_milestones) | |
checkpointer.maybe_load(model, optimizer, scaler, train_state, args, | |
ema_model) | |
start_epoch = train_state['epoch'] | |
total_iter = train_state['total_iter'] | |
criterion = FastPitchLoss( | |
dur_predictor_loss_scale=args.dur_predictor_loss_scale, | |
pitch_predictor_loss_scale=args.pitch_predictor_loss_scale, | |
attn_loss_scale=args.attn_loss_scale) | |
collate_fn = TTSCollate() | |
if args.local_rank == 0: | |
prepare_tmp(args.pitch_online_dir) | |
trainset = TTSDataset(audiopaths_and_text=args.training_files, **vars(args)) | |
valset = TTSDataset(audiopaths_and_text=args.validation_files, **vars(args)) | |
if distributed_run: | |
train_sampler = RepeatedDistributedSampler(args.trainloader_repeats, | |
trainset, drop_last=True) | |
val_sampler = DistributedSampler(valset) | |
shuffle = False | |
else: | |
train_sampler, val_sampler, shuffle = None, None, False ########### was True | |
# 4 workers are optimal on DGX-1 (from epoch 2 onwards) | |
kw = {'num_workers': args.num_workers, 'batch_size': args.batch_size, | |
'collate_fn': collate_fn} | |
train_loader = RepeatedDataLoader(args.trainloader_repeats, trainset, | |
shuffle=shuffle, drop_last=True, | |
sampler=train_sampler, pin_memory=True, | |
persistent_workers=True, **kw) | |
val_loader = DataLoader(valset, shuffle=False, sampler=val_sampler, | |
pin_memory=False, **kw) | |
if args.ema_decay: | |
mt_ema_params = init_multi_tensor_ema(model, ema_model) | |
model.train() | |
bmark_stats = BenchmarkStats() | |
torch.cuda.synchronize() | |
for epoch in range(start_epoch, args.epochs + 1): | |
epoch_start_time = time.perf_counter() | |
epoch_loss = 0.0 | |
epoch_mel_loss = 0.0 | |
epoch_num_frames = 0 | |
epoch_frames_per_sec = 0.0 | |
if distributed_run: | |
train_loader.sampler.set_epoch(epoch) | |
iter_loss = 0 | |
iter_num_frames = 0 | |
iter_meta = {} | |
iter_start_time = time.perf_counter() | |
epoch_iter = 1 | |
for batch, accum_step in zip(train_loader, | |
cycle(range(1, args.grad_accumulation + 1))): | |
if accum_step == 1: | |
adjust_learning_rate(total_iter, optimizer, args.learning_rate, | |
args.warmup_steps) | |
model.zero_grad(set_to_none=True) | |
x, y, num_frames = batch_to_gpu(batch) | |
with torch.cuda.amp.autocast(enabled=args.amp): | |
y_pred = model(x) | |
loss, meta = criterion(y_pred, y) | |
if (args.kl_loss_start_epoch is not None | |
and epoch >= args.kl_loss_start_epoch): | |
if args.kl_loss_start_epoch == epoch and epoch_iter == 1: | |
print('Begin hard_attn loss') | |
_, _, _, _, _, _, _, _, attn_soft, attn_hard, _, _ = y_pred | |
binarization_loss = attention_kl_loss(attn_hard, attn_soft) | |
kl_weight = min((epoch - args.kl_loss_start_epoch) / args.kl_loss_warmup_epochs, 1.0) * args.kl_loss_weight | |
meta['kl_loss'] = binarization_loss.clone().detach() * kl_weight | |
loss += kl_weight * binarization_loss | |
else: | |
meta['kl_loss'] = torch.zeros_like(loss) | |
kl_weight = 0 | |
binarization_loss = 0 | |
loss /= args.grad_accumulation | |
meta = {k: v / args.grad_accumulation | |
for k, v in meta.items()} | |
if args.amp: | |
scaler.scale(loss).backward() | |
else: | |
loss.backward() | |
if distributed_run: | |
reduced_loss = reduce_tensor(loss.data, args.world_size).item() | |
reduced_num_frames = reduce_tensor(num_frames.data, 1).item() | |
meta = {k: reduce_tensor(v, args.world_size) for k, v in meta.items()} | |
else: | |
reduced_loss = loss.item() | |
reduced_num_frames = num_frames.item() | |
if np.isnan(reduced_loss): | |
raise Exception("loss is NaN") | |
iter_loss += reduced_loss | |
iter_num_frames += reduced_num_frames | |
iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta} | |
if accum_step % args.grad_accumulation == 0: | |
logger.log_grads_tb(total_iter, model) | |
if args.amp: | |
scaler.unscale_(optimizer) | |
torch.nn.utils.clip_grad_norm_( | |
model.parameters(), args.grad_clip_thresh) | |
scaler.step(optimizer) | |
scaler.update() | |
else: | |
torch.nn.utils.clip_grad_norm_( | |
model.parameters(), args.grad_clip_thresh) | |
optimizer.step() | |
if args.ema_decay > 0.0: | |
apply_multi_tensor_ema(args.ema_decay, *mt_ema_params) | |
iter_mel_loss = iter_meta['mel_loss'].item() | |
iter_kl_loss = iter_meta['kl_loss'].item() | |
iter_time = time.perf_counter() - iter_start_time | |
epoch_frames_per_sec += iter_num_frames / iter_time | |
epoch_loss += iter_loss | |
epoch_num_frames += iter_num_frames | |
epoch_mel_loss += iter_mel_loss | |
num_iters = len(train_loader) // args.grad_accumulation | |
log((epoch, epoch_iter, num_iters), tb_total_steps=total_iter, | |
subset='train', data=OrderedDict([ | |
('loss', iter_loss), | |
('mel_loss', iter_mel_loss), | |
('kl_loss', iter_kl_loss), | |
('kl_weight', kl_weight), | |
('frames/s', iter_num_frames / iter_time), | |
('took', iter_time), | |
('lrate', optimizer.param_groups[0]['lr'])]), | |
) | |
iter_loss = 0 | |
iter_num_frames = 0 | |
iter_meta = {} | |
iter_start_time = time.perf_counter() | |
if epoch_iter == num_iters: | |
break | |
epoch_iter += 1 | |
total_iter += 1 | |
# Finished epoch | |
epoch_loss /= epoch_iter | |
epoch_mel_loss /= epoch_iter | |
epoch_time = time.perf_counter() - epoch_start_time | |
log((epoch,), tb_total_steps=None, subset='train_avg', | |
data=OrderedDict([ | |
('loss', epoch_loss), | |
('mel_loss', epoch_mel_loss), | |
('frames/s', epoch_num_frames / epoch_time), | |
('took', epoch_time)]), | |
) | |
bmark_stats.update(epoch_num_frames, epoch_loss, epoch_mel_loss, | |
epoch_time) | |
if epoch % args.validation_freq == 0: | |
validate(model, epoch, total_iter, criterion, val_loader, | |
distributed_run, batch_to_gpu, ema=False, local_rank=args.local_rank) | |
if args.ema_decay > 0: | |
validate(ema_model, epoch, total_iter, criterion, val_loader, | |
distributed_run, batch_to_gpu, args.local_rank, ema=True) | |
# save before making sched.step() for proper loading of LR | |
checkpointer.maybe_save(args, model, ema_model, optimizer, scaler, | |
epoch, total_iter, model_config) | |
logger.flush() | |
# Finished training | |
if len(bmark_stats) > 0: | |
log((), tb_total_steps=None, subset='train_avg', | |
data=bmark_stats.get(args.benchmark_epochs_num)) | |
validate(model, None, total_iter, criterion, val_loader, distributed_run, | |
batch_to_gpu) | |
if __name__ == '__main__': | |
main() | |