import torch import numpy as np import sys import os import tqdm from src.flair import degradations import torchvision def total_variation_loss(x): """ Compute the total variation loss for a batch of images. Args: x (torch.Tensor): Input tensor of shape (B, C, H, W) Returns: torch.Tensor: Total variation loss """ # Compute the differences between adjacent pixels diff_x = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]) diff_y = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]) # Sum the differences return torch.sum(diff_x) + torch.sum(diff_y) class VariationalPosterior: def __init__(self, model, config): """ Args: model (torch.nn.Module): model to be used for inference, should have single_step, encode and decode methods config (dict): configuration """ self.config = config self.model = model # Initialize degradation model try: degradation = getattr(degradations, config["degradation"]["name"]) except AttributeError: print(f"Degradation {config['degradation']['name']} not defined.") sys.exit(1) self.forward_operator = degradation(**config["degradation"]["kwargs"]) if "optimized_reg_weight" in config and config["optimized_reg_weight"]: reg_weight = np.load(config["optimized_reg_weight"]) if "reg_weight" in config["optimized_reg_weight"]: self.regularizer_weight = reg_weight * config["regularizer_weight"] else: # reg_weight = reg_weight / np.nanmax(reg_weight) reg_weight = 1 / (reg_weight + 1e-7) reg_weight = reg_weight / np.nansum(reg_weight) * reg_weight.shape[0] if "reg-shift" in config: self.regularizer_weight = reg_weight + config["reg-shift"] else: self.regularizer_weight = reg_weight - reg_weight[-1] self.regularizer_weight = np.clip(self.regularizer_weight, 0, None) * config["regularizer_weight"] print("loaded opt reg weight.") else: self.regularizer_weight = config["regularizer_weight"] def set_degradation(self): try: degradation = getattr(degradations, self.config["degradation"]["name"]) except AttributeError: sys.exit(1) self.forward_operator = degradation(**self.config["degradation"]["kwargs"]) def data_term(self, latent_mu, y, optimizer_dataterm, likelihood_weight, likelihood_steps, early_stopping): """ Performs data term optimization over several steps with early stopping. """ for k in range(likelihood_steps): with torch.enable_grad(): data_loss = torch.nn.MSELoss(reduction='sum')(self.forward_operator(self.model.decode(latent_mu), noise=False), y) loss = likelihood_weight * data_loss.sum() if data_loss < early_stopping * y.numel(): if latent_mu.grad is not None: latent_mu.grad = None del loss del data_loss break loss.backward() optimizer_dataterm.step() optimizer_dataterm.zero_grad() del loss del data_loss return @torch.no_grad() def projection(self, latent_mu, y, alpha=1): x_0 = self.model.decode(latent_mu) y_hat = self.forward_operator(x_0, noise=False) x_inv_hat = self.forward_operator.pseudo_inv(y_hat) projection = x_0 - x_inv_hat + self.forward_operator.pseudo_inv(y) latent_projection = self.model.encode(projection) # soft projection in latent space latent_projection = latent_mu - (latent_mu - latent_projection) * alpha return projection, latent_projection def find_closest_t(self, t): ts = torch.linspace(1, 0.0, self.regularizer_weight.shape[0], device=t.device, dtype=t.dtype) return torch.argmin(torch.abs(ts - t)) def forward(self, y, kwargs): """ Uses variational approach to infer the mode of the posterior distribution given a measurement y. Args: y (torch.Tensor): measurement tensor Returns: torch.Tensor: estimated mu """ for key, value in kwargs.items(): try: kwargs[key] = value.to(y.device) except AttributeError: pass return_dict = {} device = y.device if "init" in self.config and self.config["init"] =="random": # TODO: put this in model wrapper shape = ( 1, 16, int(self.config["resolution"]) // self.model.vae_scale_factor, int(self.config["resolution"]) // self.model.vae_scale_factor, ) latent_mu = torch.randn(shape, device=device, dtype=y.dtype) else: x_inv = self.forward_operator.pseudo_inv(y) latent_mu = self.model.encode(x_inv) latent_mu = latent_mu.detach().clone() latent_mu.requires_grad = True start_noise = torch.randn_like(latent_mu) optim_noise = start_noise.detach().clone() timesteps = self._get_timesteps(device) for epoch in range(self.config["epochs"]): optimizer, optimizer_dataterm = self._initialize_optimizers(latent_mu) for i, t in tqdm.tqdm(enumerate(timesteps), desc="Variational Optimization", total=len(timesteps)): t = torch.tensor([t], device=device, dtype=latent_mu.dtype) kwargs["noise"] = optim_noise.detach() kwargs["inv_alpha"] = self.config["inv_alpha"] eps_prediction, noise, a_t, sigma_t, v_pred = self.model.single_step(latent_mu, t, kwargs) # predict x1 which is the start noise vector for DTA optim_noise = a_t * latent_mu + sigma_t * noise + a_t * v_pred reg_term = self._compute_regularization_term(eps_prediction, noise, a_t, sigma_t, t, latent_mu, v_pred) if self.config["likelihood_weight_mode"] == "reg_weight": reg_idx = self.find_closest_t(t) likelihood_weight = self.regularizer_weight[reg_idx] * self.config["likelihood_weight"] else: likelihood_weight = self.config["likelihood_weight"] with torch.enable_grad(): reg_term = (reg_term.detach() * latent_mu.view(reg_term.shape[0], -1)).sum() reg_term.backward() optimizer.step() optimizer.zero_grad() if self.config["projection"] and t>0.7: with torch.enable_grad(): _, latent_mu_projection = self.projection(latent_mu, y) proj_loss = (latent_mu - latent_mu_projection).detach() * latent_mu proj_loss = proj_loss.sum() optimizer_dataterm.zero_grad() proj_loss.backward() optimizer_dataterm.step() optimizer_dataterm.zero_grad() self.data_term( latent_mu, y.detach(), optimizer_dataterm, likelihood_weight, self.config["likelihood_steps"], self.config["early_stopping"] ) # self.save_intermediate_results(latent_mu, i) x_hat = self.model.decode(latent_mu) return_dict.update({"x_hat": x_hat}) return return_dict def _get_timesteps(self, device): timesteps = self.model.get_timesteps(self.config["n_steps"], device=device, ts_min=self.config["ts_min"]) if self.config["t_sampling"] == "descending": return timesteps elif self.config["t_sampling"] == "ascending": return timesteps.flip(0) elif self.config["t_sampling"] == "random": idx = torch.randperm(len(timesteps), device=device, dtype=timesteps.dtype) return timesteps[idx] else: raise ValueError(f't_sampling {self.config["t_sampling"]} not supported.') def _initialize_optimizers(self, latent_mu): params = [latent_mu] params2 = [latent_mu] optimizer = self._get_optimizer(self.config["optimizer"], params) optimizer_dataterm = self._get_optimizer(self.config["optimizer_dataterm"], params2) if "scheduler" in self.config: self.scheduler = self._get_scheduler(self.config["scheduler"], optimizer) if "scheduler_dataterm" in self.config: self.scheduler_dataterm = self._get_scheduler(self.config["scheduler_dataterm"], optimizer_dataterm) return optimizer, optimizer_dataterm def _get_optimizer(self, optimizer_config, params): if optimizer_config["name"] == "Adam": return torch.optim.Adam(params, **optimizer_config["kwargs"]) elif optimizer_config["name"] == "SGD": return torch.optim.SGD(params, **optimizer_config["kwargs"]) else: raise ValueError(f'optimizer {optimizer_config["name"]} not supported.') def _get_scheduler(self, scheduler_config, optimizer): if scheduler_config["name"] == "StepLR": return torch.optim.lr_scheduler.StepLR(optimizer, **scheduler_config["kwargs"]) elif scheduler_config["name"] == "CosineAnnealingLR": return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **scheduler_config["kwargs"]) elif scheduler_config["name"] == "LinearLR": return torch.optim.lr_scheduler.LinearLR(optimizer, **scheduler_config["kwargs"]) # Add other schedulers as needed else: raise ValueError(f'scheduler {scheduler_config["name"]} not supported.') def _compute_regularization_term(self, eps_prediction, noise, a_t, sigma_t, t, latent_mu, v): reg_term = (eps_prediction - noise).reshape(eps_prediction.shape[0], -1) # reg_term = (latent_mu-(a_t*latent_mu + sigma_t*noise - t *v)).reshape(eps_prediction.shape[0], -1) # # reg_term /= reg_term.norm() / 1000 if self.config["lambda_func"] == "sigma2": reg_term *= sigma_t / a_t elif self.config["lambda_func"] == "v": x_t = a_t*latent_mu + sigma_t*noise lambda_t_der = -2 * (1/(1-t) + 1/t) reg_term = lambda_t_der * t * reg_term / 2 u_t = - 1 / (1-t) * x_t - t * lambda_t_der / 2 * noise # u_t = noise - latent_mu reg_term = -(u_t - v).reshape(eps_prediction.shape[0], -1) elif self.config["lambda_func"] != "sigma": raise ValueError(f'lambda_func {self.config["lambda_func"]} not supported.') if isinstance(self.regularizer_weight, np.ndarray): reg_idx = self.find_closest_t(t) regularizer_weight = self.regularizer_weight[reg_idx] else: regularizer_weight = self.regularizer_weight return reg_term * regularizer_weight def save_intermediate_results(self, latent_mu, i): """ Saves intermediate results for debugging or visualization. Args: latent_mu (torch.Tensor): current latent representation i (int): current iteration index """ x_hat = self.model.decode(latent_mu) # create directory if it does not exist os.makedirs("intermediate_results", exist_ok=True) torchvision.utils.save_image(x_hat, f"intermediate_results/x_hat_{i}.png", normalize=True, value_range=(-1, 1)) print(f"Saved intermediate results for iteration {i}.")