Spaces:
Runtime error
Runtime error
Update pipeline.py
Browse files- pipeline.py +6 -1
pipeline.py
CHANGED
|
@@ -42,6 +42,7 @@ if is_torch_xla_available():
|
|
| 42 |
else:
|
| 43 |
XLA_AVAILABLE = False
|
| 44 |
|
|
|
|
| 45 |
|
| 46 |
# Constants for shift calculation
|
| 47 |
BASE_SEQ_LEN = 256
|
|
@@ -188,6 +189,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 188 |
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
| 189 |
)
|
| 190 |
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
|
|
|
| 191 |
|
| 192 |
# Use pooled output of CLIPTextModel
|
| 193 |
prompt_embeds = prompt_embeds.pooler_output
|
|
@@ -196,8 +198,11 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 196 |
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 197 |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
| 198 |
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
-
return prompt_embeds
|
| 201 |
|
| 202 |
def encode_prompt(
|
| 203 |
self,
|
|
|
|
| 42 |
else:
|
| 43 |
XLA_AVAILABLE = False
|
| 44 |
|
| 45 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 46 |
|
| 47 |
# Constants for shift calculation
|
| 48 |
BASE_SEQ_LEN = 256
|
|
|
|
| 189 |
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
| 190 |
)
|
| 191 |
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
| 192 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
| 193 |
|
| 194 |
# Use pooled output of CLIPTextModel
|
| 195 |
prompt_embeds = prompt_embeds.pooler_output
|
|
|
|
| 198 |
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 199 |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
| 200 |
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 201 |
+
|
| 202 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 203 |
+
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 204 |
|
| 205 |
+
return prompt_embeds, pooled_prompt_embeds
|
| 206 |
|
| 207 |
def encode_prompt(
|
| 208 |
self,
|