Spaces:
Sleeping
Sleeping
File size: 5,936 Bytes
5c4b5eb d9c0a5a 5c4b5eb 26cfd4b 5c4b5eb 26cfd4b 5c4b5eb aeccac3 a2efe11 aeccac3 a2efe11 5c4b5eb 26cfd4b f867da7 26cfd4b cedad44 26cfd4b cedad44 26cfd4b cedad44 542ea3b 26cfd4b d9c0a5a 26cfd4b cedad44 26cfd4b cedad44 542ea3b cedad44 26cfd4b cedad44 aeccac3 5c4b5eb e07cbb0 8bfad22 5c4b5eb cedad44 26cfd4b 5c4b5eb a2efe11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import time
#import spaces
import gradio as gr
import torch
import diffusers
from utils import patch_attention_proc, remove_patch
import math
import numpy as np
from PIL import Image
from threading import Semaphore
# Globals
css = """
h1 {
text-align: center;
display: block;
}
"""
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
# Pipeline
pipe = diffusers.StableDiffusionPipeline.from_pretrained("Lykon/DreamShaper").to(device, torch.float16)
pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.safety_checker = None
semaphore = Semaphore() # for preventing collisions of two simultaneous button presses
#@spaces.GPU
def generate_baseline(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):
semaphore.acquire()
downsample_factor = 2
ratio = 0.38
merge_method = "downsample" if method == "todo" else "similarity"
merge_tokens = "keys/values" if method == "todo" else "all"
if height_width == 1024:
downsample_factor = 2
ratio = 0.75
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
elif height_width == 1536:
downsample_factor = 3
ratio = 0.89
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
elif height_width == 2048:
downsample_factor = 4
ratio = 0.9375
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
token_merge_args = {"ratio": ratio,
"merge_tokens": merge_tokens,
"merge_method": merge_method,
"downsample_method": "nearest",
"downsample_factor": downsample_factor,
"timestep_threshold_switch": 0.0,
"timestep_threshold_stop": 0.0,
"downsample_factor_level_2": downsample_factor_level_2,
"ratio_level_2": ratio_level_2
}
torch.manual_seed(seed)
start_time_base = time.time()
remove_patch(pipe)
base_img = pipe(prompt,
num_inference_steps=steps, height=height_width, width=height_width,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale).images[0]
end_time_base = time.time()
result = f"Baseline Runtime: {end_time_base-start_time_base:.2f} sec"
semaphore.release()
return base_img, result
##@spaces.GPU
def generate_merged(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):
semaphore.acquire()
downsample_factor = 2
ratio = 0.38
merge_method = "downsample" if method == "todo" else "similarity"
merge_tokens = "keys/values" if method == "todo" else "all"
if height_width == 1024:
downsample_factor = 2
ratio = 0.75
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
elif height_width == 1536:
downsample_factor = 3
ratio = 0.89
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
elif height_width == 2048:
downsample_factor = 4
ratio = 0.9375
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
token_merge_args = {"ratio": ratio,
"merge_tokens": merge_tokens,
"merge_method": merge_method,
"downsample_method": "nearest",
"downsample_factor": downsample_factor,
"timestep_threshold_switch": 0.0,
"timestep_threshold_stop": 0.0,
"downsample_factor_level_2": downsample_factor_level_2,
"ratio_level_2": ratio_level_2
}
patch_attention_proc(pipe.unet, token_merge_args=token_merge_args)
torch.manual_seed(seed)
start_time_merge = time.time()
merged_img = pipe(prompt,
num_inference_steps=steps, height=height_width, width=height_width,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale).images[0]
end_time_merge = time.time()
result = f"{'ToDo' if method == 'todo' else 'ToMe'} Runtime: {end_time_merge-start_time_merge:.2f} sec"
semaphore.release()
return merged_img, result
with gr.Blocks(css=css) as demo:
gr.Markdown("# ToDo: Token Downsampling for Efficient Generation of High-Resolution Images")
prompt = gr.Textbox(interactive=True, label="prompt")
negative_prompt = gr.Textbox(interactive=True, label="negative_prompt")
with gr.Row():
method = gr.Dropdown(["todo", "tome"], value="todo", label="method", info="Choose Your Desired Method (Default: todo)")
height_width = gr.Dropdown([1024, 1536, 2048], value=1024, label="height/width", info="Choose Your Desired Height/Width (Default: 1024)")
with gr.Row():
guidance_scale = gr.Number(label="guidance_scale", value=7.5, precision=1)
steps = gr.Number(label="steps", value=20, precision=0)
seed = gr.Number(label="seed", value=1, precision=0)
with gr.Row():
with gr.Column():
base_result = gr.Textbox(label="Baseline Runtime")
base_image = gr.Image(label=f"baseline_image", type="pil", interactive=False)
gen = gr.Button("Generate Baseline")
gen.click(generate_baseline, inputs=[prompt, seed, steps, height_width, negative_prompt,
guidance_scale, method], outputs=[base_image, base_result])
with gr.Column():
output_result = gr.Textbox(label="Runtime")
output_image = gr.Image(label=f"image", type="pil", interactive=False)
gen = gr.Button("Generate")
gen.click(generate_merged, inputs=[prompt, seed, steps, height_width, negative_prompt,
guidance_scale, method], outputs=[output_image, output_result])
demo.launch(share=True)
|