Disty0
commited on
Commit
·
1ad6f7a
1
Parent(s):
fe9a8f4
Test
Browse files- pipeline.py +3 -9
pipeline.py
CHANGED
@@ -332,9 +332,7 @@ class Zero123PlusPipeline(diffusers.StableDiffusionPipeline):
|
|
332 |
self,
|
333 |
image: Image.Image = None,
|
334 |
prompt = "",
|
335 |
-
negative_prompt = "",
|
336 |
prompt_embeds = None,
|
337 |
-
negative_prompt_embeds = None,
|
338 |
*args,
|
339 |
num_images_per_prompt: Optional[int] = 1,
|
340 |
guidance_scale=4.0,
|
@@ -368,19 +366,15 @@ class Zero123PlusPipeline(diffusers.StableDiffusionPipeline):
|
|
368 |
global_embeds = encoded.image_embeds
|
369 |
global_embeds = global_embeds.unsqueeze(-2)
|
370 |
|
371 |
-
if prompt_embeds is None
|
372 |
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
373 |
prompt,
|
374 |
self.device,
|
375 |
num_images_per_prompt,
|
376 |
False,
|
377 |
-
negative_prompt=negative_prompt,
|
378 |
)
|
379 |
-
|
380 |
-
|
381 |
-
encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
|
382 |
-
else:
|
383 |
-
encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds])
|
384 |
cak = dict(cond_lat=cond_lat)
|
385 |
if hasattr(self.unet, "controlnet"):
|
386 |
cak['control_depth'] = depth_image
|
|
|
332 |
self,
|
333 |
image: Image.Image = None,
|
334 |
prompt = "",
|
|
|
335 |
prompt_embeds = None,
|
|
|
336 |
*args,
|
337 |
num_images_per_prompt: Optional[int] = 1,
|
338 |
guidance_scale=4.0,
|
|
|
366 |
global_embeds = encoded.image_embeds
|
367 |
global_embeds = global_embeds.unsqueeze(-2)
|
368 |
|
369 |
+
if prompt_embeds is None:
|
370 |
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
371 |
prompt,
|
372 |
self.device,
|
373 |
num_images_per_prompt,
|
374 |
False,
|
|
|
375 |
)
|
376 |
+
ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
|
377 |
+
encoder_hidden_states = torch.cat([prompt_embeds, torch.zeros(prompt_embeds.shape)]) + global_embeds * ramp
|
|
|
|
|
|
|
378 |
cak = dict(cond_lat=cond_lat)
|
379 |
if hasattr(self.unet, "controlnet"):
|
380 |
cak['control_depth'] = depth_image
|