Spaces:
Running
on
Zero
Running
on
Zero
Update pipelines/pipeline_seesr.py
Browse files- pipelines/pipeline_seesr.py +20 -0
pipelines/pipeline_seesr.py
CHANGED
@@ -1258,6 +1258,26 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
1258 |
alpha_prev * x0_steer
|
1259 |
+ sigma_prev * noise_pred_kds
|
1260 |
).detach().requires_grad_(True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1261 |
|
1262 |
else:
|
1263 |
|
|
|
1258 |
alpha_prev * x0_steer
|
1259 |
+ sigma_prev * noise_pred_kds
|
1260 |
).detach().requires_grad_(True)
|
1261 |
+
|
1262 |
+
uncond_latents, cond_latents = latents.chunk(2, dim=0) # each is [N, C, H, W]
|
1263 |
+
|
1264 |
+
# 1) Compute ensemble mean of the conditional latents
|
1265 |
+
mean_cond = cond_latents.mean(dim=0, keepdim=True) # shape [1, C, H, W]
|
1266 |
+
|
1267 |
+
# 2) Compute squared distances to the mean for each particle
|
1268 |
+
# Flatten each latent to [N, C*H*W], then sum-of-squares
|
1269 |
+
dists = ((cond_latents - mean_cond).view(cond_latents.size(0), -1) ** 2).sum(dim=1) # [N]
|
1270 |
+
|
1271 |
+
# 3) Find the index of the particle closest to the mean
|
1272 |
+
best_idx = dists.argmin().item()
|
1273 |
+
|
1274 |
+
# 4) Select that one latent
|
1275 |
+
best_latent = cond_latents[best_idx : best_idx + 1] # shape [1, C, H, W]
|
1276 |
+
|
1277 |
+
# (Optional) If you need to keep classifier-free guidance structure,
|
1278 |
+
# you can reconstruct a 2-sample batch with its uncond pair:
|
1279 |
+
best_uncond = uncond_latents[best_idx : best_idx + 1]
|
1280 |
+
latents = torch.cat([best_uncond, best_latent], dim=0) # shape [2, C, H, W]
|
1281 |
|
1282 |
else:
|
1283 |
|