ohayonguy
commited on
Commit
·
1b8b226
1
Parent(s):
2ef4159
first commit
Browse files- app.py +170 -0
- arch/__init__.py +2 -0
- lightning_models/mmse_rectified_flow.py +317 -0
app.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import torch
|
| 6 |
+
from basicsr.archs.srvgg_arch import SRVGGNetCompact
|
| 7 |
+
from basicsr.utils import img2tensor, tensor2img
|
| 8 |
+
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
| 9 |
+
from realesrgan.utils import RealESRGANer
|
| 10 |
+
import spaces
|
| 11 |
+
|
| 12 |
+
from lightning_models.mmse_rectified_flow import MMSERectifiedFlow
|
| 13 |
+
|
| 14 |
+
torch.set_grad_enabled(False)
|
| 15 |
+
|
| 16 |
+
if os.getenv('SPACES_ZERO_GPU') == "true":
|
| 17 |
+
os.environ['SPACES_ZERO_GPU'] = "1"
|
| 18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
|
| 20 |
+
if not os.path.exists('pretrained_models'):
|
| 21 |
+
os.makedirs('pretrained_models')
|
| 22 |
+
realesr_model_path = 'pretrained_models/RealESRGAN_x4plus.pth'
|
| 23 |
+
if not os.path.exists(realesr_model_path):
|
| 24 |
+
os.system(
|
| 25 |
+
"wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -O experiments/pretrained_models/RealESRGAN_x4plus.pth")
|
| 26 |
+
|
| 27 |
+
pmrf_model_path = 'blind_face_restoration_pmrf.ckpt'
|
| 28 |
+
|
| 29 |
+
# background enhancer with RealESRGAN
|
| 30 |
+
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
| 31 |
+
half = True if torch.cuda.is_available() else False
|
| 32 |
+
upsampler = RealESRGANer(scale=4, model_path=realesr_model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
|
| 33 |
+
|
| 34 |
+
pmrf = MMSERectifiedFlow.load_from_checkpoint('./blind_face_restoration_pmrf.ckpt',
|
| 35 |
+
mmse_model_arch='swinir_L',
|
| 36 |
+
mmse_model_ckpt_path=None,
|
| 37 |
+
map_location='cpu').to(device)
|
| 38 |
+
|
| 39 |
+
os.makedirs('output', exist_ok=True)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@torch.inference_mode()
|
| 43 |
+
@spaces.GPU()
|
| 44 |
+
def enhance_face(img, face_helper, has_aligned, only_center_face=False, paste_back=True, scale=2):
|
| 45 |
+
face_helper.clean_all()
|
| 46 |
+
|
| 47 |
+
if has_aligned: # the inputs are already aligned
|
| 48 |
+
img = cv2.resize(img, (512, 512))
|
| 49 |
+
face_helper.cropped_faces = [img]
|
| 50 |
+
else:
|
| 51 |
+
face_helper.read_image(img)
|
| 52 |
+
face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
|
| 53 |
+
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
|
| 54 |
+
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
|
| 55 |
+
# align and warp each face
|
| 56 |
+
face_helper.align_warp_face()
|
| 57 |
+
|
| 58 |
+
# face restoration
|
| 59 |
+
for cropped_face in face_helper.cropped_faces:
|
| 60 |
+
# prepare data
|
| 61 |
+
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
| 62 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
dummy_x = torch.zeros_like(cropped_face_t)
|
| 66 |
+
output = pmrf.generate_reconstructions(dummy_x, cropped_face_t, None, 25, device)
|
| 67 |
+
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(0, 1))
|
| 68 |
+
except RuntimeError as error:
|
| 69 |
+
print(f'\tFailed inference for RestoreFormer: {error}.')
|
| 70 |
+
restored_face = cropped_face
|
| 71 |
+
|
| 72 |
+
restored_face = restored_face.astype('uint8')
|
| 73 |
+
face_helper.add_restored_face(restored_face)
|
| 74 |
+
|
| 75 |
+
if not has_aligned and paste_back:
|
| 76 |
+
# upsample the background
|
| 77 |
+
if upsampler is not None:
|
| 78 |
+
# Now only support RealESRGAN for upsampling background
|
| 79 |
+
bg_img = upsampler.enhance(img, outscale=scale)[0]
|
| 80 |
+
else:
|
| 81 |
+
bg_img = None
|
| 82 |
+
|
| 83 |
+
face_helper.get_inverse_affine(None)
|
| 84 |
+
# paste each restored face to the input image
|
| 85 |
+
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img)
|
| 86 |
+
return face_helper.cropped_faces, face_helper.restored_faces, restored_img
|
| 87 |
+
else:
|
| 88 |
+
return face_helper.cropped_faces, face_helper.restored_faces, None
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@torch.inference_mode()
|
| 92 |
+
@spaces.GPU()
|
| 93 |
+
def inference(img, aligned, scale):
|
| 94 |
+
if scale > 4:
|
| 95 |
+
scale = 4 # avoid too large scale value
|
| 96 |
+
try:
|
| 97 |
+
|
| 98 |
+
extension = os.path.splitext(os.path.basename(str(img)))[1]
|
| 99 |
+
img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
|
| 100 |
+
if len(img.shape) == 3 and img.shape[2] == 4:
|
| 101 |
+
img_mode = 'RGBA'
|
| 102 |
+
elif len(img.shape) == 2: # for gray inputs
|
| 103 |
+
img_mode = None
|
| 104 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
| 105 |
+
else:
|
| 106 |
+
img_mode = None
|
| 107 |
+
|
| 108 |
+
h, w = img.shape[0:2]
|
| 109 |
+
if h > 3500 or w > 3500:
|
| 110 |
+
print('Image size too large.')
|
| 111 |
+
return None, None
|
| 112 |
+
|
| 113 |
+
if h < 300:
|
| 114 |
+
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
|
| 115 |
+
|
| 116 |
+
face_helper = FaceRestoreHelper(
|
| 117 |
+
scale,
|
| 118 |
+
face_size=512,
|
| 119 |
+
crop_ratio=(1, 1),
|
| 120 |
+
det_model='retinaface_resnet50',
|
| 121 |
+
save_ext='png',
|
| 122 |
+
use_parse=True,
|
| 123 |
+
device=device,
|
| 124 |
+
model_rootpath=None)
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
has_aligned = True if aligned == 'aligned' else False
|
| 128 |
+
_, restored_aligned, restored_img = enhance_face(img, face_helper, has_aligned, only_center_face=False,
|
| 129 |
+
paste_back=True)
|
| 130 |
+
if has_aligned:
|
| 131 |
+
output = restored_aligned[0]
|
| 132 |
+
else:
|
| 133 |
+
output = restored_img
|
| 134 |
+
except RuntimeError as error:
|
| 135 |
+
print('Error', error)
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
if scale != 2:
|
| 139 |
+
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
|
| 140 |
+
h, w = img.shape[0:2]
|
| 141 |
+
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
|
| 142 |
+
except Exception as error:
|
| 143 |
+
print('wrong scale input.', error)
|
| 144 |
+
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
| 145 |
+
extension = 'png'
|
| 146 |
+
else:
|
| 147 |
+
extension = 'jpg'
|
| 148 |
+
save_path = f'output/out.{extension}'
|
| 149 |
+
cv2.imwrite(save_path, output)
|
| 150 |
+
|
| 151 |
+
output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
|
| 152 |
+
return output, save_path
|
| 153 |
+
except Exception as error:
|
| 154 |
+
print('global exception', error)
|
| 155 |
+
return None, None
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
css = r"""
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
demo = gr.Interface(
|
| 162 |
+
inference, [
|
| 163 |
+
gr.Image(type="filepath", label="Input"),
|
| 164 |
+
gr.Radio(['aligned', 'unaligned'], type="value", value='unaligned', label='Image Alignment'),
|
| 165 |
+
gr.Number(label="Rescaling factor", value=2),
|
| 166 |
+
], [
|
| 167 |
+
gr.Image(type="numpy", label="Output (The whole image)"),
|
| 168 |
+
gr.File(label="Download the output image")
|
| 169 |
+
],
|
| 170 |
+
)
|
arch/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from arch.hourglass.image_transformer_v2 import ImageTransformerDenoiserModelV2
|
| 2 |
+
from arch.swinir.swinir import SwinIR
|
lightning_models/mmse_rectified_flow.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from contextlib import contextmanager, nullcontext
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import wandb
|
| 6 |
+
from pytorch_lightning import LightningModule
|
| 7 |
+
from torch.nn.functional import mse_loss
|
| 8 |
+
from torch.nn.functional import sigmoid
|
| 9 |
+
from torch.optim import AdamW
|
| 10 |
+
from torch_ema import ExponentialMovingAverage as EMA
|
| 11 |
+
from torchmetrics.image import FrechetInceptionDistance, InceptionScore
|
| 12 |
+
from torchvision.transforms.functional import to_pil_image
|
| 13 |
+
from torchvision.utils import save_image
|
| 14 |
+
|
| 15 |
+
from utils.create_arch import create_arch
|
| 16 |
+
from utils.img_utils import create_grid
|
| 17 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MMSERectifiedFlow(LightningModule,
|
| 22 |
+
PyTorchModelHubMixin,
|
| 23 |
+
pipeline_tag="image-to-image",
|
| 24 |
+
license="mit",
|
| 25 |
+
):
|
| 26 |
+
def __init__(self,
|
| 27 |
+
stage,
|
| 28 |
+
arch,
|
| 29 |
+
conditional=False,
|
| 30 |
+
mmse_model_ckpt_path=None,
|
| 31 |
+
mmse_model_arch=None,
|
| 32 |
+
lr=5e-4,
|
| 33 |
+
weight_decay=1e-3,
|
| 34 |
+
betas=(0.9, 0.95),
|
| 35 |
+
mmse_noise_std=0.1,
|
| 36 |
+
num_flow_steps=50,
|
| 37 |
+
ema_decay=0.9999,
|
| 38 |
+
eps=0.0,
|
| 39 |
+
t_schedule='stratified_uniform',
|
| 40 |
+
*args,
|
| 41 |
+
**kwargs
|
| 42 |
+
):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.save_hyperparameters(logger=False)
|
| 45 |
+
|
| 46 |
+
if stage == 'flow':
|
| 47 |
+
if conditional:
|
| 48 |
+
condition_channels = 3
|
| 49 |
+
else:
|
| 50 |
+
condition_channels = 0
|
| 51 |
+
if mmse_model_arch is None and 'colorization' in kwargs and kwargs['colorization']:
|
| 52 |
+
condition_channels //= 3
|
| 53 |
+
self.model = create_arch(arch, condition_channels)
|
| 54 |
+
self.mmse_model = create_arch(mmse_model_arch, 0) if mmse_model_arch is not None else None
|
| 55 |
+
if mmse_model_ckpt_path is not None:
|
| 56 |
+
ckpt = torch.load(mmse_model_ckpt_path, map_location="cpu")
|
| 57 |
+
if mmse_model_arch is None:
|
| 58 |
+
mmse_model_arch = ckpt['hyper_parameters']['arch']
|
| 59 |
+
self.mmse_model = create_arch(mmse_model_arch, 0)
|
| 60 |
+
if 'ema' in ckpt:
|
| 61 |
+
# ema_decay doesn't affect anything here, because we are doing load_state_dict
|
| 62 |
+
mmse_ema = EMA(self.mmse_model.parameters(), decay=ema_decay)
|
| 63 |
+
mmse_ema.load_state_dict(ckpt['ema'])
|
| 64 |
+
mmse_ema.copy_to()
|
| 65 |
+
elif 'params_ema' in ckpt:
|
| 66 |
+
self.mmse_model.load_state_dict(ckpt['params_ema'])
|
| 67 |
+
else:
|
| 68 |
+
state_dict = ckpt['state_dict']
|
| 69 |
+
state_dict = {layer_name.replace('model.', ''): weights for layer_name, weights in
|
| 70 |
+
state_dict.items()}
|
| 71 |
+
state_dict = {layer_name.replace('module.', ''): weights for layer_name, weights in
|
| 72 |
+
state_dict.items()}
|
| 73 |
+
self.mmse_model.load_state_dict(state_dict)
|
| 74 |
+
for param in self.mmse_model.parameters():
|
| 75 |
+
param.requires_grad = False
|
| 76 |
+
self.mmse_model.eval()
|
| 77 |
+
else:
|
| 78 |
+
assert stage == 'mmse' or stage == 'naive_flow'
|
| 79 |
+
assert not conditional
|
| 80 |
+
self.model = create_arch(arch, 0)
|
| 81 |
+
self.mmse_model = None
|
| 82 |
+
if 'flow' in stage:
|
| 83 |
+
self.fid = FrechetInceptionDistance(reset_real_features=True, normalize=True)
|
| 84 |
+
self.inception_score = InceptionScore(normalize=True)
|
| 85 |
+
|
| 86 |
+
self.ema = EMA(self.model.parameters(), decay=ema_decay) if self.ema_wanted else None
|
| 87 |
+
self.test_results_path = None
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def ema_wanted(self):
|
| 91 |
+
return self.hparams.ema_decay != -1
|
| 92 |
+
|
| 93 |
+
def on_save_checkpoint(self, checkpoint: dict) -> None:
|
| 94 |
+
if self.ema_wanted:
|
| 95 |
+
checkpoint['ema'] = self.ema.state_dict()
|
| 96 |
+
return super().on_save_checkpoint(checkpoint)
|
| 97 |
+
|
| 98 |
+
def on_load_checkpoint(self, checkpoint: dict) -> None:
|
| 99 |
+
if self.ema_wanted:
|
| 100 |
+
self.ema.load_state_dict(checkpoint['ema'])
|
| 101 |
+
return super().on_load_checkpoint(checkpoint)
|
| 102 |
+
|
| 103 |
+
def on_before_zero_grad(self, optimizer) -> None:
|
| 104 |
+
if self.ema_wanted:
|
| 105 |
+
self.ema.update(self.model.parameters())
|
| 106 |
+
return super().on_before_zero_grad(optimizer)
|
| 107 |
+
|
| 108 |
+
def to(self, *args, **kwargs):
|
| 109 |
+
if self.ema_wanted:
|
| 110 |
+
self.ema.to(*args, **kwargs)
|
| 111 |
+
return super().to(*args, **kwargs)
|
| 112 |
+
|
| 113 |
+
# This will use the contextmanager of ema, to copy the EMA weights to the flow model during validation, and then restore them for training.
|
| 114 |
+
@contextmanager
|
| 115 |
+
def maybe_ema(self):
|
| 116 |
+
ema = self.ema
|
| 117 |
+
ctx = nullcontext if ema is None else ema.average_parameters
|
| 118 |
+
yield ctx
|
| 119 |
+
|
| 120 |
+
def forward_mmse(self, y):
|
| 121 |
+
return self.model(y).clip(0, 1)
|
| 122 |
+
|
| 123 |
+
def forward_flow(self, x_t, t, y=None):
|
| 124 |
+
if self.hparams.conditional:
|
| 125 |
+
if self.mmse_model is not None:
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
self.mmse_model.eval()
|
| 128 |
+
condition = self.mmse_model(y).clip(0, 1)
|
| 129 |
+
else:
|
| 130 |
+
condition = y
|
| 131 |
+
x_t = torch.cat((x_t, condition), dim=1)
|
| 132 |
+
return self.model(x_t, t)
|
| 133 |
+
|
| 134 |
+
def forward(self, x_t, t, y):
|
| 135 |
+
if 'flow' in self.hparams.stage:
|
| 136 |
+
return self.forward_flow(x_t, t, y)
|
| 137 |
+
else:
|
| 138 |
+
return self.forward_mmse(y)
|
| 139 |
+
|
| 140 |
+
@torch.no_grad()
|
| 141 |
+
def create_source_distribution_samples(self, x, y, non_noisy_z0):
|
| 142 |
+
with torch.no_grad():
|
| 143 |
+
if self.hparams.conditional:
|
| 144 |
+
source_dist_samples = torch.randn_like(x)
|
| 145 |
+
else:
|
| 146 |
+
if self.hparams.stage == 'flow':
|
| 147 |
+
if non_noisy_z0 is None:
|
| 148 |
+
self.mmse_model.eval()
|
| 149 |
+
non_noisy_z0 = self.mmse_model(y).clip(0, 1)
|
| 150 |
+
source_dist_samples = non_noisy_z0 + torch.randn_like(non_noisy_z0) * self.hparams.mmse_noise_std
|
| 151 |
+
else:
|
| 152 |
+
assert self.hparams.stage == 'naive_flow'
|
| 153 |
+
if non_noisy_z0 is not None:
|
| 154 |
+
source_dist_samples = non_noisy_z0
|
| 155 |
+
else:
|
| 156 |
+
source_dist_samples = y
|
| 157 |
+
if source_dist_samples.shape[1] != x.shape[1]:
|
| 158 |
+
assert source_dist_samples.shape[1] == 1 # Colorization
|
| 159 |
+
source_dist_samples = source_dist_samples.expand(-1, x.shape[1], -1, -1)
|
| 160 |
+
if self.hparams.mmse_noise_std is not None:
|
| 161 |
+
source_dist_samples = source_dist_samples + torch.randn_like(source_dist_samples) * self.hparams.mmse_noise_std
|
| 162 |
+
return source_dist_samples
|
| 163 |
+
|
| 164 |
+
@staticmethod
|
| 165 |
+
def stratified_uniform(bs, group=0, groups=1, dtype=None, device=None):
|
| 166 |
+
if groups <= 0:
|
| 167 |
+
raise ValueError(f"groups must be positive, got {groups}")
|
| 168 |
+
if group < 0 or group >= groups:
|
| 169 |
+
raise ValueError(f"group must be in [0, {groups})")
|
| 170 |
+
n = bs * groups
|
| 171 |
+
offsets = torch.arange(group, n, groups, dtype=dtype, device=device)
|
| 172 |
+
u = torch.rand(bs, dtype=dtype, device=device)
|
| 173 |
+
return ((offsets + u) / n).view(bs, 1, 1, 1)
|
| 174 |
+
|
| 175 |
+
def generate_random_t(self, bs, dtype=None):
|
| 176 |
+
if self.hparams.t_schedule == 'logit-normal':
|
| 177 |
+
return sigmoid(torch.randn(bs, 1, 1, 1, device=self.device)) * (1.0 - self.hparams.eps) + self.hparams.eps
|
| 178 |
+
elif self.hparams.t_schedule == 'uniform':
|
| 179 |
+
return torch.rand(bs, 1, 1, 1, device=self.device) * (1.0 - self.hparams.eps) + self.hparams.eps
|
| 180 |
+
elif self.hparams.t_schedule == 'stratified_uniform':
|
| 181 |
+
return self.stratified_uniform(bs, self.trainer.global_rank, self.trainer.world_size, dtype=dtype,
|
| 182 |
+
device=self.device) * (1.0 - self.hparams.eps) + self.hparams.eps
|
| 183 |
+
else:
|
| 184 |
+
raise NotImplementedError()
|
| 185 |
+
|
| 186 |
+
def training_step(self, batch, batch_idx):
|
| 187 |
+
x = batch['x']
|
| 188 |
+
y = batch['y']
|
| 189 |
+
non_noisy_z0 = batch['non_noisy_z0'] if 'non_noisy_z0' in batch else None
|
| 190 |
+
if 'flow' in self.hparams.stage:
|
| 191 |
+
with torch.no_grad():
|
| 192 |
+
t = self.generate_random_t(x.shape[0], dtype=x.dtype)
|
| 193 |
+
source_dist_samples = self.create_source_distribution_samples(x, y, non_noisy_z0)
|
| 194 |
+
x_t = t * x + (1.0 - t) * source_dist_samples
|
| 195 |
+
v_t = self(x_t, t.squeeze(), y)
|
| 196 |
+
loss = mse_loss(v_t, x - source_dist_samples)
|
| 197 |
+
else:
|
| 198 |
+
xhat = self(x_t=None, t=None, y=y)
|
| 199 |
+
loss = mse_loss(xhat, x)
|
| 200 |
+
self.log("train/loss", loss)
|
| 201 |
+
return loss
|
| 202 |
+
|
| 203 |
+
@torch.no_grad()
|
| 204 |
+
def generate_reconstructions(self, x, y, non_noisy_z0, num_flow_steps, result_device):
|
| 205 |
+
with self.maybe_ema():
|
| 206 |
+
if 'flow' in self.hparams.stage:
|
| 207 |
+
source_dist_samples = self.create_source_distribution_samples(x, y, non_noisy_z0)
|
| 208 |
+
|
| 209 |
+
dt = (1.0 / num_flow_steps) * (1.0 - self.hparams.eps)
|
| 210 |
+
x_t_next = source_dist_samples.clone()
|
| 211 |
+
x_t_seq = [x_t_next]
|
| 212 |
+
t_one = torch.ones(x.shape[0], device=self.device)
|
| 213 |
+
for i in range(num_flow_steps):
|
| 214 |
+
num_t = (i / num_flow_steps) * (1.0 - self.hparams.eps) + self.hparams.eps
|
| 215 |
+
v_t_next = self(x_t=x_t_next, t=t_one * num_t, y=y).to(x_t_next.dtype)
|
| 216 |
+
x_t_next = x_t_next.clone() + v_t_next * dt
|
| 217 |
+
x_t_seq.append(x_t_next.to(result_device))
|
| 218 |
+
|
| 219 |
+
xhat = x_t_seq[-1].clip(0, 1).to(torch.float32)
|
| 220 |
+
source_dist_samples = source_dist_samples.to(result_device)
|
| 221 |
+
else:
|
| 222 |
+
xhat = self(x_t=None, t=None, y=y).to(torch.float32)
|
| 223 |
+
x_t_seq = None
|
| 224 |
+
source_dist_samples = None
|
| 225 |
+
return xhat.to(result_device), x_t_seq, source_dist_samples
|
| 226 |
+
|
| 227 |
+
def validation_step(self, batch, batch_idx):
|
| 228 |
+
x = batch['x']
|
| 229 |
+
y = batch['y']
|
| 230 |
+
non_noisy_z0 = batch['non_noisy_z0'] if 'non_noisy_z0' in batch else None
|
| 231 |
+
xhat, x_t_seq, source_dist_samples = self.generate_reconstructions(x, y, non_noisy_z0, self.hparams.num_flow_steps,
|
| 232 |
+
self.device)
|
| 233 |
+
x = x.to(torch.float32)
|
| 234 |
+
y = y.to(torch.float32)
|
| 235 |
+
self.log_dict({"val_metrics/mse": ((x - xhat) ** 2).mean()}, on_step=False, on_epoch=True, sync_dist=True,
|
| 236 |
+
batch_size=x.shape[0])
|
| 237 |
+
|
| 238 |
+
if 'flow' in self.hparams.stage:
|
| 239 |
+
self.fid.update(x, real=True)
|
| 240 |
+
self.fid.update(xhat, real=False)
|
| 241 |
+
self.inception_score.update(xhat)
|
| 242 |
+
|
| 243 |
+
if batch_idx == 0:
|
| 244 |
+
wandb_logger = self.logger.experiment
|
| 245 |
+
wandb_logger.log({'val_images/x': [wandb.Image(to_pil_image(create_grid(x)))],
|
| 246 |
+
'val_images/y': [wandb.Image(to_pil_image(create_grid(y.clip(0, 1))))],
|
| 247 |
+
'val_images/xhat': [wandb.Image(to_pil_image(create_grid(xhat)))], })
|
| 248 |
+
if 'flow' in self.hparams.stage:
|
| 249 |
+
wandb_logger.log({'val_images/x_t_seq': [wandb.Image(to_pil_image(create_grid(
|
| 250 |
+
torch.cat([elem[0].unsqueeze(0).to(torch.float32) for elem in x_t_seq], dim=0).clip(0, 1),
|
| 251 |
+
num_images=len(x_t_seq))))], 'val_images/source_distribution_samples': [
|
| 252 |
+
wandb.Image(to_pil_image(create_grid(source_dist_samples.clip(0, 1).to(torch.float32))))]})
|
| 253 |
+
if self.mmse_model is not None:
|
| 254 |
+
xhat_mmse = self.mmse_model(y).clip(0, 1)
|
| 255 |
+
wandb_logger.log({'val_images/xhat_mmse': [
|
| 256 |
+
wandb.Image(to_pil_image(create_grid(xhat_mmse.to(torch.float32))))]})
|
| 257 |
+
|
| 258 |
+
def on_validation_epoch_end(self):
|
| 259 |
+
if 'flow' in self.hparams.stage:
|
| 260 |
+
inception_score_mean, inception_score_std = self.inception_score.compute()
|
| 261 |
+
self.log_dict(
|
| 262 |
+
{'val_metrics/fid': self.fid.compute(),
|
| 263 |
+
'val_metrics/inception_score_mean': inception_score_mean,
|
| 264 |
+
'val_metrics/inception_score_std': inception_score_std},
|
| 265 |
+
on_epoch=True, on_step=False, sync_dist=True,
|
| 266 |
+
batch_size=1)
|
| 267 |
+
self.fid.reset()
|
| 268 |
+
self.inception_score.reset()
|
| 269 |
+
|
| 270 |
+
def test_step(self, batch, batch_idx):
|
| 271 |
+
assert self.test_results_path is not None, "Please set test_results_path before testing."
|
| 272 |
+
assert os.path.isdir(self.test_results_path), 'Please make sure the test_result_path dir exists.'
|
| 273 |
+
|
| 274 |
+
def save_image_batch(images, folder, image_file_names):
|
| 275 |
+
os.makedirs(folder, exist_ok=True)
|
| 276 |
+
for i, img in enumerate(images):
|
| 277 |
+
save_image(images[i].clip(0, 1), os.path.join(folder, image_file_names[i]))
|
| 278 |
+
|
| 279 |
+
os.makedirs(self.test_results_path, exist_ok=True)
|
| 280 |
+
x = batch['x']
|
| 281 |
+
y = batch['y']
|
| 282 |
+
non_noisy_z0 = batch['non_noisy_z0'] if 'non_noisy_z0' in batch else None
|
| 283 |
+
y_path = os.path.join(self.test_results_path, 'y')
|
| 284 |
+
save_image_batch(y, y_path, batch['img_file_name'])
|
| 285 |
+
|
| 286 |
+
if 'flow' in self.hparams.stage:
|
| 287 |
+
source_dist_samples_to_save = None
|
| 288 |
+
|
| 289 |
+
for num_flow_steps in self.num_test_flow_steps:
|
| 290 |
+
xhat, x_t_seq, source_dist_samples = self.generate_reconstructions(x, y, non_noisy_z0, num_flow_steps,
|
| 291 |
+
torch.device("cpu"))
|
| 292 |
+
xhat_path = os.path.join(self.test_results_path, f"num_flow_steps={num_flow_steps}", 'xhat')
|
| 293 |
+
save_image_batch(xhat, xhat_path, batch['img_file_name'])
|
| 294 |
+
if source_dist_samples_to_save is None:
|
| 295 |
+
source_dist_samples_to_save = source_dist_samples
|
| 296 |
+
|
| 297 |
+
source_distribution_samples_path = os.path.join(self.test_results_path, 'source_distribution_samples')
|
| 298 |
+
save_image_batch(source_dist_samples_to_save, source_distribution_samples_path, batch['img_file_name'])
|
| 299 |
+
if self.mmse_model is not None:
|
| 300 |
+
mmse_estimates = self.mmse_model(y).clip(0, 1)
|
| 301 |
+
mmse_samples_path = os.path.join(self.test_results_path, 'mmse_samples')
|
| 302 |
+
save_image_batch(mmse_estimates, mmse_samples_path, batch['img_file_name'])
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
else:
|
| 306 |
+
xhat, _, _ = self.generate_reconstructions(x, y, non_noisy_z0, None, torch.device('cpu'))
|
| 307 |
+
xhat_path = os.path.join(self.test_results_path, 'xhat')
|
| 308 |
+
save_image_batch(xhat, xhat_path, batch['img_file_name'])
|
| 309 |
+
|
| 310 |
+
def configure_optimizers(self):
|
| 311 |
+
# Add here a learning rate scheduler if you wish to do so.
|
| 312 |
+
optimizer = AdamW(self.model.parameters(),
|
| 313 |
+
betas=self.hparams.betas,
|
| 314 |
+
eps=1e-8,
|
| 315 |
+
lr=self.hparams.lr,
|
| 316 |
+
weight_decay=self.hparams.weight_decay)
|
| 317 |
+
return optimizer
|