benchmark / app.py
patrickvonplaten's picture
up
5e0a11c
raw
history blame
1.42 kB
import gradio as gr
import torch
from diffusers import AutoPipelineForText2Image
import time
USE_TORCH_COMPILE = False
dtype = torch.float16
device = torch.device("cuda:0")
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", variant="fp16", torch_dtype=dtype)
pipeline.to(device)
if USE_TORCH_COMPILE:
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
def generate(prompt_len: int, num_images_per_prompt: int = 1):
prompt = prompt_len * "a"
num_inference_steps = 40
start_time = time.time()
pipeline(prompt, num_images_per_prompt=num_images_per_prompt, num_inference_steps=num_inference_steps).images
end_time = time.time()
print(f"For {num_inference_steps} steps", end_time - start_time)
print("Avg per step", (end_time - start_time) / num_inference_steps)
with gr.Blocks(css="style.css") as demo:
batch_size = gr.Slider(
label="Batch size",
minimum=0,
maximum=8,
step=1,
value=1,
)
prompt_len = gr.Slider(
label="Prompt len",
minimum=1,
maximum=77,
step=20,
value=1,
)
btn = gr.Button("Benchmark!").style(
margin=False,
rounded=(False, True, True, False),
full_width=False,
)
btn.click(fn=generate, inputs=[batch_size, prompt_len])
demo.launch(share=True)