alexnasa commited on
Commit
a37dbfb
·
verified ·
1 Parent(s): 8a56135

Update pipelines/pipeline_seesr.py

Browse files
Files changed (1) hide show
  1. pipelines/pipeline_seesr.py +25 -27
pipelines/pipeline_seesr.py CHANGED
@@ -1258,39 +1258,37 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
1258
  alpha_prev * x0_steer
1259
  + sigma_prev * noise_pred_kds
1260
  ).detach().requires_grad_(True)
1261
-
1262
- uncond_latents, cond_latents = latents.chunk(2, dim=0) # each is [N, C, H, W]
1263
-
1264
- # 1) Compute ensemble mean of the conditional latents
1265
- mean_cond = cond_latents.mean(dim=0, keepdim=True) # shape [1, C, H, W]
1266
-
1267
- # 2) Compute squared distances to the mean for each particle
1268
- # Flatten each latent to [N, C*H*W], then sum-of-squares
1269
- dists = ((cond_latents - mean_cond).view(cond_latents.size(0), -1) ** 2).sum(dim=1) # [N]
1270
-
1271
- # 3) Find the index of the particle closest to the mean
1272
- best_idx = dists.argmin().item()
1273
-
1274
- # 4) Select that one latent
1275
- best_latent = cond_latents[best_idx : best_idx + 1] # shape [1, C, H, W]
1276
-
1277
- # (Optional) If you need to keep classifier-free guidance structure,
1278
- # you can reconstruct a 2-sample batch with its uncond pair:
1279
- best_uncond = uncond_latents[best_idx : best_idx + 1]
1280
- latents = torch.cat([best_uncond, best_latent], dim=0) # shape [2, C, H, W]
1281
-
1282
  else:
1283
 
1284
  # compute the previous noisy sample x_t -> x_t-1
1285
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1286
 
1287
- # call the callback, if provided
1288
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1289
- progress_bar.update()
1290
- if callback is not None and i % callback_steps == 0:
1291
- callback(i, t, latents)
1292
-
1293
  with torch.no_grad():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1294
  # If we do sequential model offloading, let's offload unet and controlnet
1295
  # manually for max memory savings
1296
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
 
1258
  alpha_prev * x0_steer
1259
  + sigma_prev * noise_pred_kds
1260
  ).detach().requires_grad_(True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1261
  else:
1262
 
1263
  # compute the previous noisy sample x_t -> x_t-1
1264
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1265
 
 
 
 
 
 
 
1266
  with torch.no_grad():
1267
+ if use_KDS:
1268
+ # Final-latent selection (once!)
1269
+ # latents shape: [2*N, C, H, W]
1270
+ uncond_latents, cond_latents = latents.chunk(2, dim=0) # each [N, C, H, W]
1271
+ # 1) ensemble mean
1272
+ mean_cond = cond_latents.mean(dim=0, keepdim=True) # [1, C, H, W]
1273
+ # 2) distances
1274
+ dists = ((cond_latents - mean_cond)
1275
+ .view(cond_latents.size(0), -1)
1276
+ .pow(2)
1277
+ .sum(dim=1)) # [N]
1278
+ # 3) best index
1279
+ best_idx = dists.argmin().item()
1280
+ # 4) select that latent (and its uncond pair)
1281
+ best_uncond = uncond_latents[best_idx:best_idx+1]
1282
+ best_cond = cond_latents [best_idx:best_idx+1]
1283
+ latents = torch.cat([best_uncond, best_cond], dim=0) # [2, C, H, W]
1284
+
1285
+ # call the callback, if provided
1286
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1287
+ progress_bar.update()
1288
+ if callback is not None and i % callback_steps == 0:
1289
+ callback(i, t, latents)
1290
+
1291
+
1292
  # If we do sequential model offloading, let's offload unet and controlnet
1293
  # manually for max memory savings
1294
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: