benchmark / app.py
patrickvonplaten's picture
Update app.py
a094439
raw
history blame
1.24 kB
import gradio as gr
import torch
from diffusers import AutoPipelineForText2Image
import time
USE_TORCH_COMPILE = False
dtype = torch.float16
device = torch.device("cuda:0")
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", variant="fp16", torch_dtype=dtype)
pipeline.to(device)
if USE_TORCH_COMPILE:
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
def generate(num_images_per_prompt: int = 1):
prompt = 77 * "a"
num_inference_steps = 40
start_time = time.time()
pipeline(prompt, num_images_per_prompt=num_images_per_prompt, num_inference_steps=num_inference_steps).images
end_time = time.time()
print(f"For {num_inference_steps} steps", end_time - start_time)
print("Avg per step", (end_time - start_time) / num_inference_steps)
with gr.Blocks(css="style.css") as demo:
batch_size = gr.Slider(
label="Batch size",
minimum=0,
maximum=16,
step=1,
value=1,
)
btn = gr.Button("Benchmark!").style(
margin=False,
rounded=(False, True, True, False),
full_width=False,
)
btn.click(fn=generate, inputs=[batch_size])
demo.launch()