from diffusers import DiffusionPipeline, LCMScheduler, AutoencoderTiny from compel import Compel, ReturnedEmbeddingsType import torch import os from PIL import Image import numpy as np import gradio as gr import psutil from sfast.compilers.stable_diffusion_pipeline_compiler import ( compile, CompilationConfig, ) SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None) HF_TOKEN = os.environ.get("HF_TOKEN", None) # check if MPS is available OSX only M1/M2/M3 chips mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available() device = torch.device( "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu" ) torch_device = device #torch_dtype = torch.float16 torch_dtype = torch.bfloat16 print(f"SAFETY_CHECKER: {SAFETY_CHECKER}") print(f"device: {device}") if mps_available: device = torch.device("mps") torch_device = "cpu" torch_dtype = torch.float32 model_id = "stabilityai/stable-diffusion-xl-base-1.0" if SAFETY_CHECKER == "True": pipe = DiffusionPipeline.from_pretrained(model_id) else: pipe = DiffusionPipeline.from_pretrained(model_id, safety_checker=None) pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) pipe.load_lora_weights( "latent-consistency/lcm-lora-sdxl", use_auth_token=HF_TOKEN, ) if device.type != "mps": pipe.unet.to(memory_format=torch.channels_last) pipe.to(device=torch_device, dtype=torch_dtype).to(device) # Load LCM LoRA config = CompilationConfig.Default() config.enable_xformers = False config.enable_triton = False config.enable_cuda_graph = False pipe = compile(pipe, config=config) compel_proc = Compel( tokenizer=[pipe.tokenizer, pipe.tokenizer_2], text_encoder=[pipe.text_encoder, pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True], ) def predict( prompt, guidance, steps, seed=1231231, randomize_bt=False, progress=gr.Progress(track_tqdm=True), ): if randomize_bt: seed = np.random.randint(0, 2**32 - 1) generator = torch.manual_seed(seed) prompt_embeds, pooled_prompt_embeds = compel_proc(prompt) results = pipe( prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator, num_inference_steps=steps, guidance_scale=guidance, width=512, height=512, # original_inference_steps=params.lcm_steps, output_type="pil", ) nsfw_content_detected = ( results.nsfw_content_detected[0] if "nsfw_content_detected" in results else False ) if nsfw_content_detected: raise gr.Error("NSFW content detected.") return results.images[0], seed css = """ #container{ margin: 0 auto; max-width: 40rem; } #intro{ max-width: 100%; text-align: center; margin: 0 auto; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="container"): gr.Markdown( """# SDXL in 4 steps with Latent Consistency LoRAs SDXL is loaded with a LCM-LoRA, giving it the super power of doing inference in as little as 4 steps. [Learn more on our blog](https://huggingface.co/blog/lcm_lora) or [technical report](https://huggingface.co/papers/2311.05556). """, elem_id="intro", ) with gr.Row(): with gr.Row(): prompt = gr.Textbox( placeholder="Insert your prompt here:", scale=5, container=False ) generate_bt = gr.Button("Generate", scale=1) image = gr.Image(type="filepath") with gr.Accordion("Advanced options", open=False): guidance = gr.Slider( label="Guidance", minimum=0.0, maximum=5, value=0.3, step=0.001 ) steps = gr.Slider(label="Steps", value=4, minimum=2, maximum=10, step=1) with gr.Row(): seed = gr.Slider( randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1, scale=5, ) with gr.Group(): randomize_bt = gr.Checkbox(label="Randomize", value=False) random_seed = gr.Textbox(show_label=False) with gr.Accordion("Run with diffusers"): gr.Markdown( """## Running LCM-LoRAs it with `diffusers` ```bash pip install diffusers==0.23.0 ``` ```py from diffusers import DiffusionPipeline, LCMScheduler pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0").to("cuda") pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl") #yes, it's a normal LoRA results = pipe( prompt="The spirit of a tamagotchi wandering in the city of Vienna", num_inference_steps=4, guidance_scale=0.0, ) results.images[0] ``` """ ) inputs = [prompt, guidance, steps, seed, randomize_bt] generate_bt.click(fn=predict, inputs=inputs, outputs=[image, random_seed]) demo.queue(api_open=False) demo.launch(show_api=False)