alexnasa commited on
Commit
0d8bd11
·
verified ·
1 Parent(s): 58bb60e

Update pipelines/pipeline_seesr.py

Browse files
Files changed (1) hide show
  1. pipelines/pipeline_seesr.py +19 -0
pipelines/pipeline_seesr.py CHANGED
@@ -1270,6 +1270,25 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
1270
  callback(i, t, latents)
1271
 
1272
  with torch.no_grad():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1273
  # If we do sequential model offloading, let's offload unet and controlnet
1274
  # manually for max memory savings
1275
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
 
1270
  callback(i, t, latents)
1271
 
1272
  with torch.no_grad():
1273
+
1274
+ if use_KDS:
1275
+ # Final-latent selection (once!)
1276
+ # latents shape: [2*N, C, H, W]
1277
+ uncond_latents, cond_latents = latents.chunk(2, dim=0) # each [N, C, H, W]
1278
+ # 1) ensemble mean
1279
+ mean_cond = cond_latents.mean(dim=0, keepdim=True) # [1, C, H, W]
1280
+ # 2) distances
1281
+ dists = ((cond_latents - mean_cond)
1282
+ .view(cond_latents.size(0), -1)
1283
+ .pow(2)
1284
+ .sum(dim=1)) # [N]
1285
+ # 3) best index
1286
+ best_idx = dists.argmin().item()
1287
+ # 4) select that latent (and its uncond pair)
1288
+ best_uncond = uncond_latents[best_idx:best_idx+1]
1289
+ best_cond = cond_latents [best_idx:best_idx+1]
1290
+ latents = torch.cat([best_uncond, best_cond], dim=0) # [2, C, H, W]
1291
+
1292
  # If we do sequential model offloading, let's offload unet and controlnet
1293
  # manually for max memory savings
1294
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: