alexnasa commited on
Commit
897276f
·
verified ·
1 Parent(s): 51fc527

Update pipelines/pipeline_seesr.py

Browse files
Files changed (1) hide show
  1. pipelines/pipeline_seesr.py +20 -19
pipelines/pipeline_seesr.py CHANGED
@@ -1225,38 +1225,39 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
1225
 
1226
  if use_KDS:
1227
 
1228
- # 2) Compute x₀ prediction
1229
  beta_t = 1 - self.scheduler.alphas_cumprod[t]
1230
  alpha_t = self.scheduler.alphas_cumprod[t].sqrt()
1231
  sigma_t = beta_t.sqrt()
1232
- x0_pred = (latents - sigma_t * noise_pred) / alpha_t
1233
-
1234
- # 3) Apply KDE steering
1235
- m_shift = kde_grad(x0_pred, bandwidth=bandwidth)
1236
- delta_t = gamma_0 * (1 - i / (len(timesteps) - 1))
1237
- x0_steer = x0_pred + delta_t * m_shift
1238
- # frac = i / (len(timesteps) - 1)
1239
- # delta_t = 0.0 if frac < 0.3 else 0.3
1240
- # x0_steer = x0_pred + delta_t * gamma_0 * m_shift
1241
 
 
 
1242
 
1243
- # 4) Recompute “noise” for DDIM step
 
 
 
 
 
 
 
 
1244
  noise_pred_kds = (latents - alpha_t * x0_steer) / sigma_t
1245
 
1246
- # 5) Determine prev alphas
1247
  if i < len(timesteps) - 1:
1248
- next_t = timesteps[i + 1]
1249
- alpha_prev = self.scheduler.alphas_cumprod[next_t].sqrt()
1250
  else:
1251
- alpha_prev = self.scheduler.final_alpha_cumprod.sqrt()
1252
-
1253
  sigma_prev = (1 - alpha_prev**2).sqrt()
1254
 
1255
- # 6) Form next latent per DDIM
1256
  latents = (
1257
- alpha_prev * x0_steer
1258
- + sigma_prev * noise_pred_kds
1259
  ).detach().requires_grad_(True)
 
1260
  else:
1261
 
1262
  # compute the previous noisy sample x_t -> x_t-1
 
1225
 
1226
  if use_KDS:
1227
 
1228
+ # 2) Compute x₀ prediction for all particles
1229
  beta_t = 1 - self.scheduler.alphas_cumprod[t]
1230
  alpha_t = self.scheduler.alphas_cumprod[t].sqrt()
1231
  sigma_t = beta_t.sqrt()
1232
+ x0_pred = (latents - sigma_t * noise_pred) / alpha_t # shape [2N, C, H, W]
 
 
 
 
 
 
 
 
1233
 
1234
+ # — split into unconditional vs. conditional
1235
+ x0_uncond, x0_cond = x0_pred.chunk(2, dim=0) # each [N, C, H, W]
1236
 
1237
+ # 3) Apply KDE steering *only* on the conditional batch
1238
+ m_shift_cond = kde_grad(x0_cond, bandwidth=bandwidth) # [N, C, H, W]
1239
+ delta_t = gamma_0 * (1 - i / (len(timesteps) - 1))
1240
+ x0_cond_steer = x0_cond + delta_t * m_shift_cond # steered conditional
1241
+
1242
+ # 4) Recombine the latents: leave uncond untouched, use steered cond
1243
+ x0_steer = torch.cat([x0_uncond, x0_cond_steer], dim=0) # [2N, C, H, W]
1244
+
1245
+ # 5) Recompute “noise” for DDIM step
1246
  noise_pred_kds = (latents - alpha_t * x0_steer) / sigma_t
1247
 
1248
+ # 6) Determine prev alphas and form next latent per DDIM
1249
  if i < len(timesteps) - 1:
1250
+ next_t = timesteps[i + 1]
1251
+ alpha_prev = self.scheduler.alphas_cumprod[next_t].sqrt()
1252
  else:
1253
+ alpha_prev = self.scheduler.final_alpha_cumprod.sqrt()
 
1254
  sigma_prev = (1 - alpha_prev**2).sqrt()
1255
 
 
1256
  latents = (
1257
+ alpha_prev * x0_steer
1258
+ + sigma_prev * noise_pred_kds
1259
  ).detach().requires_grad_(True)
1260
+
1261
  else:
1262
 
1263
  # compute the previous noisy sample x_t -> x_t-1