alexnasa commited on
Commit
f4b4884
·
verified ·
1 Parent(s): c8b1497

Update pipelines/pipeline_seesr.py

Browse files
Files changed (1) hide show
  1. pipelines/pipeline_seesr.py +38 -57
pipelines/pipeline_seesr.py CHANGED
@@ -99,7 +99,9 @@ EXAMPLE_DOC_STRING = """
99
  def kde_grad(x0: torch.Tensor, patch_size = 16, bandwidth = 0.1):
100
  # x0: (N, C, H, W) in float32
101
  N, C, H, W = x0.shape
102
- patches = unfold(x0, kernel_size=patch_size, stride=patch_size) # (N, C*ps*ps, M)
 
 
103
  P, M = patches.shape[1], patches.shape[2]
104
  p_i = patches.unsqueeze(1) # (N,1,P,M)
105
  p_j = patches.unsqueeze(0) # (1,N,P,M)
@@ -111,15 +113,13 @@ def kde_grad(x0: torch.Tensor, patch_size = 16, bandwidth = 0.1):
111
  num = (w.unsqueeze(2) * diff).sum(dim=1) # (N,P,M)
112
  denom = w.sum(dim=1, keepdim=True) + 1e-8 # (N,1,M)
113
  mshift = num / denom # (N,P,M)
114
-
115
  # fold back
116
  grad = fold(
117
- mshift / bandwidth**2,
118
- output_size=(H, W),
119
- kernel_size=patch_size,
120
- stride=patch_size
121
- ) # (N, C, H, W)
122
-
123
  return grad
124
 
125
  class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
@@ -835,8 +835,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
835
  num_particles: Optional[int] = 4,
836
  gamma_0: Optional[float] = 0.1, # base steering strength
837
  use_KDS = True,
838
- bandwidth = 0.1,
839
  patch_size = 16,
 
840
  args=None,
841
  ):
842
  r"""
@@ -1050,9 +1050,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
1050
  for i, t in enumerate(timesteps):
1051
  with torch.no_grad():
1052
  # pass, if the timestep is larger than start_steps
1053
- # if t > start_steps:
1054
- # print(f'pass {t} steps.')
1055
- # continue
1056
 
1057
  # expand the latents if we are doing classifier free guidance
1058
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@@ -1189,7 +1189,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
1189
  cond_list = []
1190
  img_list = []
1191
 
1192
- noise_preds.append(model_out)
1193
 
1194
  # Stitch noise predictions for all tiles
1195
  noise_pred = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
@@ -1226,69 +1226,50 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
1226
 
1227
  if use_KDS:
1228
 
1229
- # 2) Compute x₀ prediction for all particles
1230
  beta_t = 1 - self.scheduler.alphas_cumprod[t]
1231
  alpha_t = self.scheduler.alphas_cumprod[t].sqrt()
1232
  sigma_t = beta_t.sqrt()
1233
- x0_pred = (latents - sigma_t * noise_pred) / alpha_t # shape [2N, C, H, W]
1234
-
1235
- # split into unconditional vs. conditional
1236
- x0_uncond, x0_cond = x0_pred.chunk(2, dim=0) # each [N, C, H, W]
 
 
 
 
 
1237
 
1238
- # 3) Apply KDE steering *only* on the conditional batch
1239
- m_shift_cond = kde_grad(x0_cond, patch_size=patch_size, bandwidth=bandwidth) # [N, C, H, W]
1240
- delta_t = gamma_0 * (1 - i / (len(timesteps) - 1))
1241
- x0_cond_steer = x0_cond + delta_t * m_shift_cond # steered conditional
1242
 
1243
- # 4) Recombine the latents: leave uncond untouched, use steered cond
1244
- x0_steer = torch.cat([x0_uncond, x0_cond_steer], dim=0) # [2N, C, H, W]
1245
-
1246
- # 5) Recompute “noise” for DDIM step
1247
  noise_pred_kds = (latents - alpha_t * x0_steer) / sigma_t
1248
 
1249
- # 6) Determine prev alphas and form next latent per DDIM
1250
  if i < len(timesteps) - 1:
1251
- next_t = timesteps[i + 1]
1252
- alpha_prev = self.scheduler.alphas_cumprod[next_t].sqrt()
1253
  else:
1254
- alpha_prev = self.scheduler.final_alpha_cumprod.sqrt()
 
1255
  sigma_prev = (1 - alpha_prev**2).sqrt()
1256
 
 
1257
  latents = (
1258
- alpha_prev * x0_steer
1259
- + sigma_prev * noise_pred_kds
1260
  ).detach().requires_grad_(True)
1261
  else:
1262
 
1263
  # compute the previous noisy sample x_t -> x_t-1
1264
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1265
 
1266
- with torch.no_grad():
1267
- if use_KDS:
1268
- # Final-latent selection (once!)
1269
- # latents shape: [2*N, C, H, W]
1270
- uncond_latents, cond_latents = latents.chunk(2, dim=0) # each [N, C, H, W]
1271
- # 1) ensemble mean
1272
- mean_cond = cond_latents.mean(dim=0, keepdim=True) # [1, C, H, W]
1273
- # 2) distances
1274
- dists = ((cond_latents - mean_cond)
1275
- .view(cond_latents.size(0), -1)
1276
- .pow(2)
1277
- .sum(dim=1)) # [N]
1278
- # 3) best index
1279
- best_idx = dists.argmin().item()
1280
- # 4) select that latent (and its uncond pair)
1281
- best_uncond = uncond_latents[best_idx:best_idx+1]
1282
- best_cond = cond_latents [best_idx:best_idx+1]
1283
- latents = torch.cat([best_uncond, best_cond], dim=0) # [2, C, H, W]
1284
-
1285
- # call the callback, if provided
1286
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1287
- progress_bar.update()
1288
- if callback is not None and i % callback_steps == 0:
1289
- callback(i, t, latents)
1290
 
1291
-
1292
  # If we do sequential model offloading, let's offload unet and controlnet
1293
  # manually for max memory savings
1294
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
 
99
  def kde_grad(x0: torch.Tensor, patch_size = 16, bandwidth = 0.1):
100
  # x0: (N, C, H, W) in float32
101
  N, C, H, W = x0.shape
102
+ patches = unfold(
103
+ x0, kernel_size=patch_size, stride=patch_size//2
104
+ ) # (N, C*ps*ps, M)
105
  P, M = patches.shape[1], patches.shape[2]
106
  p_i = patches.unsqueeze(1) # (N,1,P,M)
107
  p_j = patches.unsqueeze(0) # (1,N,P,M)
 
113
  num = (w.unsqueeze(2) * diff).sum(dim=1) # (N,P,M)
114
  denom = w.sum(dim=1, keepdim=True) + 1e-8 # (N,1,M)
115
  mshift = num / denom # (N,P,M)
 
116
  # fold back
117
  grad = fold(
118
+ mshift / bandwidth**2,
119
+ output_size=(H, W),
120
+ kernel_size=patch_size,
121
+ stride=patch_size//2
122
+ ) # (N, C, H, W)
 
123
  return grad
124
 
125
  class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
 
835
  num_particles: Optional[int] = 4,
836
  gamma_0: Optional[float] = 0.1, # base steering strength
837
  use_KDS = True,
 
838
  patch_size = 16,
839
+ bandwidth = 0.1,
840
  args=None,
841
  ):
842
  r"""
 
1050
  for i, t in enumerate(timesteps):
1051
  with torch.no_grad():
1052
  # pass, if the timestep is larger than start_steps
1053
+ if t > start_steps:
1054
+ print(f'pass {t} steps.')
1055
+ continue
1056
 
1057
  # expand the latents if we are doing classifier free guidance
1058
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
 
1189
  cond_list = []
1190
  img_list = []
1191
 
1192
+ noise_preds.append(model_out)
1193
 
1194
  # Stitch noise predictions for all tiles
1195
  noise_pred = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
 
1226
 
1227
  if use_KDS:
1228
 
1229
+ # 2) Compute x₀ prediction
1230
  beta_t = 1 - self.scheduler.alphas_cumprod[t]
1231
  alpha_t = self.scheduler.alphas_cumprod[t].sqrt()
1232
  sigma_t = beta_t.sqrt()
1233
+ x0_pred = (latents - sigma_t * noise_pred) / alpha_t
1234
+
1235
+ # 3) Apply KDE steering
1236
+ m_shift = kde_grad(x0_pred, patch_size=patch_size, bandwidth=bandwidth)
1237
+ delta_t = gamma_0 * (1 - i / (len(timesteps) - 1))
1238
+ x0_steer = x0_pred + delta_t * m_shift
1239
+ # frac = i / (len(timesteps) - 1)
1240
+ # delta_t = 0.0 if frac < 0.3 else 0.3
1241
+ # x0_steer = x0_pred + delta_t * gamma_0 * m_shift
1242
 
 
 
 
 
1243
 
1244
+ # 4) Recompute “noise” for DDIM step
 
 
 
1245
  noise_pred_kds = (latents - alpha_t * x0_steer) / sigma_t
1246
 
1247
+ # 5) Determine prev alphas
1248
  if i < len(timesteps) - 1:
1249
+ next_t = timesteps[i + 1]
1250
+ alpha_prev = self.scheduler.alphas_cumprod[next_t].sqrt()
1251
  else:
1252
+ alpha_prev = self.scheduler.final_alpha_cumprod.sqrt()
1253
+
1254
  sigma_prev = (1 - alpha_prev**2).sqrt()
1255
 
1256
+ # 6) Form next latent per DDIM
1257
  latents = (
1258
+ alpha_prev * x0_steer
1259
+ + sigma_prev * noise_pred_kds
1260
  ).detach().requires_grad_(True)
1261
  else:
1262
 
1263
  # compute the previous noisy sample x_t -> x_t-1
1264
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1265
 
1266
+ # call the callback, if provided
1267
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1268
+ progress_bar.update()
1269
+ if callback is not None and i % callback_steps == 0:
1270
+ callback(i, t, latents)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1271
 
1272
+ with torch.no_grad():
1273
  # If we do sequential model offloading, let's offload unet and controlnet
1274
  # manually for max memory savings
1275
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: