Spaces:
Paused
Paused
Update clip_slider_pipeline.py
Browse files- clip_slider_pipeline.py +5 -4
clip_slider_pipeline.py
CHANGED
|
@@ -18,7 +18,7 @@ class CLIPSlider:
|
|
| 18 |
):
|
| 19 |
|
| 20 |
self.device = device
|
| 21 |
-
self.pipe = sd_pipe.to(self.device)
|
| 22 |
self.iterations = iterations
|
| 23 |
if target_word != "" or opposite != "":
|
| 24 |
self.avg_diff = self.find_latent_direction(target_word, opposite)
|
|
@@ -280,13 +280,14 @@ class CLIPSliderXL(CLIPSlider):
|
|
| 280 |
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
| 281 |
prompt_embeds_list.append(prompt_embeds)
|
| 282 |
|
| 283 |
-
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
| 284 |
-
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
| 285 |
end_time = time.time()
|
|
|
|
| 286 |
print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
|
| 287 |
torch.manual_seed(seed)
|
| 288 |
start_time = time.time()
|
| 289 |
-
image = self.pipe(prompt_embeds=prompt_embeds
|
| 290 |
**pipeline_kwargs).images[0]
|
| 291 |
end_time = time.time()
|
| 292 |
print(f"generation time - pipe: {end_time - start_time:.2f} ms")
|
|
|
|
| 18 |
):
|
| 19 |
|
| 20 |
self.device = device
|
| 21 |
+
self.pipe = sd_pipe.to(self.device, torch.float16)
|
| 22 |
self.iterations = iterations
|
| 23 |
if target_word != "" or opposite != "":
|
| 24 |
self.avg_diff = self.find_latent_direction(target_word, opposite)
|
|
|
|
| 280 |
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
| 281 |
prompt_embeds_list.append(prompt_embeds)
|
| 282 |
|
| 283 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1).to(torch.float16)
|
| 284 |
+
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1).to(torch.float16)
|
| 285 |
end_time = time.time()
|
| 286 |
+
print("prompt_embeds", prompt_embeds.dtype)
|
| 287 |
print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
|
| 288 |
torch.manual_seed(seed)
|
| 289 |
start_time = time.time()
|
| 290 |
+
image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
|
| 291 |
**pipeline_kwargs).images[0]
|
| 292 |
end_time = time.time()
|
| 293 |
print(f"generation time - pipe: {end_time - start_time:.2f} ms")
|