Spaces:
Running
on
Zero
Running
on
Zero
Update pipelines/pipeline_seesr.py
Browse files- pipelines/pipeline_seesr.py +19 -0
pipelines/pipeline_seesr.py
CHANGED
@@ -1270,6 +1270,25 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
1270 |
callback(i, t, latents)
|
1271 |
|
1272 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1273 |
# If we do sequential model offloading, let's offload unet and controlnet
|
1274 |
# manually for max memory savings
|
1275 |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
|
|
1270 |
callback(i, t, latents)
|
1271 |
|
1272 |
with torch.no_grad():
|
1273 |
+
|
1274 |
+
if use_KDS:
|
1275 |
+
# Final-latent selection (once!)
|
1276 |
+
# latents shape: [2*N, C, H, W]
|
1277 |
+
uncond_latents, cond_latents = latents.chunk(2, dim=0) # each [N, C, H, W]
|
1278 |
+
# 1) ensemble mean
|
1279 |
+
mean_cond = cond_latents.mean(dim=0, keepdim=True) # [1, C, H, W]
|
1280 |
+
# 2) distances
|
1281 |
+
dists = ((cond_latents - mean_cond)
|
1282 |
+
.view(cond_latents.size(0), -1)
|
1283 |
+
.pow(2)
|
1284 |
+
.sum(dim=1)) # [N]
|
1285 |
+
# 3) best index
|
1286 |
+
best_idx = dists.argmin().item()
|
1287 |
+
# 4) select that latent (and its uncond pair)
|
1288 |
+
best_uncond = uncond_latents[best_idx:best_idx+1]
|
1289 |
+
best_cond = cond_latents [best_idx:best_idx+1]
|
1290 |
+
latents = torch.cat([best_uncond, best_cond], dim=0) # [2, C, H, W]
|
1291 |
+
|
1292 |
# If we do sequential model offloading, let's offload unet and controlnet
|
1293 |
# manually for max memory savings
|
1294 |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|