alexnasa commited on
Commit
65bed02
·
verified ·
1 Parent(s): 7830888

Update pipelines/pipeline_seesr.py

Browse files
Files changed (1) hide show
  1. pipelines/pipeline_seesr.py +37 -4
pipelines/pipeline_seesr.py CHANGED
@@ -778,7 +778,6 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
778
  return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1))
779
 
780
  @perfcount
781
- @torch.no_grad()
782
  @replace_example_docstring(EXAMPLE_DOC_STRING)
783
  def __call__(
784
  self,
@@ -808,7 +807,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
808
  ram_encoder_hidden_states=None,
809
  latent_tiled_size=320,
810
  latent_tiled_overlap=4,
811
- args=None
 
812
  ):
813
  r"""
814
  Function invoked when calling the pipeline for generation.
@@ -996,6 +996,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
996
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
997
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
998
 
 
 
 
999
  # 8. Denoising loop
1000
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1001
  with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -1183,9 +1186,39 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
1183
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1184
 
1185
 
 
1186
 
1187
- # compute the previous noisy sample x_t -> x_t-1
1188
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1189
 
1190
  # call the callback, if provided
1191
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
 
778
  return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1))
779
 
780
  @perfcount
 
781
  @replace_example_docstring(EXAMPLE_DOC_STRING)
782
  def __call__(
783
  self,
 
807
  ram_encoder_hidden_states=None,
808
  latent_tiled_size=320,
809
  latent_tiled_overlap=4,
810
+ use_KDS=True,
811
+ args=None,
812
  ):
813
  r"""
814
  Function invoked when calling the pipeline for generation.
 
996
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
997
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
998
 
999
+ if use_KDS:
1000
+ latents.requires_grad_(True)
1001
+
1002
  # 8. Denoising loop
1003
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1004
  with self.progress_bar(total=num_inference_steps) as progress_bar:
 
1186
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1187
 
1188
 
1189
+ if use_KDS:
1190
 
1191
+ # 2) Compute x₀ prediction
1192
+ beta_t = 1 - self.scheduler.alphas_cumprod[t]
1193
+ alpha_t = self.scheduler.alphas_cumprod[t].sqrt()
1194
+ sigma_t = beta_t.sqrt()
1195
+ x0_pred = (latents - sigma_t * noise_pred) / alpha_t
1196
+
1197
+ # 3) Apply KDE steering
1198
+ m_shift = kde_grad(x0_pred)
1199
+ delta_t = gamma_0 * (1 - i / (len(timesteps_tensor) - 1))
1200
+ x0_steer = x0_pred + delta_t * m_shift
1201
+
1202
+ # 4) Recompute “noise” for DDIM step
1203
+ noise_pred_kds = (latents - alpha_t * x0_steer) / sigma_t
1204
+
1205
+ # 5) Determine prev alphas
1206
+ if i < len(timesteps_tensor) - 1:
1207
+ next_t = timesteps_tensor[i + 1]
1208
+ alpha_prev = self.scheduler.alphas_cumprod[next_t].sqrt()
1209
+ else:
1210
+ alpha_prev = self.scheduler.final_alpha_cumprod.sqrt()
1211
+
1212
+ sigma_prev = (1 - alpha_prev**2).sqrt()
1213
+
1214
+ # 6) Form next latent per DDIM
1215
+ latents = (
1216
+ alpha_prev * x0_steer
1217
+ + sigma_prev * noise_pred_kds
1218
+ ).detach().requires_grad_(True)
1219
+ else:
1220
+ # compute the previous noisy sample x_t -> x_t-1
1221
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0].requires_grad_(False)
1222
 
1223
  # call the callback, if provided
1224
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):