Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import spaces # [uncomment to use ZeroGPU] | |
| from diffusers import DiffusionPipeline | |
| import torch | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use | |
| if torch.cuda.is_available(): | |
| torch_dtype = torch.float16 | |
| else: | |
| torch_dtype = torch.float32 | |
| pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype) | |
| pipe = pipe.to(device) | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 1024 | |
| # [uncomment to use ZeroGPU] | |
| def generate_images(prompt, seed, steps, pipe, pruned_pipe): | |
| # Run the model and return images directly | |
| g_cpu = torch.Generator("cuda").manual_seed(seed) | |
| original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0] | |
| g_cpu = torch.Generator("cuda").manual_seed(seed) | |
| ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0] | |
| return original_image, ecodiff_image | |
| examples = [ | |
| "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
| "An astronaut riding a green horse", | |
| "A delicious ceviche cheesecake slice", | |
| ] | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 640px; | |
| } | |
| """ | |
| header = """ | |
| # 🌱 Text-to-Image Generation with EcoDiff Pruned SD-XL (20% Pruning Ratio) | |
| # Under Construction!!! | |
| <div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> | |
| <a href="https://arxiv.org/abs/2412.02852"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a> | |
| <a href="https://huggingface.co/zhangyang-0123/EcoDiffPrunedModels"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a> | |
| <a href="https://github.com/YaNgZhAnG-V5/EcoDiff"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a> | |
| </div> | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown(header) | |
| with gr.Row(): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| value="A clock tower floating in a sea of clouds", | |
| scale=3, | |
| ) | |
| seed = gr.Number(label="Seed", value=44, precision=0, scale=1) | |
| steps = gr.Slider( | |
| label="Number of Steps", | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| scale=1, | |
| ) | |
| generate_btn = gr.Button("Generate Images") | |
| gr.Examples( | |
| examples=[ | |
| "A clock tower floating in a sea of clouds", | |
| "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
| "An astronaut riding a green horse", | |
| "A delicious ceviche cheesecake slice", | |
| "A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages", | |
| ], | |
| inputs=[prompt], | |
| ) | |
| with gr.Row(): | |
| original_output = gr.Image(label="Original Output") | |
| ecodiff_output = gr.Image(label="EcoDiff Output") | |
| gr.on( | |
| triggers=[generate_btn.click, prompt.submit], | |
| fn=generate_images, | |
| inputs=[ | |
| prompt, | |
| seed, | |
| steps, | |
| pipe, | |
| pipe, | |
| ], | |
| outputs=[original_output, ecodiff_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |