alexanz commited on
Commit
a9322d5
·
verified ·
1 Parent(s): d12f069

add distilled

Browse files
Files changed (1) hide show
  1. app.py +39 -6
app.py CHANGED
@@ -6,7 +6,7 @@ from rembg import remove
6
 
7
  # import spaces #[uncomment to use ZeroGPU]
8
  from peft import PeftModel
9
- from diffusers import DiffusionPipeline, StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline
10
  from diffusers.utils import load_image
11
  import torch
12
 
@@ -26,7 +26,8 @@ MAX_IMAGE_SIZE = 512
26
 
27
 
28
  # @spaces.GPU #[uncomment to use ZeroGPU]
29
- def load_model(model_id, lora_strength, use_controlnet=False, control_mode="edge_detection", use_ip_adapter=False, control_strength_ip=0.0):
 
30
  global pipe
31
  if pipe is not None:
32
  del pipe
@@ -75,8 +76,25 @@ def load_model(model_id, lora_strength, use_controlnet=False, control_mode="edge
75
  )
76
  pipe.unet = PeftModel.from_pretrained(pipe.unet, model_id, scaling=lora_strength, torch_dtype=torch_dtype)
77
  else:
78
- pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch_dtype)
79
- pipe.unet = PeftModel.from_pretrained(pipe.unet, model_id, scaling=lora_strength)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  if use_ip_adapter:
82
  pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
@@ -110,6 +128,7 @@ def infer(
110
  control_strength_ip,
111
  control_image_ip,
112
  use_rmbg,
 
113
  progress=gr.Progress(track_tqdm=True),
114
  ):
115
  load_status = load_model(
@@ -118,7 +137,8 @@ def infer(
118
  use_controlnet,
119
  control_mode,
120
  use_ip_adapter,
121
- control_strength_ip
 
122
  )
123
  if randomize_seed:
124
  seed = random.randint(0, MAX_SEED)
@@ -278,6 +298,12 @@ with gr.Blocks(css=css) as demo:
278
 
279
  use_rmbg = gr.Checkbox(label="Delete background?", value=False)
280
 
 
 
 
 
 
 
281
  gr.Examples(examples=examples, inputs=[prompt])
282
 
283
  gr.on(
@@ -301,7 +327,8 @@ with gr.Blocks(css=css) as demo:
301
  use_ip_adapter,
302
  control_strength_ip,
303
  control_image_ip,
304
- use_rmbg
 
305
  ],
306
  outputs=[result, seed, model_status],
307
  )
@@ -323,5 +350,11 @@ with gr.Blocks(css=css) as demo:
323
  inputs=[use_rmbg]
324
  )
325
 
 
 
 
 
 
 
326
  if __name__ == "__main__":
327
  demo.launch()
 
6
 
7
  # import spaces #[uncomment to use ZeroGPU]
8
  from peft import PeftModel
9
+ from diffusers import DiffusionPipeline, StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, AutoencoderTiny, DDIMScheduler
10
  from diffusers.utils import load_image
11
  import torch
12
 
 
26
 
27
 
28
  # @spaces.GPU #[uncomment to use ZeroGPU]
29
+ def load_model(model_id, lora_strength, use_controlnet=False, control_mode="edge_detection", use_ip_adapter=False, control_strength_ip=0.0,
30
+ acceleration_mode=None):
31
  global pipe
32
  if pipe is not None:
33
  del pipe
 
76
  )
77
  pipe.unet = PeftModel.from_pretrained(pipe.unet, model_id, scaling=lora_strength, torch_dtype=torch_dtype)
78
  else:
79
+ if acceleration_mode is None:
80
+ pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch_dtype)
81
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, model_id, scaling=lora_strength)
82
+ elif acceleration_mode == "distilled":
83
+ pipe = StableDiffusionPipeline.from_pretrained(
84
+ "nota-ai/bk-sdm-small", torch_dtype=torch.float16, use_safetensors=True,
85
+ )
86
+ elif acceleration_mode == "distilled + tiny":
87
+ pipe = StableDiffusionPipeline.from_pretrained(
88
+ "nota-ai/bk-sdm-small", torch_dtype=torch.float16, use_safetensors=True,
89
+ )
90
+ pipe.vae = AutoencoderTiny.from_pretrained(
91
+ "sayakpaul/taesd-diffusers", torch_dtype=torch.float16, use_safetensors=True,
92
+ )
93
+ elif acceleration_mode == "DDIM":
94
+ pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch_dtype)
95
+ pipe.scheduler = DDIMScheduler.from_config(
96
+ pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
97
+ )
98
 
99
  if use_ip_adapter:
100
  pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
 
128
  control_strength_ip,
129
  control_image_ip,
130
  use_rmbg,
131
+ acceleration_mode,
132
  progress=gr.Progress(track_tqdm=True),
133
  ):
134
  load_status = load_model(
 
137
  use_controlnet,
138
  control_mode,
139
  use_ip_adapter,
140
+ control_strength_ip,
141
+ acceleration_mode
142
  )
143
  if randomize_seed:
144
  seed = random.randint(0, MAX_SEED)
 
298
 
299
  use_rmbg = gr.Checkbox(label="Delete background?", value=False)
300
 
301
+ use_acceleration = gr.Checkbox(label="Use accelerate model? (only for 1.5 SD!)", value=False)
302
+ with gr.Accordion("Acceleration Settings", open=True, visible=False) as acceleration_settings:
303
+ acceleration_mode = gr.Dropdown(label="Acceleration mode",
304
+ choices=["distilled", "distilled + tiny", "DDIM"],
305
+ value=None)
306
+
307
  gr.Examples(examples=examples, inputs=[prompt])
308
 
309
  gr.on(
 
327
  use_ip_adapter,
328
  control_strength_ip,
329
  control_image_ip,
330
+ use_rmbg,
331
+ acceleration_mode
332
  ],
333
  outputs=[result, seed, model_status],
334
  )
 
350
  inputs=[use_rmbg]
351
  )
352
 
353
+ use_acceleration.change(
354
+ fn=lambda x: gr.update(visible=x, value=None),
355
+ inputs=[use_acceleration],
356
+ outputs=[acceleration_settings]
357
+ )
358
+
359
  if __name__ == "__main__":
360
  demo.launch()