charliebaby2023's picture
Update app_demo.py
aa274d7 verified
raw
history blame
4.42 kB
#!/usr/bin/env python
from __future__ import annotations
import os
import random
import time
import gradio as gr
import numpy as np
import PIL.Image
import torch
from diffusers import StableDiffusionPipeline
from concurrent.futures import ThreadPoolExecutor
import uuid
model_id = "Lykon/dreamshaper-xl-v2-turbo"
DESCRIPTION = '''# Fast Stable Diffusion CPU with Latent Consistency Model
Distilled from [Dreamshaper v7](https://huggingface.co/Lykon/dreamshaper-7) fine‑tune of SD v1-5.
'''
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>running on CPU.</p>"
MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "768"))
DTYPE = torch.float32
# Load pipeline once, disabling NSFW filter at construction time
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
safety_checker=None,
torch_dtype=DTYPE,
use_safetensors=True
).to("cpu")
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
def save_image(img, profile: gr.OAuthProfile | None, metadata: dict):
unique_name = str(uuid.uuid4()) + '.png'
img.save(unique_name)
return unique_name
def save_images(image_array, profile: gr.OAuthProfile | None, metadata: dict):
with ThreadPoolExecutor() as executor:
return list(executor.map(
lambda args: save_image(*args),
zip(image_array, [profile]*len(image_array), [metadata]*len(image_array))
))
def generate(
prompt: str,
seed: int = 0,
width: int = 512,
height: int = 512,
guidance_scale: float = 8.0,
num_inference_steps: int = 4,
num_images: int = 1,
randomize_seed: bool = False,
progress = gr.Progress(track_tqdm=True),
profile: gr.OAuthProfile | None = None,
) -> tuple[list[str], int]:
# prepare seed
seed = randomize_seed_fn(seed, randomize_seed)
torch.manual_seed(seed)
start_time = time.time()
# **Call the pipeline with only supported kwargs:**
outputs = pipe(
prompt=prompt,
negative_prompt="", # required to avoid NoneType in UNet
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images,
output_type="pil",
).images
latency = time.time() - start_time
print(f"Generation took {latency:.2f} seconds")
paths = save_images(
outputs,
profile,
metadata={
"prompt": prompt,
"seed": seed,
"width": width,
"height": height,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
}
)
return paths, seed
examples = [
"A futuristic cityscape at sunset",
"Steampunk airship over mountains",
"Portrait of a cyborg queen, hyper‑detailed",
]
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Group():
with gr.Row():
prompt = gr.Text(
placeholder="Enter your prompt",
show_label=False,
container=False,
)
run_button = gr.Button("Run", scale=0)
gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
grid=[2]
)
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(0, MAX_SEED, value=0, step=1, randomize=True, label="Seed")
randomize_seed = gr.Checkbox(label="Randomize seed across runs", value=True)
with gr.Row():
width = gr.Slider(256, MAX_IMAGE_SIZE, value=512, step=32, label="Width")
height = gr.Slider(256, MAX_IMAGE_SIZE, value=512, step=32, label="Height")
with gr.Row():
guidance_scale = gr.Slider(2.0, 14.0, value=8.0, step=0.1, label="Guidance Scale")
num_inference_steps = gr.Slider(1, 8, value=4, step=1, label="Inference Steps")
num_images = gr.Slider(1, 8, value=1, step=1, label="Number of Images")
gr.Examples(
examples=examples,
inputs=prompt,
outputs=gallery,
fn=generate,
cache_examples=CACHE_EXAMPLES,
)
demo.queue()
demo.launch()