Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import torch | |
| from ml_collections.config_flags import config_flags | |
| from sde.config import get_config | |
| from sde import ddpm, ncsnv2, ncsnpp # need to import to trigger its registry | |
| from sde import utils as mutils | |
| from sde.ema import ExponentialMovingAverage | |
| from adapt import ScoreAdapter | |
| device = torch.device("cuda") | |
| def restore_checkpoint(ckpt_dir, state, device): | |
| loaded_state = torch.load(ckpt_dir, map_location=device) | |
| # state['optimizer'].load_state_dict(loaded_state['optimizer']) | |
| state['model'].load_state_dict(loaded_state['model'], strict=False) | |
| state['ema'].load_state_dict(loaded_state['ema']) | |
| state['step'] = loaded_state['step'] | |
| return state | |
| def save_checkpoint(ckpt_dir, state): | |
| saved_state = { | |
| 'optimizer': state['optimizer'].state_dict(), | |
| 'model': state['model'].state_dict(), | |
| 'ema': state['ema'].state_dict(), | |
| 'step': state['step'] | |
| } | |
| torch.save(saved_state, ckpt_dir) | |
| class VESDE(ScoreAdapter): | |
| def __init__(self): | |
| config = get_config() | |
| config.device = device | |
| ckpt_fname = self.checkpoint_root() / "sde" / 'checkpoint_127.pth' | |
| score_model = mutils.create_model(config) | |
| ema = ExponentialMovingAverage( | |
| score_model.parameters(), decay=config.model.ema_rate | |
| ) | |
| state = dict(model=score_model, ema=ema, step=0) | |
| self._data_shape = ( | |
| config.data.num_channels, config.data.image_size, config.data.image_size | |
| ) | |
| self._σ_min = float(config.model.sigma_min * 2) | |
| state = restore_checkpoint(ckpt_fname, state, device=config.device) | |
| ema.copy_to(score_model.parameters()) | |
| score_model.eval() | |
| score_model = score_model.module # remove DataParallel | |
| self.model = score_model | |
| self._device = device | |
| def data_shape(self): | |
| return self._data_shape | |
| def σ_min(self): | |
| return self._σ_min | |
| def denoise(self, xs, σ): | |
| N = xs.shape[0] | |
| # see Karras eqn. 212-215 for the 1/2 σ correction | |
| cond_t = (0.5 * σ) * torch.ones(N, device=self.device) | |
| # note that the forward function the model has been modified; see comments | |
| n_hat = self.model(xs, cond_t) | |
| Ds = xs + σ * n_hat | |
| return Ds | |
| def unet_is_cond(self): | |
| return False | |
| def use_cls_guidance(self): | |
| return False | |
| def snap_t_to_nearest_tick(self, t): | |
| return super().snap_t_to_nearest_tick(t) | |