import torch from tqdm import tqdm import torchvision.utils as tvu import torchvision import os class_num = 951 def compute_alpha(beta, t): beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) return a def inverse_data_transform(x): x = (x + 1.0) / 2.0 return torch.clamp(x, 0.0, 1.0) def ddnm_diffusion(x, model, b, eta, A_funcs, y, cls_fn=None, classes=None, config=None): with torch.no_grad(): # setup iteration variables skip = config.diffusion.num_diffusion_timesteps//config.time_travel.T_sampling n = x.size(0) x0_preds = [] xs = [x] # generate time schedule times = get_schedule_jump(config.time_travel.T_sampling, config.time_travel.travel_length, config.time_travel.travel_repeat, ) time_pairs = list(zip(times[:-1], times[1:])) # reverse diffusion sampling for i, j in tqdm(time_pairs): i, j = i*skip, j*skip if j<0: j=-1 if j < i: # normal sampling t = (torch.ones(n) * i).to(x.device) next_t = (torch.ones(n) * j).to(x.device) at = compute_alpha(b, t.long()) at_next = compute_alpha(b, next_t.long()) xt = xs[-1].to('cuda') if cls_fn == None: et = model(xt, t) else: classes = torch.ones(xt.size(0), dtype=torch.long, device=torch.device("cuda"))*class_num et = model(xt, t, classes) et = et[:, :3] et = et - (1 - at).sqrt()[0, 0, 0, 0] * cls_fn(x, t, classes) if et.size(1) == 6: et = et[:, :3] x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() x0_t_hat = x0_t - A_funcs.A_pinv( A_funcs.A(x0_t.reshape(x0_t.size(0), -1)) - y.reshape(y.size(0), -1) ).reshape(*x0_t.size()) c1 = (1 - at_next).sqrt() * eta c2 = (1 - at_next).sqrt() * ((1 - eta ** 2) ** 0.5) xt_next = at_next.sqrt() * x0_t_hat + c1 * torch.randn_like(x0_t) + c2 * et x0_preds.append(x0_t.to('cpu')) xs.append(xt_next.to('cpu')) else: # time-travel back next_t = (torch.ones(n) * j).to(x.device) at_next = compute_alpha(b, next_t.long()) x0_t = x0_preds[-1].to('cuda') xt_next = at_next.sqrt() * x0_t + torch.randn_like(x0_t) * (1 - at_next).sqrt() xs.append(xt_next.to('cpu')) return [xs[-1]], [x0_preds[-1]] def ddnm_plus_diffusion(x, model, b, eta, A_funcs, y, sigma_y, cls_fn=None, classes=None, config=None): with torch.no_grad(): # setup iteration variables skip = config.diffusion.num_diffusion_timesteps//config.time_travel.T_sampling n = x.size(0) x0_preds = [] xs = [x] # generate time schedule times = get_schedule_jump(config.time_travel.T_sampling, config.time_travel.travel_length, config.time_travel.travel_repeat, ) time_pairs = list(zip(times[:-1], times[1:])) # reverse diffusion sampling for i, j in tqdm(time_pairs): i, j = i*skip, j*skip if j<0: j=-1 if j < i: # normal sampling t = (torch.ones(n) * i).to(x.device) next_t = (torch.ones(n) * j).to(x.device) at = compute_alpha(b, t.long()) at_next = compute_alpha(b, next_t.long()) xt = xs[-1].to('cuda') if cls_fn == None: et = model(xt, t) else: classes = torch.ones(xt.size(0), dtype=torch.long, device=torch.device("cuda"))*class_num et = model(xt, t, classes) et = et[:, :3] et = et - (1 - at).sqrt()[0, 0, 0, 0] * cls_fn(x, t, classes) if et.size(1) == 6: et = et[:, :3] # Eq. 12 x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() sigma_t = (1 - at_next).sqrt()[0, 0, 0, 0] # Eq. 17 x0_t_hat = x0_t - A_funcs.Lambda(A_funcs.A_pinv( A_funcs.A(x0_t.reshape(x0_t.size(0), -1)) - y.reshape(y.size(0), -1) ).reshape(x0_t.size(0), -1), at_next.sqrt()[0, 0, 0, 0], sigma_y, sigma_t, eta).reshape(*x0_t.size()) # Eq. 51 xt_next = at_next.sqrt() * x0_t_hat + A_funcs.Lambda_noise( torch.randn_like(x0_t).reshape(x0_t.size(0), -1), at_next.sqrt()[0, 0, 0, 0], sigma_y, sigma_t, eta, et.reshape(et.size(0), -1)).reshape(*x0_t.size()) x0_preds.append(x0_t.to('cpu')) xs.append(xt_next.to('cpu')) else: # time-travel back next_t = (torch.ones(n) * j).to(x.device) at_next = compute_alpha(b, next_t.long()) x0_t = x0_preds[-1].to('cuda') xt_next = at_next.sqrt() * x0_t + torch.randn_like(x0_t) * (1 - at_next).sqrt() xs.append(xt_next.to('cpu')) # #ablation # if i%50==0: # os.makedirs('/userhome/wyh/ddnm/debug/x0t', exist_ok=True) # tvu.save_image( # inverse_data_transform(x0_t[0]), # os.path.join('/userhome/wyh/ddnm/debug/x0t', f"x0_t_{i}.png") # ) # os.makedirs('/userhome/wyh/ddnm/debug/x0_t_hat', exist_ok=True) # tvu.save_image( # inverse_data_transform(x0_t_hat[0]), # os.path.join('/userhome/wyh/ddnm/debug/x0_t_hat', f"x0_t_hat_{i}.png") # ) # os.makedirs('/userhome/wyh/ddnm/debug/xt_next', exist_ok=True) # tvu.save_image( # inverse_data_transform(xt_next[0]), # os.path.join('/userhome/wyh/ddnm/debug/xt_next', f"xt_next_{i}.png") # ) return [xs[-1]], [x0_preds[-1]] # form RePaint def get_schedule_jump(T_sampling, travel_length, travel_repeat): jumps = {} for j in range(0, T_sampling - travel_length, travel_length): jumps[j] = travel_repeat - 1 t = T_sampling ts = [] while t >= 1: t = t-1 ts.append(t) if jumps.get(t, 0) > 0: jumps[t] = jumps[t] - 1 for _ in range(travel_length): t = t + 1 ts.append(t) ts.append(-1) _check_times(ts, -1, T_sampling) return ts def _check_times(times, t_0, T_sampling): # Check end assert times[0] > times[1], (times[0], times[1]) # Check beginning assert times[-1] == -1, times[-1] # Steplength = 1 for t_last, t_cur in zip(times[:-1], times[1:]): assert abs(t_last - t_cur) == 1, (t_last, t_cur) # Value range for t in times: assert t >= t_0, (t, t_0) assert t <= T_sampling, (t, T_sampling)