File size: 11,994 Bytes
90a9dd3
 
 
 
 
a7169e0
90a9dd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
import torch
import numpy as np
import sys
import os
import tqdm
from src.flair import degradations
import torchvision

def total_variation_loss(x):
    """
    Compute the total variation loss for a batch of images.
    Args:
        x (torch.Tensor): Input tensor of shape (B, C, H, W)
    Returns:
        torch.Tensor: Total variation loss
    """
    # Compute the differences between adjacent pixels
    diff_x = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])
    diff_y = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])
    
    # Sum the differences
    return torch.sum(diff_x) + torch.sum(diff_y)

class VariationalPosterior:
    def __init__(self, model, config):
        """
        Args:
            model (torch.nn.Module): model to be used for inference, should have single_step, encode and decode methods
            config (dict): configuration

        """
        self.config = config
        self.model = model

        # Initialize degradation model
        try:
            degradation = getattr(degradations, config["degradation"]["name"])
        except AttributeError:
            print(f"Degradation {config['degradation']['name']} not defined.")
            sys.exit(1)

        self.forward_operator = degradation(**config["degradation"]["kwargs"])

        if "optimized_reg_weight" in config and config["optimized_reg_weight"]:
            reg_weight = np.load(config["optimized_reg_weight"])
            if "reg_weight" in config["optimized_reg_weight"]:
                self.regularizer_weight = reg_weight * config["regularizer_weight"]
            else:
                # reg_weight = reg_weight / np.nanmax(reg_weight)
                reg_weight = 1 / (reg_weight + 1e-7)
                reg_weight = reg_weight / np.nansum(reg_weight) * reg_weight.shape[0]
                if "reg-shift" in config:
                    self.regularizer_weight = reg_weight + config["reg-shift"]
                else:
                    self.regularizer_weight = reg_weight - reg_weight[-1]
                self.regularizer_weight = np.clip(self.regularizer_weight, 0, None) * config["regularizer_weight"] 
            print("loaded opt reg weight.")
        else:
            self.regularizer_weight = config["regularizer_weight"]

    def set_degradation(self):
        try:
            degradation = getattr(degradations, self.config["degradation"]["name"])
        except AttributeError:
            sys.exit(1)
        self.forward_operator = degradation(**self.config["degradation"]["kwargs"])

    def data_term(self, latent_mu, y, optimizer_dataterm, likelihood_weight, likelihood_steps, early_stopping):
        """
        Performs data term optimization over several steps with early stopping.
        """
        for k in range(likelihood_steps):
            with torch.enable_grad():
                data_loss = torch.nn.MSELoss(reduction='sum')(self.forward_operator(self.model.decode(latent_mu), noise=False), y)
                loss = likelihood_weight * data_loss.sum()
            if data_loss < early_stopping * y.numel():
                if latent_mu.grad is not None:
                    latent_mu.grad = None
                del loss
                del data_loss
                break
            loss.backward()
            optimizer_dataterm.step()
            optimizer_dataterm.zero_grad()
            del loss
            del data_loss
        return 

    @torch.no_grad()
    def projection(self, latent_mu, y, alpha=1):
        x_0 = self.model.decode(latent_mu)
        y_hat = self.forward_operator(x_0, noise=False)
        x_inv_hat = self.forward_operator.pseudo_inv(y_hat)
        projection = x_0 - x_inv_hat + self.forward_operator.pseudo_inv(y) 
        latent_projection = self.model.encode(projection)
        # soft projection in latent space
        latent_projection = latent_mu - (latent_mu - latent_projection) * alpha
        return projection, latent_projection

    def find_closest_t(self, t):
        ts = torch.linspace(1, 0.0, self.regularizer_weight.shape[0], device=t.device, dtype=t.dtype)
        return torch.argmin(torch.abs(ts - t))

    def forward(self, y, kwargs):
        """
        Uses variational approach to infer the mode of the posterior distribution given a measurement y.
        Args:
            y (torch.Tensor): measurement tensor
        Returns:
            torch.Tensor: estimated mu
        """
        
        for key, value in kwargs.items():
            try:
                kwargs[key] = value.to(y.device)
            except AttributeError:
                pass

        return_dict = {}
        device = y.device

        
        if "init" in self.config and self.config["init"] =="random":
            # TODO: put this in model wrapper
            shape = (
                1,
                16,
                int(self.config["resolution"]) // self.model.vae_scale_factor,
                int(self.config["resolution"]) // self.model.vae_scale_factor,
            )
            latent_mu = torch.randn(shape, device=device, dtype=y.dtype)
        else:
            x_inv = self.forward_operator.pseudo_inv(y)
            latent_mu = self.model.encode(x_inv)
        latent_mu = latent_mu.detach().clone()
        latent_mu.requires_grad = True

        start_noise = torch.randn_like(latent_mu)
        optim_noise = start_noise.detach().clone()
        timesteps = self._get_timesteps(device)

        for epoch in range(self.config["epochs"]):
            optimizer, optimizer_dataterm = self._initialize_optimizers(latent_mu)
    
            for i, t in tqdm.tqdm(enumerate(timesteps), desc="Variational Optimization", total=len(timesteps)):
                t = torch.tensor([t], device=device, dtype=latent_mu.dtype)
                kwargs["noise"] = optim_noise.detach()
                kwargs["inv_alpha"] = self.config["inv_alpha"]
                    

                eps_prediction, noise, a_t, sigma_t, v_pred = self.model.single_step(latent_mu, t, kwargs)

                # predict x1 which is the start noise vector for DTA
                optim_noise = a_t * latent_mu + sigma_t * noise + a_t * v_pred

                reg_term = self._compute_regularization_term(eps_prediction, noise, a_t, sigma_t, t, latent_mu, v_pred)

                if self.config["likelihood_weight_mode"] == "reg_weight":
                    reg_idx = self.find_closest_t(t)
                    likelihood_weight = self.regularizer_weight[reg_idx] * self.config["likelihood_weight"]
                else:
                    likelihood_weight = self.config["likelihood_weight"]

                with torch.enable_grad():
                    reg_term = (reg_term.detach() * latent_mu.view(reg_term.shape[0], -1)).sum()

                    
                reg_term.backward()
                optimizer.step()
                optimizer.zero_grad()
                
                if self.config["projection"] and t>0.7:
                    with torch.enable_grad():
                        _, latent_mu_projection = self.projection(latent_mu, y)
                        proj_loss = (latent_mu - latent_mu_projection).detach() * latent_mu
                        proj_loss = proj_loss.sum()
                    optimizer_dataterm.zero_grad()
                    proj_loss.backward()
                    optimizer_dataterm.step()
                    optimizer_dataterm.zero_grad()
                
                self.data_term(
                    latent_mu,
                    y.detach(),
                    optimizer_dataterm,
                    likelihood_weight,
                    self.config["likelihood_steps"],
                    self.config["early_stopping"]
                )
                # self.save_intermediate_results(latent_mu, i)

        x_hat = self.model.decode(latent_mu)
        return_dict.update({"x_hat": x_hat})
        return return_dict

    def _get_timesteps(self, device):
        timesteps = self.model.get_timesteps(self.config["n_steps"], device=device, ts_min=self.config["ts_min"])
        if self.config["t_sampling"] == "descending":
            return timesteps
        elif self.config["t_sampling"] == "ascending":
            return timesteps.flip(0)
        elif self.config["t_sampling"] == "random":
            idx = torch.randperm(len(timesteps), device=device, dtype=timesteps.dtype)
            return timesteps[idx]            
        else:
            raise ValueError(f't_sampling {self.config["t_sampling"]} not supported.')

    def _initialize_optimizers(self, latent_mu):
        params = [latent_mu]
        params2 = [latent_mu]
        optimizer = self._get_optimizer(self.config["optimizer"], params)
        optimizer_dataterm = self._get_optimizer(self.config["optimizer_dataterm"], params2)
        if "scheduler" in self.config:
            self.scheduler = self._get_scheduler(self.config["scheduler"], optimizer)
        if "scheduler_dataterm" in self.config:
            self.scheduler_dataterm = self._get_scheduler(self.config["scheduler_dataterm"], optimizer_dataterm)
        return optimizer, optimizer_dataterm

    def _get_optimizer(self, optimizer_config, params):
        if optimizer_config["name"] == "Adam":
            return torch.optim.Adam(params, **optimizer_config["kwargs"])
        elif optimizer_config["name"] == "SGD":
            return torch.optim.SGD(params, **optimizer_config["kwargs"])
        else:
            raise ValueError(f'optimizer {optimizer_config["name"]} not supported.')

    def _get_scheduler(self, scheduler_config, optimizer):
        if scheduler_config["name"] == "StepLR":
            return torch.optim.lr_scheduler.StepLR(optimizer, **scheduler_config["kwargs"])
        elif scheduler_config["name"] == "CosineAnnealingLR":
            return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **scheduler_config["kwargs"])
        elif scheduler_config["name"] == "LinearLR":
            return torch.optim.lr_scheduler.LinearLR(optimizer, **scheduler_config["kwargs"])
        # Add other schedulers as needed
        else:
            raise ValueError(f'scheduler {scheduler_config["name"]} not supported.')

    def _compute_regularization_term(self, eps_prediction, noise, a_t, sigma_t, t, latent_mu, v):
        reg_term = (eps_prediction - noise).reshape(eps_prediction.shape[0], -1)
        # reg_term = (latent_mu-(a_t*latent_mu + sigma_t*noise - t *v)).reshape(eps_prediction.shape[0], -1)
        # 
        # reg_term /= reg_term.norm() / 1000
        if self.config["lambda_func"] == "sigma2":
            reg_term *= sigma_t / a_t
        elif self.config["lambda_func"] == "v":
            x_t = a_t*latent_mu + sigma_t*noise
            lambda_t_der = -2 * (1/(1-t) + 1/t)
            reg_term = lambda_t_der * t * reg_term / 2
            u_t = - 1 / (1-t) * x_t - t * lambda_t_der / 2 * noise
            # u_t = noise - latent_mu
            reg_term = -(u_t - v).reshape(eps_prediction.shape[0], -1)
        elif self.config["lambda_func"] != "sigma":
            raise ValueError(f'lambda_func {self.config["lambda_func"]} not supported.')

        if isinstance(self.regularizer_weight, np.ndarray):
            reg_idx = self.find_closest_t(t)
            regularizer_weight = self.regularizer_weight[reg_idx]
        else:
            regularizer_weight = self.regularizer_weight

        return reg_term * regularizer_weight
    
    def save_intermediate_results(self, latent_mu, i):
        """
        Saves intermediate results for debugging or visualization.
        Args:
            latent_mu (torch.Tensor): current latent representation
            i (int): current iteration index
        """
        x_hat = self.model.decode(latent_mu)
        # create directory if it does not exist
        os.makedirs("intermediate_results", exist_ok=True)
        torchvision.utils.save_image(x_hat, f"intermediate_results/x_hat_{i}.png", normalize=True, value_range=(-1, 1))
        print(f"Saved intermediate results for iteration {i}.")