Spaces:
Runtime error
Runtime error
Update stable_cascade.py
Browse files- stable_cascade.py +13 -0
stable_cascade.py
CHANGED
|
@@ -13,6 +13,7 @@ def generate_images(
|
|
| 13 |
height=1024,
|
| 14 |
width=1024,
|
| 15 |
guidance_scale=4.0,
|
|
|
|
| 16 |
num_images_per_prompt=1,
|
| 17 |
prior_inference_steps=20,
|
| 18 |
decoder_inference_steps=10
|
|
@@ -30,10 +31,12 @@ def generate_images(
|
|
| 30 |
Returns:
|
| 31 |
- List[PIL.Image]: A list of generated PIL Image objects.
|
| 32 |
"""
|
|
|
|
| 33 |
|
| 34 |
# Generate image embeddings using the prior model
|
| 35 |
prior_output = prior(
|
| 36 |
prompt=prompt,
|
|
|
|
| 37 |
height=height,
|
| 38 |
width=width,
|
| 39 |
negative_prompt=negative_prompt,
|
|
@@ -46,6 +49,7 @@ def generate_images(
|
|
| 46 |
decoder_output = decoder(
|
| 47 |
image_embeddings=prior_output.image_embeddings.half(),
|
| 48 |
prompt=prompt,
|
|
|
|
| 49 |
negative_prompt=negative_prompt,
|
| 50 |
guidance_scale=0.0, # Guidance scale typically set to 0 for decoder as guidance is applied in the prior
|
| 51 |
output_type="pil",
|
|
@@ -70,6 +74,15 @@ def web_demo():
|
|
| 70 |
placeholder="Negative Prompt",
|
| 71 |
show_label=False,
|
| 72 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
with gr.Row():
|
| 74 |
with gr.Column():
|
| 75 |
text2image_num_images_per_prompt = gr.Slider(
|
|
|
|
| 13 |
height=1024,
|
| 14 |
width=1024,
|
| 15 |
guidance_scale=4.0,
|
| 16 |
+
seed=42,
|
| 17 |
num_images_per_prompt=1,
|
| 18 |
prior_inference_steps=20,
|
| 19 |
decoder_inference_steps=10
|
|
|
|
| 31 |
Returns:
|
| 32 |
- List[PIL.Image]: A list of generated PIL Image objects.
|
| 33 |
"""
|
| 34 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
| 35 |
|
| 36 |
# Generate image embeddings using the prior model
|
| 37 |
prior_output = prior(
|
| 38 |
prompt=prompt,
|
| 39 |
+
generator=generator,
|
| 40 |
height=height,
|
| 41 |
width=width,
|
| 42 |
negative_prompt=negative_prompt,
|
|
|
|
| 49 |
decoder_output = decoder(
|
| 50 |
image_embeddings=prior_output.image_embeddings.half(),
|
| 51 |
prompt=prompt,
|
| 52 |
+
generator=generator,
|
| 53 |
negative_prompt=negative_prompt,
|
| 54 |
guidance_scale=0.0, # Guidance scale typically set to 0 for decoder as guidance is applied in the prior
|
| 55 |
output_type="pil",
|
|
|
|
| 74 |
placeholder="Negative Prompt",
|
| 75 |
show_label=False,
|
| 76 |
)
|
| 77 |
+
|
| 78 |
+
text2image_num_images_per_prompt = gr.Slider(
|
| 79 |
+
minimum=1,
|
| 80 |
+
maximum=1000000,
|
| 81 |
+
step=10,
|
| 82 |
+
value=42,
|
| 83 |
+
label="Seed",
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
with gr.Row():
|
| 87 |
with gr.Column():
|
| 88 |
text2image_num_images_per_prompt = gr.Slider(
|