alexnasa commited on
Commit
d45e23c
·
verified ·
1 Parent(s): 1fe35e7

Update pipelines/pipeline_seesr.py

Browse files
Files changed (1) hide show
  1. pipelines/pipeline_seesr.py +20 -0
pipelines/pipeline_seesr.py CHANGED
@@ -1258,6 +1258,26 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
1258
  alpha_prev * x0_steer
1259
  + sigma_prev * noise_pred_kds
1260
  ).detach().requires_grad_(True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1261
 
1262
  else:
1263
 
 
1258
  alpha_prev * x0_steer
1259
  + sigma_prev * noise_pred_kds
1260
  ).detach().requires_grad_(True)
1261
+
1262
+ uncond_latents, cond_latents = latents.chunk(2, dim=0) # each is [N, C, H, W]
1263
+
1264
+ # 1) Compute ensemble mean of the conditional latents
1265
+ mean_cond = cond_latents.mean(dim=0, keepdim=True) # shape [1, C, H, W]
1266
+
1267
+ # 2) Compute squared distances to the mean for each particle
1268
+ # Flatten each latent to [N, C*H*W], then sum-of-squares
1269
+ dists = ((cond_latents - mean_cond).view(cond_latents.size(0), -1) ** 2).sum(dim=1) # [N]
1270
+
1271
+ # 3) Find the index of the particle closest to the mean
1272
+ best_idx = dists.argmin().item()
1273
+
1274
+ # 4) Select that one latent
1275
+ best_latent = cond_latents[best_idx : best_idx + 1] # shape [1, C, H, W]
1276
+
1277
+ # (Optional) If you need to keep classifier-free guidance structure,
1278
+ # you can reconstruct a 2-sample batch with its uncond pair:
1279
+ best_uncond = uncond_latents[best_idx : best_idx + 1]
1280
+ latents = torch.cat([best_uncond, best_latent], dim=0) # shape [2, C, H, W]
1281
 
1282
  else:
1283