Spaces:
Running
on
Zero
Running
on
Zero
Update pipelines/pipeline_seesr.py
Browse files- pipelines/pipeline_seesr.py +19 -0
pipelines/pipeline_seesr.py
CHANGED
|
@@ -1270,6 +1270,25 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
| 1270 |
callback(i, t, latents)
|
| 1271 |
|
| 1272 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1273 |
# If we do sequential model offloading, let's offload unet and controlnet
|
| 1274 |
# manually for max memory savings
|
| 1275 |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
|
|
|
| 1270 |
callback(i, t, latents)
|
| 1271 |
|
| 1272 |
with torch.no_grad():
|
| 1273 |
+
|
| 1274 |
+
if use_KDS:
|
| 1275 |
+
# Final-latent selection (once!)
|
| 1276 |
+
# latents shape: [2*N, C, H, W]
|
| 1277 |
+
uncond_latents, cond_latents = latents.chunk(2, dim=0) # each [N, C, H, W]
|
| 1278 |
+
# 1) ensemble mean
|
| 1279 |
+
mean_cond = cond_latents.mean(dim=0, keepdim=True) # [1, C, H, W]
|
| 1280 |
+
# 2) distances
|
| 1281 |
+
dists = ((cond_latents - mean_cond)
|
| 1282 |
+
.view(cond_latents.size(0), -1)
|
| 1283 |
+
.pow(2)
|
| 1284 |
+
.sum(dim=1)) # [N]
|
| 1285 |
+
# 3) best index
|
| 1286 |
+
best_idx = dists.argmin().item()
|
| 1287 |
+
# 4) select that latent (and its uncond pair)
|
| 1288 |
+
best_uncond = uncond_latents[best_idx:best_idx+1]
|
| 1289 |
+
best_cond = cond_latents [best_idx:best_idx+1]
|
| 1290 |
+
latents = torch.cat([best_uncond, best_cond], dim=0) # [2, C, H, W]
|
| 1291 |
+
|
| 1292 |
# If we do sequential model offloading, let's offload unet and controlnet
|
| 1293 |
# manually for max memory savings
|
| 1294 |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|