import numpy as np import deepinv import torch import deepinv as dinv from deepinv.optim.data_fidelity import L2 from deepinv.optim.prior import PnP from deepinv.unfolded import unfolded_builder import copy import deepinv.optim.utils class PoissonGaussianDistance(dinv.optim.Distance): r""" Implementation of :math:`\distancename` as the normalized :math:`\ell_2` norm .. math:: f(x) = (x-y)^{T}\Sigma_y(x-y) with :math:`\Sigma_y=\text{diag}(gamma y + \sigma^2)` :param float sigma: Gaussian noise parameter. Default: 1. :param float gain: Poisson noise parameter. Default 0. """ def __init__(self, sigma=1.0, gain=0.): super().__init__() self.sigma = sigma self.gain = gain def fn(self, x, y, *args, **kwargs): r""" Computes the distance :math:`\distance{x}{y}` i.e. .. math:: \distance{x}{y} = \frac{1}{2}\|x-y\|^2 :param torch.Tensor u: Variable :math:`x` at which the data fidelity is computed. :param torch.Tensor y: Data :math:`y`. :return: (:class:`torch.Tensor`) data fidelity :math:`\datafid{u}{y}` of size `B` with `B` the size of the batch. """ norm = 1.0 / (self.sigma**2 + y * self.gain) z = (x - y) * norm d = 0.5 * torch.norm(z.reshape(z.shape[0], -1), p=2, dim=-1) ** 2 return d def grad(self, x, y, *args, **kwargs): r""" Computes the gradient of :math:`\distancename`, that is :math:`\nabla_{x}\distance{x}{y}`, i.e. .. math:: \nabla_{x}\distance{x}{y} = \frac{1}{\sigma^2} x-y :param torch.Tensor x: Variable :math:`x` at which the gradient is computed. :param torch.Tensor y: Observation :math:`y`. :return: (:class:`torch.Tensor`) gradient of the distance function :math:`\nabla_{x}\distance{x}{y}`. """ norm = 1.0 / (self.sigma**2 + y * self.gain) return (x - y) * norm def prox(self, x, y, *args, gamma=1.0, **kwargs): r""" Proximal operator of :math:`\gamma \distance{x}{y} = \frac{\gamma}{2 \sigma^2} \|x-y\|^2`. Computes :math:`\operatorname{prox}_{\gamma \distancename}`, i.e. .. math:: \operatorname{prox}_{\gamma \distancename} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|u-y\|_2^2+\frac{1}{2}\|u-x\|_2^2 :param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed. :param torch.Tensor y: Data :math:`y`. :param float gamma: thresholding parameter. :return: (:class:`torch.Tensor`) proximity operator :math:`\operatorname{prox}_{\gamma \distancename}(x)`. """ norm = 1.0 / (self.sigma**2 + y * self.gain) return (x + norm * gamma * y) / (1 + gamma * norm) class PoissonGaussianDataFidelity(dinv.optim.DataFidelity): r""" Implementation of the data-fidelity as the normalized :math:`\ell_2` norm .. math:: f(x) = \|\forw{x}-y\|^2_{\text{diag}(\sigma^2 + y \gamma)} It can be used to define a log-likelihood function associated with Poisson Gaussian noise by setting an appropriate noise level :math:`\sigma`. :param float sigma: Standard deviation of the noise to be used as a normalisation factor. :param float gain: Gain factor of the data-fidelity term. """ def __init__(self, sigma=1.0, gain=0.): super().__init__() self.d = PoissonGaussianDistance(sigma=sigma, gain=gain) self.gain = gain self.sigma = sigma def prox(self, x, y, physics, gamma=1.0, *args, **kwargs): r""" Proximal operator of :math:`\gamma \datafid{Ax}{y} = \frac{\gamma}{2\sigma^2}\|Ax-y\|^2`. Computes :math:`\operatorname{prox}_{\gamma \datafidname}`, i.e. .. math:: \operatorname{prox}_{\gamma \datafidname} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|Au-y\|_2^2+\frac{1}{2}\|u-x\|_2^2 :param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed. :param torch.Tensor y: Data :math:`y`. :param deepinv.physics.Physics physics: physics model. :param float gamma: stepsize of the proximity operator. :return: (:class:`torch.Tensor`) proximity operator :math:`\operatorname{prox}_{\gamma \datafidname}(x)`. """ assert isinstance(physics, dinv.physics.LinearPhysics), "not implemented for non-linear physics" if isinstance(physics, dinv.physics.StackedPhysics): device=y[0].device noise_model = physics[-1].noise_model else: device=y.device noise_model = physics.noise_model if hasattr(noise_model, "gain"): self.gain = noise_model.gain.detach().to(device) if hasattr(noise_model, "sigma"): self.sigma = noise_model.sigma.detach().to(device) # Ensure sigma is a tensor and reshape if necessary if isinstance(self.sigma, float): self.sigma = torch.tensor([self.sigma], device=device) if self.sigma.ndim == 0 : self.sigma = self.sigma.unsqueeze(0).to(device) # Ensure gain is a tensor and reshape if necessary if isinstance(self.gain, float): self.gain = torch.tensor([self.gain], device=device) if self.gain.ndim == 0 : self.gain = self.gain.unsqueeze(0).to(device) if self.gain[0] > 0 : norm = gamma / (self.sigma[:, None, None, None]**2 + y * self.gain[:, None, None, None]) else : norm = gamma / (self.sigma[:, None, None, None]**2) A = lambda u: physics.A_adjoint(physics.A(u)*norm) + u b = physics.A_adjoint(norm*y) + x return deepinv.optim.utils.conjugate_gradient(A, b, init=x, max_iter=3, tol=1e-3) from deepinv.optim.optim_iterators import OptimIterator, fStep, gStep class myHQSIteration(OptimIterator): r""" Single iteration of half-quadratic splitting. Class for a single iteration of the Half-Quadratic Splitting (HQS) algorithm for minimising :math:`f(x) + \lambda \regname(x)`. The iteration is given by .. math:: \begin{equation*} \begin{aligned} u_{k} &= \operatorname{prox}_{\gamma f}(x_k) \\ x_{k+1} &= \operatorname{prox}_{\sigma \lambda \regname}(u_k). \end{aligned} \end{equation*} where :math:`\gamma` and :math:`\sigma` are step-sizes. Note that this algorithm does not converge to a minimizer of :math:`f(x) + \lambda \regname(x)`, but instead to a minimizer of :math:`\gamma\, ^1f+\sigma \lambda \regname`, where :math:`^1f` denotes the Moreau envelope of :math:`f` """ def __init__(self, **kwargs): super(myHQSIteration, self).__init__(**kwargs) self.g_step = mygStepHQS(**kwargs) self.f_step = myfStepHQS(**kwargs) self.requires_prox_g = True class myfStepHQS(fStep): r""" HQS fStep module. """ def __init__(self, **kwargs): super(myfStepHQS, self).__init__(**kwargs) def forward(self, x, cur_data_fidelity, cur_params, y, physics): r""" Single proximal step on the data-fidelity term :math:`f`. :param torch.Tensor x: Current iterate :math:`x_k`. :param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity. :param dict cur_params: Dictionary containing the current parameters of the algorithm. :param torch.Tensor y: Input data. :param deepinv.physics.Physics physics: Instance of the physics modeling the data-fidelity term. """ return cur_data_fidelity.prox(x, y, physics, gamma=cur_params["stepsize"]) class mygStepHQS(gStep): r""" HQS gStep module. """ def __init__(self, **kwargs): super(mygStepHQS, self).__init__(**kwargs) def forward(self, x, cur_prior, cur_params): r""" Single proximal step on the prior term :math:`\lambda \regname`. :param torch.Tensor x: Current iterate :math:`x_k`. :param dict cur_prior: Class containing the current prior. :param dict cur_params: Dictionary containing the current parameters of the algorithm. """ return cur_prior.prox( x, sigma_denoiser = cur_params["g_param"], gain_denoiser = cur_params["gain_param"], gamma=cur_params["lambda"] * cur_params["stepsize"], ) def get_unrolled_architecture(gain_param_init = 1e-3, weight_tied = True, model = None, device = 'cpu'): # Unrolled optimization algorithm parameters max_iter = 8 # number of unfolded layers # Select the data fidelity term # Set up the trainable denoising prior # Here the prior model is common for all iterations if model is not None : denoiser = model.to(device) else : denoiser = dinv.models.DRUNet( pretrained= '/lustre/fswork/projects/rech/nyd/commun/mterris/base_checkpoints/drunet_deepinv_color_finetune_22k.pth', ).to(device) class myPnP(PnP): r""" Gradient-Step Denoiser prior. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def prox(self, x, sigma_denoiser, gain_denoiser, *args, **kwargs): if not self.training: pad = (-x.size(-2) % 8, -x.size(-1) % 8) x = torch.nn.functional.pad(x, (0, pad[1], 0, pad[0]), mode="constant") out = self.denoiser(x, sigma=sigma_denoiser, gamma=gain_denoiser) if not self.training: out = out[..., : -pad[0] or None, : -pad[1] or None] return out data_fidelity = PoissonGaussianDataFidelity() if not weight_tied : prior = [myPnP(denoiser=copy.deepcopy(denoiser)) for i in range(max_iter)] else : prior = [myPnP(denoiser=denoiser)] def get_DPIR_params(noise_level_img, max_iter=8): r""" Default parameters for the DPIR Plug-and-Play algorithm. :param float noise_level_img: Noise level of the input image. """ s1 = 49.0 / 255.0 s2 = noise_level_img sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype( np.float32 ) stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2 lamb = 1 / 0.23 return list(sigma_denoiser), list(lamb * stepsize) sigma_denoiser, stepsize = get_DPIR_params(0.05) stepsize = torch.tensor(stepsize) * (torch.tensor(sigma_denoiser)**2) gain_denoiser = [gain_param_init]*len(sigma_denoiser) params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "gain_param": gain_denoiser} trainable_params = [ "g_param", "gain_param" "stepsize", ] # define which parameters from 'params_algo' are trainable # Define the unfolded trainable model. model = unfolded_builder( iteration=myHQSIteration(), params_algo=params_algo.copy(), trainable_params=trainable_params, data_fidelity=data_fidelity, max_iter=max_iter, prior=prior, device=device, ) return model.to(device)