Upload SUPIR_model.py
Browse files- SUPIR/models/SUPIR_model.py +195 -0
 
    	
        SUPIR/models/SUPIR_model.py
    ADDED
    
    | 
         @@ -0,0 +1,195 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            from sgm.models.diffusion import DiffusionEngine
         
     | 
| 3 | 
         
            +
            from sgm.util import instantiate_from_config
         
     | 
| 4 | 
         
            +
            import copy
         
     | 
| 5 | 
         
            +
            from sgm.modules.distributions.distributions import DiagonalGaussianDistribution
         
     | 
| 6 | 
         
            +
            import random
         
     | 
| 7 | 
         
            +
            from SUPIR.utils.colorfix import wavelet_reconstruction, adaptive_instance_normalization
         
     | 
| 8 | 
         
            +
            from pytorch_lightning import seed_everything
         
     | 
| 9 | 
         
            +
            from torch.nn.functional import interpolate
         
     | 
| 10 | 
         
            +
            from SUPIR.utils.tilevae import VAEHook
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            class SUPIRModel(DiffusionEngine):
         
     | 
| 13 | 
         
            +
                def __init__(self, control_stage_config, ae_dtype='fp32', diffusion_dtype='fp32', p_p='', n_p='', *args, **kwargs):
         
     | 
| 14 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 15 | 
         
            +
                    control_model = instantiate_from_config(control_stage_config)
         
     | 
| 16 | 
         
            +
                    self.model.load_control_model(control_model)
         
     | 
| 17 | 
         
            +
                    self.first_stage_model.denoise_encoder = copy.deepcopy(self.first_stage_model.encoder)
         
     | 
| 18 | 
         
            +
                    self.sampler_config = kwargs['sampler_config']
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                    assert (ae_dtype in ['fp32', 'fp16', 'bf16']) and (diffusion_dtype in ['fp32', 'fp16', 'bf16'])
         
     | 
| 21 | 
         
            +
                    if ae_dtype == 'fp32':
         
     | 
| 22 | 
         
            +
                        ae_dtype = torch.float32
         
     | 
| 23 | 
         
            +
                    elif ae_dtype == 'fp16':
         
     | 
| 24 | 
         
            +
                        raise RuntimeError('fp16 cause NaN in AE')
         
     | 
| 25 | 
         
            +
                    elif ae_dtype == 'bf16':
         
     | 
| 26 | 
         
            +
                        ae_dtype = torch.bfloat16
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                    if diffusion_dtype == 'fp32':
         
     | 
| 29 | 
         
            +
                        diffusion_dtype = torch.float32
         
     | 
| 30 | 
         
            +
                    elif diffusion_dtype == 'fp16':
         
     | 
| 31 | 
         
            +
                        diffusion_dtype = torch.float16
         
     | 
| 32 | 
         
            +
                    elif diffusion_dtype == 'bf16':
         
     | 
| 33 | 
         
            +
                        diffusion_dtype = torch.bfloat16
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                    self.ae_dtype = ae_dtype
         
     | 
| 36 | 
         
            +
                    self.model.dtype = diffusion_dtype
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                    self.p_p = p_p
         
     | 
| 39 | 
         
            +
                    self.n_p = n_p
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                @torch.no_grad()
         
     | 
| 42 | 
         
            +
                def encode_first_stage(self, x):
         
     | 
| 43 | 
         
            +
                    with torch.autocast("cuda", dtype=self.ae_dtype):
         
     | 
| 44 | 
         
            +
                        z = self.first_stage_model.encode(x)
         
     | 
| 45 | 
         
            +
                    z = self.scale_factor * z
         
     | 
| 46 | 
         
            +
                    return z
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                @torch.no_grad()
         
     | 
| 49 | 
         
            +
                def encode_first_stage_with_denoise(self, x, use_sample=True, is_stage1=False):
         
     | 
| 50 | 
         
            +
                    with torch.autocast("cuda", dtype=self.ae_dtype):
         
     | 
| 51 | 
         
            +
                        if is_stage1:
         
     | 
| 52 | 
         
            +
                            h = self.first_stage_model.denoise_encoder_s1(x)
         
     | 
| 53 | 
         
            +
                        else:
         
     | 
| 54 | 
         
            +
                            h = self.first_stage_model.denoise_encoder(x)
         
     | 
| 55 | 
         
            +
                        moments = self.first_stage_model.quant_conv(h)
         
     | 
| 56 | 
         
            +
                        posterior = DiagonalGaussianDistribution(moments)
         
     | 
| 57 | 
         
            +
                        if use_sample:
         
     | 
| 58 | 
         
            +
                            z = posterior.sample()
         
     | 
| 59 | 
         
            +
                        else:
         
     | 
| 60 | 
         
            +
                            z = posterior.mode()
         
     | 
| 61 | 
         
            +
                    z = self.scale_factor * z
         
     | 
| 62 | 
         
            +
                    return z
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                @torch.no_grad()
         
     | 
| 65 | 
         
            +
                def decode_first_stage(self, z):
         
     | 
| 66 | 
         
            +
                    z = 1.0 / self.scale_factor * z
         
     | 
| 67 | 
         
            +
                    with torch.autocast("cuda", dtype=self.ae_dtype):
         
     | 
| 68 | 
         
            +
                        out = self.first_stage_model.decode(z)
         
     | 
| 69 | 
         
            +
                    return out.float()
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                @torch.no_grad()
         
     | 
| 72 | 
         
            +
                def batchify_denoise(self, x, is_stage1=False):
         
     | 
| 73 | 
         
            +
                    '''
         
     | 
| 74 | 
         
            +
                    [N, C, H, W], [-1, 1], RGB
         
     | 
| 75 | 
         
            +
                    '''
         
     | 
| 76 | 
         
            +
                    x = self.encode_first_stage_with_denoise(x, use_sample=False, is_stage1=is_stage1)
         
     | 
| 77 | 
         
            +
                    return self.decode_first_stage(x)
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                @torch.no_grad()
         
     | 
| 80 | 
         
            +
                def batchify_sample(self, x, p, p_p='default', n_p='default', num_steps=100, restoration_scale=4.0, s_churn=0, s_noise=1.003, cfg_scale=4.0, seed=-1,
         
     | 
| 81 | 
         
            +
                                    num_samples=1, control_scale=1, color_fix_type='None', use_linear_CFG=False, use_linear_control_scale=False,
         
     | 
| 82 | 
         
            +
                                    cfg_scale_start=1.0, control_scale_start=0.0, **kwargs):
         
     | 
| 83 | 
         
            +
                    '''
         
     | 
| 84 | 
         
            +
                    [N, C], [-1, 1], RGB
         
     | 
| 85 | 
         
            +
                    '''
         
     | 
| 86 | 
         
            +
                    assert len(x) == len(p)
         
     | 
| 87 | 
         
            +
                    assert color_fix_type in ['Wavelet', 'AdaIn', 'None']
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    N = len(x)
         
     | 
| 90 | 
         
            +
                    if num_samples > 1:
         
     | 
| 91 | 
         
            +
                        assert N == 1
         
     | 
| 92 | 
         
            +
                        N = num_samples
         
     | 
| 93 | 
         
            +
                        x = x.repeat(N, 1, 1, 1)
         
     | 
| 94 | 
         
            +
                        p = p * N
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    if p_p == 'default':
         
     | 
| 97 | 
         
            +
                        p_p = self.p_p
         
     | 
| 98 | 
         
            +
                    if n_p == 'default':
         
     | 
| 99 | 
         
            +
                        n_p = self.n_p
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    self.sampler_config.params.num_steps = num_steps
         
     | 
| 102 | 
         
            +
                    if use_linear_CFG:
         
     | 
| 103 | 
         
            +
                        self.sampler_config.params.guider_config.params.scale_min = cfg_scale
         
     | 
| 104 | 
         
            +
                        self.sampler_config.params.guider_config.params.scale = cfg_scale_start
         
     | 
| 105 | 
         
            +
                    else:
         
     | 
| 106 | 
         
            +
                        self.sampler_config.params.guider_config.params.scale_min = cfg_scale
         
     | 
| 107 | 
         
            +
                        self.sampler_config.params.guider_config.params.scale = cfg_scale
         
     | 
| 108 | 
         
            +
                    self.sampler_config.params.restore_cfg = restoration_scale
         
     | 
| 109 | 
         
            +
                    self.sampler_config.params.s_churn = s_churn
         
     | 
| 110 | 
         
            +
                    self.sampler_config.params.s_noise = s_noise
         
     | 
| 111 | 
         
            +
                    self.sampler = instantiate_from_config(self.sampler_config)
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    if seed == -1:
         
     | 
| 114 | 
         
            +
                        seed = random.randint(0, 65535)
         
     | 
| 115 | 
         
            +
                    seed_everything(seed)
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                    _z = self.encode_first_stage_with_denoise(x, use_sample=False)
         
     | 
| 118 | 
         
            +
                    x_stage1 = self.decode_first_stage(_z)
         
     | 
| 119 | 
         
            +
                    z_stage1 = self.encode_first_stage(x_stage1)
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    c, uc = self.prepare_condition(_z, p, p_p, n_p, N)
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    denoiser = lambda input, sigma, c, control_scale: self.denoiser(
         
     | 
| 124 | 
         
            +
                        self.model, input, sigma, c, control_scale, **kwargs
         
     | 
| 125 | 
         
            +
                    )
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    noised_z = torch.randn_like(_z).to(_z.device)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    _samples = self.sampler(denoiser, noised_z, cond=c, uc=uc, x_center=z_stage1, control_scale=control_scale,
         
     | 
| 130 | 
         
            +
                                            use_linear_control_scale=use_linear_control_scale, control_scale_start=control_scale_start)
         
     | 
| 131 | 
         
            +
                    samples = self.decode_first_stage(_samples)
         
     | 
| 132 | 
         
            +
                    if color_fix_type == 'Wavelet':
         
     | 
| 133 | 
         
            +
                        samples = wavelet_reconstruction(samples, x_stage1)
         
     | 
| 134 | 
         
            +
                    elif color_fix_type == 'AdaIn':
         
     | 
| 135 | 
         
            +
                        samples = adaptive_instance_normalization(samples, x_stage1)
         
     | 
| 136 | 
         
            +
                    return samples
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                def init_tile_vae(self, encoder_tile_size=512, decoder_tile_size=64):
         
     | 
| 139 | 
         
            +
                    self.first_stage_model.denoise_encoder.original_forward = self.first_stage_model.denoise_encoder.forward
         
     | 
| 140 | 
         
            +
                    self.first_stage_model.encoder.original_forward = self.first_stage_model.encoder.forward
         
     | 
| 141 | 
         
            +
                    self.first_stage_model.decoder.original_forward = self.first_stage_model.decoder.forward
         
     | 
| 142 | 
         
            +
                    self.first_stage_model.denoise_encoder.forward = VAEHook(
         
     | 
| 143 | 
         
            +
                        self.first_stage_model.denoise_encoder, encoder_tile_size, is_decoder=False, fast_decoder=False,
         
     | 
| 144 | 
         
            +
                        fast_encoder=False, color_fix=False, to_gpu=True)
         
     | 
| 145 | 
         
            +
                    self.first_stage_model.encoder.forward = VAEHook(
         
     | 
| 146 | 
         
            +
                        self.first_stage_model.encoder, encoder_tile_size, is_decoder=False, fast_decoder=False,
         
     | 
| 147 | 
         
            +
                        fast_encoder=False, color_fix=False, to_gpu=True)
         
     | 
| 148 | 
         
            +
                    self.first_stage_model.decoder.forward = VAEHook(
         
     | 
| 149 | 
         
            +
                        self.first_stage_model.decoder, decoder_tile_size, is_decoder=True, fast_decoder=False,
         
     | 
| 150 | 
         
            +
                        fast_encoder=False, color_fix=False, to_gpu=True)
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                def prepare_condition(self, _z, p, p_p, n_p, N):
         
     | 
| 153 | 
         
            +
                    batch = {}
         
     | 
| 154 | 
         
            +
                    batch['original_size_as_tuple'] = torch.tensor([1024, 1024]).repeat(N, 1).to(_z.device)
         
     | 
| 155 | 
         
            +
                    batch['crop_coords_top_left'] = torch.tensor([0, 0]).repeat(N, 1).to(_z.device)
         
     | 
| 156 | 
         
            +
                    batch['target_size_as_tuple'] = torch.tensor([1024, 1024]).repeat(N, 1).to(_z.device)
         
     | 
| 157 | 
         
            +
                    batch['aesthetic_score'] = torch.tensor([9.0]).repeat(N, 1).to(_z.device)
         
     | 
| 158 | 
         
            +
                    batch['control'] = _z
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                    batch_uc = copy.deepcopy(batch)
         
     | 
| 161 | 
         
            +
                    batch_uc['txt'] = [n_p for _ in p]
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                    if not isinstance(p[0], list):
         
     | 
| 164 | 
         
            +
                        batch['txt'] = [''.join([_p, p_p]) for _p in p]
         
     | 
| 165 | 
         
            +
                        with torch.cuda.amp.autocast(dtype=self.ae_dtype):
         
     | 
| 166 | 
         
            +
                            c, uc = self.conditioner.get_unconditional_conditioning(batch, batch_uc)
         
     | 
| 167 | 
         
            +
                    else:
         
     | 
| 168 | 
         
            +
                        assert len(p) == 1, 'Support bs=1 only for local prompt conditioning.'
         
     | 
| 169 | 
         
            +
                        p_tiles = p[0]
         
     | 
| 170 | 
         
            +
                        c = []
         
     | 
| 171 | 
         
            +
                        for i, p_tile in enumerate(p_tiles):
         
     | 
| 172 | 
         
            +
                            batch['txt'] = [''.join([p_tile, p_p])]
         
     | 
| 173 | 
         
            +
                            with torch.cuda.amp.autocast(dtype=self.ae_dtype):
         
     | 
| 174 | 
         
            +
                                if i == 0:
         
     | 
| 175 | 
         
            +
                                    _c, uc = self.conditioner.get_unconditional_conditioning(batch, batch_uc)
         
     | 
| 176 | 
         
            +
                                else:
         
     | 
| 177 | 
         
            +
                                    _c, _ = self.conditioner.get_unconditional_conditioning(batch, None)
         
     | 
| 178 | 
         
            +
                            c.append(_c)
         
     | 
| 179 | 
         
            +
                    return c, uc
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 183 | 
         
            +
                from SUPIR.util import create_model, load_state_dict
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                model = create_model('../../options/dev/SUPIR_paper_version.yaml')
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                SDXL_CKPT = '/opt/data/private/AIGC_pretrain/SDXL_cache/sd_xl_base_1.0_0.9vae.safetensors'
         
     | 
| 188 | 
         
            +
                SUPIR_CKPT = '/opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-paper.ckpt'
         
     | 
| 189 | 
         
            +
                model.load_state_dict(load_state_dict(SDXL_CKPT), strict=False)
         
     | 
| 190 | 
         
            +
                model.load_state_dict(load_state_dict(SUPIR_CKPT), strict=False)
         
     | 
| 191 | 
         
            +
                model = model.cuda()
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                x = torch.randn(1, 3, 512, 512).cuda()
         
     | 
| 194 | 
         
            +
                p = ['a professional, detailed, high-quality photo']
         
     | 
| 195 | 
         
            +
                samples = model.batchify_sample(x, p, num_steps=50, restoration_scale=4.0, s_churn=0, cfg_scale=4.0, seed=-1, num_samples=1)
         
     |