charliebaby2023's picture
Update app_demo.py
17bcc13 verified
raw
history blame
10.4 kB
from __future__ import annotations
from huggingface_hub import HfApi, snapshot_download
from concurrent.futures import ThreadPoolExecutor
import asyncio
import ast
import os
import random
import time
import gradio as gr
import numpy as np
import PIL.Image
import torch
from diffusers import StableDiffusionPipeline
import uuid
from diffusers import DiffusionPipeline
from tqdm import tqdm
from safetensors.torch import load_file
import gradio_user_history as gr_user_history
import cv2
#DESCRIPTION = '''# Fast Stable Diffusion CPU with Latent Consistency Model
#Distilled from [Dreamshaper v7](https://huggingface.co/Lykon/dreamshaper-7) fine‑tune of SD v1-5.
#'''
#if not torch.cuda.is_available():
#DESCRIPTION += "\n<p>running on CPU.</p>"
MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "768"))
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
DTYPE = torch.float32 # torch.float16 works as well, but pictures seem to be a bit worse
api = HfApi()
executor = ThreadPoolExecutor()
model_cache = {}
#custom
model_id = "Lykon/dreamshaper-xl-v2-turbo"
custom_pipe = DiffusionPipeline.from_pretrained(mode_id, custom_pipeline="latent_consistency_txt2img", custom_revision="main")
#1st
pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", custom_pipeline="latent_consistency_txt2img", custom_revision="main")
pipe.to(torch_device="cpu", torch_dtype=DTYPE)
pipe.safety_checker = None
# Load pipeline once, disabling NSFW filter at construction time
pipe = StableDiffusionPipeline.from_pretrained(
model_id, safety_checker=None, torch_dtype=DTYPE, use_safetensors=True).to("cpu")
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
def save_image(img, profile: gr.OAuthProfile | None, metadata: dict):
unique_name = str(uuid.uuid4()) + '.png'
img.save(unique_name)
gr_user_history.save_image(label=metadata["prompt"], image=img, profile=profile, metadata=metadata)
return unique_name
#def save_images(image_array, profile: gr.OAuthProfile | None, metadata: dict):
# with ThreadPoolExecutor() as executor:
# return list(executor.map(
# lambda args: save_image(*args),
# zip(image_array, [profile]*len(image_array), [metadata]*len(image_array))
# ))
def save_images(image_array, profile: gr.OAuthProfile | None, metadata: dict):
paths = []
with ThreadPoolExecutor() as executor:
paths = list(executor.map(save_image, image_array, [profile]*len(image_array), [metadata]*len(image_array)))
return paths
def generate(
prompt: str,
seed: int = 0,
width: int = 512,
height: int = 512,
guidance_scale: float = 8.0,
num_inference_steps: int = 4,
num_images: int = 1,
randomize_seed: bool = False,
progress = gr.Progress(track_tqdm=True),
profile: gr.OAuthProfile | None = None,
) -> tuple[list[str], int]:
# prepare seed
seed = randomize_seed_fn(seed, randomize_seed)
torch.manual_seed(seed)
start_time = time.time()
# **Call the pipeline with only supported kwargs:**
outputs = pipe(
prompt=prompt,
negative_prompt="", # required to avoid NoneType in UNet
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images,
output_type="pil",
lcm_origin_steps=50,
).images
latency = time.time() - start_time
print(f"Generation took {latency:.2f} seconds")
paths = save_images(
outputs,
profile,
metadata={
"prompt": prompt,
"seed": seed,
"width": width,
"height": height,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
}
)
return paths, seed
def validate_and_list_models(hfuser):
try:
models = api.list_models(author=hfuser)
return [model.modelId for model in models if model.pipeline_tag == "text-to-image"]
except Exception:
return []
def parse_user_model_dict(user_model_dict_str):
try:
data = ast.literal_eval(user_model_dict_str)
if isinstance(data, dict) and all(isinstance(v, list) for v in data.values()):
return data
return {}
except Exception:
return {}
def load_model(model_id):
if model_id in model_cache:
return f"{model_id} loaded from cache"
try:
path = snapshot_download(repo_id=model_id, cache_dir="model_cache", token=os.getenv("HF_TOKEN"))
model_cache[model_id] = path
return f"{model_id} loaded successfully"
except Exception as e:
return f"{model_id} failed to load: {str(e)}"
def run_models(models, parallel):
if parallel:
futures = [executor.submit(load_model, m) for m in models]
return [f.result() for f in futures]
else:
return [load_model(m) for m in models]
#with gr.Blocks(css="style.css") as demo:
with gr.Blocks() as demo:
with gr.Row():
gr.HTML(
f"""
<p id="project-links" align="center">
<a href='https://huggingface.co/spaces/charliebaby2023/Fast_Stable_diffusion_CPU/edit/main/app_demo.py'>Edit this app_demo py file</a>
<p> this is currently running the Lykon/dreamshaper-xl-v2-turbo model</p>
<p><fast stable diffusion, CPU</p>
</p>
"""
)
with gr.Column(scale=1):
with gr.Row():
hfuser_input = gr.Textbox(label="Hugging Face Username")
hfuser_models = gr.Dropdown(label="Models from User", choices=["Choose A Model"], value="Choose A Model", multiselect=True, visible=False)
user_model_dict = gr.Textbox(visible=False, label="Dict Input (e.g., {'username': ['model1', 'model2']})")
with gr.Row():
run_btn = gr.Button("Load Models")
with gr.Column(scale=3):
with gr.Row():
parallel_toggle = gr.Checkbox(label="Load in Parallel", value=True)
with gr.Row():
output = gr.Textbox(label="Output", lines=3)
def update_models(hfuser):
if hfuser:
models = validate_and_list_models(hfuser)
label = f"Models found: {len(models)}"
if len(models) > 0:
return gr.update(choices=models, label=label, visible=True)
else:
return gr.update(choices=models, label=label, visible=False)
else:
models = ''
label = ''
return gr.update(choices=models, label=label, visible=False)
def update_from_dict(dict_str):
parsed = parse_user_model_dict(dict_str)
if not parsed:
return gr.update(), gr.update()
hfuser = next(iter(parsed))
models = parsed[hfuser]
label = f"Models found: {len(models)}"
return gr.update(value=hfuser), gr.update(choices=models, value=models, label=label)
#return gr.update(value=hfuser), gr.update(choices=parsed[hfuser], value=parsed[hfuser])
hfuser_input.change(update_models, hfuser_input, hfuser_models)
user_model_dict.change(update_from_dict, user_model_dict, [hfuser_input, hfuser_models])
run_btn.click(run_models, [hfuser_models, parallel_toggle], output)
with gr.Group():
with gr.Row():
prompt = gr.Text(
placeholder="Enter your prompt", show_label=False, container=False,
)
run_button = gr.Button("Run", scale=0)
gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery"
)
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(0, MAX_SEED, value=0, step=1, randomize=True, label="Seed")
randomize_seed = gr.Checkbox(label="Randomize seed across runs", value=True)
with gr.Row():
width = gr.Slider(256, MAX_IMAGE_SIZE, value=512, step=32, label="Width")
height = gr.Slider(256, MAX_IMAGE_SIZE, value=512, step=32, label="Height")
with gr.Row():
guidance_scale = gr.Slider(2.0, 14.0, value=8.0, step=0.1, label="Guidance Scale")
num_inference_steps = gr.Slider(1, 8, value=4, step=1, label="Inference Steps")
num_images = gr.Slider(1, 8, value=1, step=1, label="Number of Images")
with gr.Group():
with gr.Row():
prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, )
run_button = gr.Button("Run", scale=0)
result = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery", grid=[2] )
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(label="Seed",minimum=0,maximum=MAX_SEED,step=1,value=0,randomize=True)
randomize_seed = gr.Checkbox(label="Randomize seed across runs", value=True)
with gr.Row():
width = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512, )
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512,)
with gr.Row():
guidance_scale = gr.Slider(label="Guidance scale for base", minimum=2, maximum=14, step=0.1, value=8.0,)
num_inference_steps = gr.Slider(label="Number of inference steps for base", minimum=1, maximum=8, step=1, value=4,)
with gr.Row():
num_images = gr.Slider(label="Number of images", minimum=1, maximum=8, step=1, value=1, visible=True,)
with gr.Accordion("Past generations", open=False):
gr_user_history.render()
gr.on( triggers=[ prompt.submit, run_button.click, ],
fn=generate,
inputs=[prompt,seed,width,height,guidance_scale,num_inference_steps,num_images,randomize_seed ],
outputs=[result, seed],
api_name="run",
)
if __name__ == "__main__":
demo.queue(api_open=False)
demo.launch()