Update src/flux/generate.py
Browse files- src/flux/generate.py +2 -1
src/flux/generate.py
CHANGED
@@ -515,6 +515,7 @@ def generate(
|
|
515 |
def generate_from_test_sample(
|
516 |
test_sample, pipe, config,
|
517 |
num_images=1,
|
|
|
518 |
vae_skip_iter: str = None,
|
519 |
target_height: int = None,
|
520 |
target_width: int = None,
|
@@ -708,7 +709,6 @@ def generate_from_test_sample(
|
|
708 |
return delta_emb, delta_emb_pblock, delta_emb_mask, \
|
709 |
text_cond_mask, delta_start_ends, condition_latents, condition_ids
|
710 |
|
711 |
-
num_inference_steps = 8 # FIXME: harcoded here
|
712 |
num_channels_latents = pipe.transformer.config.in_channels // 4
|
713 |
|
714 |
# set timesteps
|
@@ -801,6 +801,7 @@ def generate_from_test_sample(
|
|
801 |
result_img = generate(
|
802 |
pipe,
|
803 |
prompt=prompt,
|
|
|
804 |
max_sequence_length=max_length,
|
805 |
vae_conditions=conditions,
|
806 |
generator=generator,
|
|
|
515 |
def generate_from_test_sample(
|
516 |
test_sample, pipe, config,
|
517 |
num_images=1,
|
518 |
+
num_inference_steps = 8,
|
519 |
vae_skip_iter: str = None,
|
520 |
target_height: int = None,
|
521 |
target_width: int = None,
|
|
|
709 |
return delta_emb, delta_emb_pblock, delta_emb_mask, \
|
710 |
text_cond_mask, delta_start_ends, condition_latents, condition_ids
|
711 |
|
|
|
712 |
num_channels_latents = pipe.transformer.config.in_channels // 4
|
713 |
|
714 |
# set timesteps
|
|
|
801 |
result_img = generate(
|
802 |
pipe,
|
803 |
prompt=prompt,
|
804 |
+
num_inference_steps=num_inference_steps,
|
805 |
max_sequence_length=max_length,
|
806 |
vae_conditions=conditions,
|
807 |
generator=generator,
|