naonauno commited on
Commit
84ab83e
·
verified ·
1 Parent(s): 3ba4f56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import torch
3
  import numpy as np
4
  import cv2
5
- from diffusers import StableDiffusionPipeline
6
  from model import UNet2DConditionModelEx
7
  from pipeline import StableDiffusionControlLoraV3Pipeline
8
  from PIL import Image
@@ -14,7 +14,7 @@ login(token=os.environ.get("HF_TOKEN"))
14
 
15
  # Initialize the models
16
  base_model = "runwayml/stable-diffusion-v1-5"
17
- dtype = torch.float32
18
 
19
  # Load the custom UNet
20
  unet = UNet2DConditionModelEx.from_pretrained(
@@ -33,16 +33,17 @@ pipe = StableDiffusionControlLoraV3Pipeline.from_pretrained(
33
  torch_dtype=dtype
34
  )
35
 
 
 
 
36
  # Load the ControlLoRA weights
37
  pipe.load_lora_weights(
38
  "models",
39
  weight_name="40kHalf.safetensors"
40
  )
41
 
42
- # Enable CPU offload
43
  pipe.enable_model_cpu_offload()
44
-
45
- # Enable memory efficient attention if available
46
  if hasattr(pipe, "enable_xformers_memory_efficient_attention"):
47
  pipe.enable_xformers_memory_efficient_attention()
48
 
 
2
  import torch
3
  import numpy as np
4
  import cv2
5
+ from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
6
  from model import UNet2DConditionModelEx
7
  from pipeline import StableDiffusionControlLoraV3Pipeline
8
  from PIL import Image
 
14
 
15
  # Initialize the models
16
  base_model = "runwayml/stable-diffusion-v1-5"
17
+ dtype = torch.float16 # A100 works better with float16
18
 
19
  # Load the custom UNet
20
  unet = UNet2DConditionModelEx.from_pretrained(
 
33
  torch_dtype=dtype
34
  )
35
 
36
+ # Use a faster scheduler
37
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
38
+
39
  # Load the ControlLoRA weights
40
  pipe.load_lora_weights(
41
  "models",
42
  weight_name="40kHalf.safetensors"
43
  )
44
 
45
+ # Enable optimizations
46
  pipe.enable_model_cpu_offload()
 
 
47
  if hasattr(pipe, "enable_xformers_memory_efficient_attention"):
48
  pipe.enable_xformers_memory_efficient_attention()
49