Spaces:
Running
on
Zero
Running
on
Zero
Update pipelines/pipeline_seesr.py
Browse files- pipelines/pipeline_seesr.py +25 -27
pipelines/pipeline_seesr.py
CHANGED
@@ -1258,39 +1258,37 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
1258 |
alpha_prev * x0_steer
|
1259 |
+ sigma_prev * noise_pred_kds
|
1260 |
).detach().requires_grad_(True)
|
1261 |
-
|
1262 |
-
uncond_latents, cond_latents = latents.chunk(2, dim=0) # each is [N, C, H, W]
|
1263 |
-
|
1264 |
-
# 1) Compute ensemble mean of the conditional latents
|
1265 |
-
mean_cond = cond_latents.mean(dim=0, keepdim=True) # shape [1, C, H, W]
|
1266 |
-
|
1267 |
-
# 2) Compute squared distances to the mean for each particle
|
1268 |
-
# Flatten each latent to [N, C*H*W], then sum-of-squares
|
1269 |
-
dists = ((cond_latents - mean_cond).view(cond_latents.size(0), -1) ** 2).sum(dim=1) # [N]
|
1270 |
-
|
1271 |
-
# 3) Find the index of the particle closest to the mean
|
1272 |
-
best_idx = dists.argmin().item()
|
1273 |
-
|
1274 |
-
# 4) Select that one latent
|
1275 |
-
best_latent = cond_latents[best_idx : best_idx + 1] # shape [1, C, H, W]
|
1276 |
-
|
1277 |
-
# (Optional) If you need to keep classifier-free guidance structure,
|
1278 |
-
# you can reconstruct a 2-sample batch with its uncond pair:
|
1279 |
-
best_uncond = uncond_latents[best_idx : best_idx + 1]
|
1280 |
-
latents = torch.cat([best_uncond, best_latent], dim=0) # shape [2, C, H, W]
|
1281 |
-
|
1282 |
else:
|
1283 |
|
1284 |
# compute the previous noisy sample x_t -> x_t-1
|
1285 |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1286 |
|
1287 |
-
# call the callback, if provided
|
1288 |
-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1289 |
-
progress_bar.update()
|
1290 |
-
if callback is not None and i % callback_steps == 0:
|
1291 |
-
callback(i, t, latents)
|
1292 |
-
|
1293 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1294 |
# If we do sequential model offloading, let's offload unet and controlnet
|
1295 |
# manually for max memory savings
|
1296 |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
|
|
1258 |
alpha_prev * x0_steer
|
1259 |
+ sigma_prev * noise_pred_kds
|
1260 |
).detach().requires_grad_(True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1261 |
else:
|
1262 |
|
1263 |
# compute the previous noisy sample x_t -> x_t-1
|
1264 |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1266 |
with torch.no_grad():
|
1267 |
+
if use_KDS:
|
1268 |
+
# Final-latent selection (once!)
|
1269 |
+
# latents shape: [2*N, C, H, W]
|
1270 |
+
uncond_latents, cond_latents = latents.chunk(2, dim=0) # each [N, C, H, W]
|
1271 |
+
# 1) ensemble mean
|
1272 |
+
mean_cond = cond_latents.mean(dim=0, keepdim=True) # [1, C, H, W]
|
1273 |
+
# 2) distances
|
1274 |
+
dists = ((cond_latents - mean_cond)
|
1275 |
+
.view(cond_latents.size(0), -1)
|
1276 |
+
.pow(2)
|
1277 |
+
.sum(dim=1)) # [N]
|
1278 |
+
# 3) best index
|
1279 |
+
best_idx = dists.argmin().item()
|
1280 |
+
# 4) select that latent (and its uncond pair)
|
1281 |
+
best_uncond = uncond_latents[best_idx:best_idx+1]
|
1282 |
+
best_cond = cond_latents [best_idx:best_idx+1]
|
1283 |
+
latents = torch.cat([best_uncond, best_cond], dim=0) # [2, C, H, W]
|
1284 |
+
|
1285 |
+
# call the callback, if provided
|
1286 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1287 |
+
progress_bar.update()
|
1288 |
+
if callback is not None and i % callback_steps == 0:
|
1289 |
+
callback(i, t, latents)
|
1290 |
+
|
1291 |
+
|
1292 |
# If we do sequential model offloading, let's offload unet and controlnet
|
1293 |
# manually for max memory savings
|
1294 |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|