|
import os |
|
import time |
|
import shutil |
|
|
|
import torch |
|
import cv2 |
|
import torch.optim as optim |
|
import numpy as np |
|
from glob import glob |
|
from torch.cuda.amp import GradScaler, autocast |
|
from torch.nn.parallel.distributed import DistributedDataParallel |
|
from torch.utils.data import Dataset, DataLoader |
|
from tqdm import tqdm |
|
from utils.image_processing import denormalize_input, preprocess_images, resize_image |
|
from losses import LossSummary, AnimeGanLoss, to_gray_scale |
|
from utils import load_checkpoint, save_checkpoint, read_image |
|
from utils.common import set_lr |
|
from color_transfer import color_transfer_pytorch |
|
|
|
|
|
def transfer_color_and_rescale(src, target): |
|
"""Transfer color from src image to target then rescale to [-1, 1]""" |
|
out = color_transfer_pytorch(src, target) |
|
out = (out / 0.5) - 1 |
|
return out |
|
|
|
def gaussian_noise(): |
|
gaussian_mean = torch.tensor(0.0) |
|
gaussian_std = torch.tensor(0.1) |
|
return torch.normal(gaussian_mean, gaussian_std) |
|
|
|
def convert_to_readable(seconds): |
|
return time.strftime('%H:%M:%S', time.gmtime(seconds)) |
|
|
|
|
|
def revert_to_np_image(image_tensor): |
|
image = image_tensor.cpu().numpy() |
|
|
|
image = image.transpose(1, 2, 0) |
|
image = denormalize_input(image, dtype=np.int16) |
|
return image[..., ::-1] |
|
|
|
|
|
def save_generated_images(images: torch.Tensor, save_dir: str): |
|
"""Save generated images `(*, 3, H, W)` range [-1, 1] into disk""" |
|
os.makedirs(save_dir, exist_ok=True) |
|
images = images.clone().detach().cpu().numpy() |
|
images = images.transpose(0, 2, 3, 1) |
|
n_images = len(images) |
|
|
|
for i in range(n_images): |
|
img = images[i] |
|
img = denormalize_input(img, dtype=np.int16) |
|
img = img[..., ::-1] |
|
cv2.imwrite(os.path.join(save_dir, f"G{i}.jpg"), img) |
|
|
|
|
|
class DDPTrainer: |
|
def _init_distributed(self): |
|
if self.cfg.ddp: |
|
self.logger.info("Setting up DDP") |
|
self.pg = torch.distributed.init_process_group( |
|
backend="nccl", |
|
rank=self.cfg.local_rank, |
|
world_size=self.cfg.world_size |
|
) |
|
self.G = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.G, self.pg) |
|
self.D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.D, self.pg) |
|
torch.cuda.set_device(self.cfg.local_rank) |
|
self.G.cuda(self.cfg.local_rank) |
|
self.D.cuda(self.cfg.local_rank) |
|
self.logger.info("Setting up DDP Done") |
|
|
|
def _init_amp(self, enabled=False): |
|
|
|
self.scaler_g = GradScaler(enabled=enabled) |
|
self.scaler_d = GradScaler(enabled=enabled) |
|
if self.cfg.ddp: |
|
self.G = DistributedDataParallel( |
|
self.G, device_ids=[self.cfg.local_rank], |
|
output_device=self.cfg.local_rank, |
|
find_unused_parameters=False) |
|
|
|
self.D = DistributedDataParallel( |
|
self.D, device_ids=[self.cfg.local_rank], |
|
output_device=self.cfg.local_rank, |
|
find_unused_parameters=False) |
|
self.logger.info("Set DistributedDataParallel") |
|
|
|
|
|
class Trainer(DDPTrainer): |
|
""" |
|
Base Trainer class |
|
""" |
|
|
|
def __init__( |
|
self, |
|
generator, |
|
discriminator, |
|
config, |
|
logger, |
|
) -> None: |
|
self.G = generator |
|
self.D = discriminator |
|
self.cfg = config |
|
self.max_norm = 10 |
|
self.device_type = 'cuda' if self.cfg.device.startswith('cuda') else 'cpu' |
|
self.optimizer_g = optim.Adam(self.G.parameters(), lr=self.cfg.lr_g, betas=(0.5, 0.999)) |
|
self.optimizer_d = optim.Adam(self.D.parameters(), lr=self.cfg.lr_d, betas=(0.5, 0.999)) |
|
self.loss_tracker = LossSummary() |
|
if self.cfg.ddp: |
|
self.device = torch.device(f"cuda:{self.cfg.local_rank}") |
|
logger.info(f"---------{self.cfg.local_rank} {self.device}") |
|
else: |
|
self.device = torch.device(self.cfg.device) |
|
self.loss_fn = AnimeGanLoss(self.cfg, self.device, self.cfg.gray_adv) |
|
self.logger = logger |
|
self._init_working_dir() |
|
self._init_distributed() |
|
self._init_amp(enabled=self.cfg.amp) |
|
|
|
def _init_working_dir(self): |
|
"""Init working directory for saving checkpoint, ...""" |
|
os.makedirs(self.cfg.exp_dir, exist_ok=True) |
|
Gname = self.G.name |
|
Dname = self.D.name |
|
self.checkpoint_path_G_init = os.path.join(self.cfg.exp_dir, f"{Gname}_init.pt") |
|
self.checkpoint_path_G = os.path.join(self.cfg.exp_dir, f"{Gname}.pt") |
|
self.checkpoint_path_D = os.path.join(self.cfg.exp_dir, f"{Dname}.pt") |
|
self.save_image_dir = os.path.join(self.cfg.exp_dir, "generated_images") |
|
self.example_image_dir = os.path.join(self.cfg.exp_dir, "train_images") |
|
os.makedirs(self.save_image_dir, exist_ok=True) |
|
os.makedirs(self.example_image_dir, exist_ok=True) |
|
|
|
def init_weight_G(self, weight: str): |
|
"""Init Generator weight""" |
|
return load_checkpoint(self.G, weight) |
|
|
|
def init_weight_D(self, weight: str): |
|
"""Init Discriminator weight""" |
|
return load_checkpoint(self.D, weight) |
|
|
|
def pretrain_generator(self, train_loader, start_epoch): |
|
""" |
|
Pretrain Generator to recontruct input image. |
|
""" |
|
init_losses = [] |
|
set_lr(self.optimizer_g, self.cfg.init_lr) |
|
for epoch in range(start_epoch, self.cfg.init_epochs): |
|
|
|
|
|
pbar = tqdm(train_loader) |
|
for data in pbar: |
|
img = data["image"].to(self.device) |
|
|
|
self.optimizer_g.zero_grad() |
|
|
|
with autocast(enabled=self.cfg.amp): |
|
fake_img = self.G(img) |
|
loss = self.loss_fn.content_loss_vgg(img, fake_img) |
|
|
|
self.scaler_g.scale(loss).backward() |
|
self.scaler_g.step(self.optimizer_g) |
|
self.scaler_g.update() |
|
|
|
if self.cfg.ddp: |
|
torch.distributed.barrier() |
|
|
|
init_losses.append(loss.cpu().detach().numpy()) |
|
avg_content_loss = sum(init_losses) / len(init_losses) |
|
pbar.set_description(f'[Init Training G] content loss: {avg_content_loss:2f}') |
|
|
|
save_checkpoint(self.G, self.checkpoint_path_G_init, self.optimizer_g, epoch) |
|
if self.cfg.local_rank == 0: |
|
self.generate_and_save(self.cfg.test_image_dir, subname='initg') |
|
self.logger.info(f"Epoch {epoch}/{self.cfg.init_epochs}") |
|
|
|
set_lr(self.optimizer_g, self.cfg.lr_g) |
|
|
|
def train_epoch(self, epoch, train_loader): |
|
pbar = tqdm(train_loader, total=len(train_loader)) |
|
for data in pbar: |
|
img = data["image"].to(self.device) |
|
anime = data["anime"].to(self.device) |
|
anime_gray = data["anime_gray"].to(self.device) |
|
anime_smt_gray = data["smooth_gray"].to(self.device) |
|
|
|
|
|
self.optimizer_d.zero_grad() |
|
|
|
with autocast(enabled=self.cfg.amp): |
|
fake_img = self.G(img) |
|
|
|
if self.cfg.d_noise: |
|
fake_img += gaussian_noise() |
|
anime += gaussian_noise() |
|
anime_gray += gaussian_noise() |
|
anime_smt_gray += gaussian_noise() |
|
|
|
if self.cfg.gray_adv: |
|
fake_img = to_gray_scale(fake_img) |
|
|
|
fake_d = self.D(fake_img) |
|
real_anime_d = self.D(anime) |
|
real_anime_gray_d = self.D(anime_gray) |
|
real_anime_smt_gray_d = self.D(anime_smt_gray) |
|
|
|
loss_d = self.loss_fn.compute_loss_D( |
|
fake_d, |
|
real_anime_d, |
|
real_anime_gray_d, |
|
real_anime_smt_gray_d |
|
) |
|
|
|
self.scaler_d.scale(loss_d).backward() |
|
self.scaler_d.unscale_(self.optimizer_d) |
|
torch.nn.utils.clip_grad_norm_(self.D.parameters(), max_norm=self.max_norm) |
|
self.scaler_d.step(self.optimizer_d) |
|
self.scaler_d.update() |
|
if self.cfg.ddp: |
|
torch.distributed.barrier() |
|
self.loss_tracker.update_loss_D(loss_d) |
|
|
|
|
|
self.optimizer_g.zero_grad() |
|
|
|
with autocast(enabled=self.cfg.amp): |
|
fake_img = self.G(img) |
|
|
|
if self.cfg.gray_adv: |
|
fake_d = self.D(to_gray_scale(fake_img)) |
|
else: |
|
fake_d = self.D(fake_img) |
|
|
|
( |
|
adv_loss, con_loss, |
|
gra_loss, col_loss, |
|
tv_loss |
|
) = self.loss_fn.compute_loss_G( |
|
fake_img, |
|
img, |
|
fake_d, |
|
anime_gray, |
|
) |
|
loss_g = adv_loss + con_loss + gra_loss + col_loss + tv_loss |
|
if torch.isnan(adv_loss).any(): |
|
self.logger.info("----------------------------------------------") |
|
self.logger.info(fake_d) |
|
self.logger.info(adv_loss) |
|
self.logger.info("----------------------------------------------") |
|
raise ValueError("NAN loss!!") |
|
|
|
self.scaler_g.scale(loss_g).backward() |
|
self.scaler_d.unscale_(self.optimizer_g) |
|
grad = torch.nn.utils.clip_grad_norm_(self.G.parameters(), max_norm=self.max_norm) |
|
self.scaler_g.step(self.optimizer_g) |
|
self.scaler_g.update() |
|
if self.cfg.ddp: |
|
torch.distributed.barrier() |
|
|
|
self.loss_tracker.update_loss_G(adv_loss, gra_loss, col_loss, con_loss) |
|
pbar.set_description(f"{self.loss_tracker.get_loss_description()} - {grad:.3f}") |
|
|
|
def get_train_loader(self, dataset): |
|
if self.cfg.ddp: |
|
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) |
|
else: |
|
train_sampler = None |
|
return DataLoader( |
|
dataset, |
|
batch_size=self.cfg.batch_size, |
|
num_workers=self.cfg.num_workers, |
|
pin_memory=True, |
|
shuffle=train_sampler is None, |
|
sampler=train_sampler, |
|
drop_last=True, |
|
|
|
) |
|
|
|
def maybe_increase_imgsz(self, epoch, train_dataset): |
|
""" |
|
Increase image size at specific epoch |
|
+ 50% epochs train at imgsz[0] |
|
+ the rest 50% will increase every `len(epochs) / 2 / (len(imgsz) - 1)` |
|
|
|
Args: |
|
epoch: Current epoch |
|
train_dataset: Dataset |
|
|
|
Examples: |
|
``` |
|
epochs = 100 |
|
imgsz = [256, 352, 416, 512] |
|
=> [(0, 256), (50, 352), (66, 416), (82, 512)] |
|
``` |
|
""" |
|
epochs = self.cfg.epochs |
|
imgsz = self.cfg.imgsz |
|
num_size_remains = len(imgsz) - 1 |
|
half_epochs = epochs // 2 |
|
|
|
if len(imgsz) == 1: |
|
new_size = imgsz[0] |
|
elif epoch < half_epochs: |
|
new_size = imgsz[0] |
|
else: |
|
per_epoch_increment = int(half_epochs / num_size_remains) |
|
found = None |
|
for i, size in enumerate(imgsz[:]): |
|
if epoch < half_epochs + per_epoch_increment * i: |
|
found = size |
|
break |
|
if not found: |
|
found = imgsz[-1] |
|
new_size = found |
|
|
|
self.logger.info(f"Check {imgsz}, {new_size}, {train_dataset.imgsz}") |
|
if new_size != train_dataset.imgsz: |
|
train_dataset.set_imgsz(new_size) |
|
self.logger.info(f"Increase image size to {new_size} at epoch {epoch}") |
|
|
|
def train(self, train_dataset: Dataset, start_epoch=0, start_epoch_g=0): |
|
""" |
|
Train Generator and Discriminator. |
|
""" |
|
self.logger.info(self.device) |
|
self.G.to(self.device) |
|
self.D.to(self.device) |
|
|
|
self.pretrain_generator(self.get_train_loader(train_dataset), start_epoch_g) |
|
|
|
if self.cfg.local_rank == 0: |
|
self.logger.info(f"Start training for {self.cfg.epochs} epochs") |
|
|
|
for i, data in enumerate(train_dataset): |
|
for k in data.keys(): |
|
image = data[k] |
|
cv2.imwrite( |
|
os.path.join(self.example_image_dir, f"data_{k}_{i}.jpg"), |
|
revert_to_np_image(image) |
|
) |
|
if i == 2: |
|
break |
|
|
|
end = None |
|
num_iter = 0 |
|
per_epoch_times = [] |
|
for epoch in range(start_epoch, self.cfg.epochs): |
|
self.maybe_increase_imgsz(epoch, train_dataset) |
|
|
|
start = time.time() |
|
self.train_epoch(epoch, self.get_train_loader(train_dataset)) |
|
|
|
if epoch % self.cfg.save_interval == 0 and self.cfg.local_rank == 0: |
|
save_checkpoint(self.G, self.checkpoint_path_G,self.optimizer_g, epoch) |
|
save_checkpoint(self.D, self.checkpoint_path_D, self.optimizer_d, epoch) |
|
self.generate_and_save(self.cfg.test_image_dir) |
|
|
|
if epoch % 10 == 0: |
|
self.copy_results(epoch) |
|
|
|
num_iter += 1 |
|
|
|
if self.cfg.local_rank == 0: |
|
end = time.time() |
|
if end is None: |
|
eta = 9999 |
|
else: |
|
per_epoch_time = (end - start) |
|
per_epoch_times.append(per_epoch_time) |
|
eta = np.mean(per_epoch_times) * (self.cfg.epochs - epoch) |
|
eta = convert_to_readable(eta) |
|
self.logger.info(f"epoch {epoch}/{self.cfg.epochs}, ETA: {eta}") |
|
|
|
def generate_and_save( |
|
self, |
|
image_dir, |
|
max_imgs=15, |
|
subname='gen' |
|
): |
|
''' |
|
Generate and save images |
|
''' |
|
start = time.time() |
|
self.G.eval() |
|
|
|
max_iter = max_imgs |
|
fake_imgs = [] |
|
real_imgs = [] |
|
image_files = glob(os.path.join(image_dir, "*")) |
|
|
|
for i, image_file in enumerate(image_files): |
|
image = read_image(image_file) |
|
image = resize_image(image) |
|
real_imgs.append(image.copy()) |
|
image = preprocess_images(image) |
|
image = image.to(self.device) |
|
with torch.no_grad(): |
|
with autocast(enabled=self.cfg.amp): |
|
fake_img = self.G(image) |
|
|
|
fake_img = fake_img.detach().cpu().numpy() |
|
|
|
fake_img = fake_img.transpose(0, 2, 3, 1) |
|
fake_imgs.append(denormalize_input(fake_img, dtype=np.int16)[0]) |
|
|
|
if i + 1 == max_iter: |
|
break |
|
|
|
|
|
|
|
for i, (real_img, fake_img) in enumerate(zip(real_imgs, fake_imgs)): |
|
img = np.concatenate((real_img, fake_img), axis=1) |
|
save_path = os.path.join(self.save_image_dir, f'{subname}_{i}.jpg') |
|
if not cv2.imwrite(save_path, img[..., ::-1]): |
|
self.logger.info(f"Save generated image failed, {save_path}, {img.shape}") |
|
elapsed = time.time() - start |
|
self.logger.info(f"Generated {len(fake_imgs)} images in {elapsed:.3f}s.") |
|
|
|
def copy_results(self, epoch): |
|
"""Copy result (Weight + Generated images) to each epoch folder |
|
Every N epoch |
|
""" |
|
copy_dir = os.path.join(self.cfg.exp_dir, f"epoch_{epoch}") |
|
os.makedirs(copy_dir, exist_ok=True) |
|
|
|
shutil.copy2( |
|
self.checkpoint_path_G, |
|
copy_dir |
|
) |
|
|
|
dest = os.path.join(copy_dir, os.path.basename(self.save_image_dir)) |
|
shutil.copytree( |
|
self.save_image_dir, |
|
dest, |
|
dirs_exist_ok=True |
|
) |
|
|