Disty0 commited on
Commit
1ad6f7a
·
1 Parent(s): fe9a8f4
Files changed (1) hide show
  1. pipeline.py +3 -9
pipeline.py CHANGED
@@ -332,9 +332,7 @@ class Zero123PlusPipeline(diffusers.StableDiffusionPipeline):
332
  self,
333
  image: Image.Image = None,
334
  prompt = "",
335
- negative_prompt = "",
336
  prompt_embeds = None,
337
- negative_prompt_embeds = None,
338
  *args,
339
  num_images_per_prompt: Optional[int] = 1,
340
  guidance_scale=4.0,
@@ -368,19 +366,15 @@ class Zero123PlusPipeline(diffusers.StableDiffusionPipeline):
368
  global_embeds = encoded.image_embeds
369
  global_embeds = global_embeds.unsqueeze(-2)
370
 
371
- if prompt_embeds is None and negative_prompt_embeds is None:
372
  prompt_embeds, negative_prompt_embeds = self.encode_prompt(
373
  prompt,
374
  self.device,
375
  num_images_per_prompt,
376
  False,
377
- negative_prompt=negative_prompt,
378
  )
379
- encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds])
380
- ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
381
- encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
382
- else:
383
- encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds])
384
  cak = dict(cond_lat=cond_lat)
385
  if hasattr(self.unet, "controlnet"):
386
  cak['control_depth'] = depth_image
 
332
  self,
333
  image: Image.Image = None,
334
  prompt = "",
 
335
  prompt_embeds = None,
 
336
  *args,
337
  num_images_per_prompt: Optional[int] = 1,
338
  guidance_scale=4.0,
 
366
  global_embeds = encoded.image_embeds
367
  global_embeds = global_embeds.unsqueeze(-2)
368
 
369
+ if prompt_embeds is None:
370
  prompt_embeds, negative_prompt_embeds = self.encode_prompt(
371
  prompt,
372
  self.device,
373
  num_images_per_prompt,
374
  False,
 
375
  )
376
+ ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
377
+ encoder_hidden_states = torch.cat([prompt_embeds, torch.zeros(prompt_embeds.shape)]) + global_embeds * ramp
 
 
 
378
  cak = dict(cond_lat=cond_lat)
379
  if hasattr(self.unet, "controlnet"):
380
  cak['control_depth'] = depth_image