ovi054 commited on
Commit
4166d00
·
verified ·
1 Parent(s): adefe82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -41
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from diffusers import UniPCMultistepScheduler
3
  from diffusers import WanPipeline, AutoencoderKLWan
4
  from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
5
  from huggingface_hub import hf_hub_download
@@ -14,65 +14,97 @@ print(f"Using device: {device}")
14
 
15
  model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
16
  print("Loading VAE...")
17
- # VAE is often kept in float32 for precision during decoding
18
  vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
19
 
20
  print("Loading WanPipeline in bfloat16...")
21
- # Load the main model in bfloat16 to save memory
22
  pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
23
 
24
  flow_shift = 1.0
25
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
26
 
27
- # Move the base pipeline to the GPU. This is where meta-device loading happens.
28
- print("Moving pipeline to device...")
29
  pipe.to(device)
30
 
31
- # --- LORA FUSING (Done ONCE at startup) ---
32
- print("Loading and Fusing base LoRA...")
33
  CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
34
  CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
 
35
  CUSTOM_LORA_NAME = "custom_lora"
36
 
 
37
  try:
38
  causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
39
- pipe.load_lora_weights(causvid_path) # Loads LoRA, likely onto CPU
40
- print("✅ Base LoRA loaded.")
41
-
42
- pipe.fuse_lora() # Fuses float32 LoRA into bfloat16 model
43
- print("✅ Base LoRA fused.")
44
-
45
- # FIX for Dtype Mismatch: After fusing, some layers became float32.
46
- # We must cast the entire pipeline back to bfloat16 to ensure consistency.
47
- pipe.to(dtype=torch.bfloat16)
48
- print("✅ Pipeline converted back to bfloat16 post-fusion.")
49
-
50
  except Exception as e:
51
- print(f"⚠️ Could not load or fuse the base LoRA: {e}")
 
52
 
53
  print("Initialization complete. Gradio is starting...")
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  @spaces.GPU()
56
  def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
57
- clean_lora_id = lora_id.strip() if lora_id else ""
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  if clean_lora_id:
60
  try:
61
- print(f"Applying custom LoRA '{clean_lora_id}' for this run...")
62
- # 1. Load the temporary LoRA
63
  pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME)
64
 
65
- # 2. Fuse it into the model. This is the key to avoiding device/meta errors.
66
- pipe.fuse_lora(adapter_names=[CUSTOM_LORA_NAME])
67
-
68
- # 3. Ensure dtype consistency after the new fusion
69
- pipe.to(dtype=torch.bfloat16)
70
- print(f"✅ Custom LoRA '{CUSTOM_LORA_NAME}' fused and activated.")
71
 
 
 
72
  except Exception as e:
73
- print(f"⚠️ Failed to load or fuse custom LoRA '{clean_lora_id}': {e}. Running without it.")
74
- clean_lora_id = "" # Clear the id so we don't try to clean it up
 
 
75
 
 
 
 
 
 
76
  apply_cache_on_pipe(pipe)
77
 
78
  try:
@@ -89,18 +121,19 @@ def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_ste
89
  image = (image * 255).astype(np.uint8)
90
  return Image.fromarray(image)
91
  finally:
92
- # Clean up the dynamic LoRA if it was successfully loaded and fused
93
- if clean_lora_id:
94
- print(f"Cleaning up custom LoRA '{CUSTOM_LORA_NAME}'...")
95
- # 1. Unfuse the temporary LoRA to revert the model's weights
96
- pipe.unfuse_lora(adapter_names=[CUSTOM_LORA_NAME])
97
-
98
- # 2. Unload the LoRA weights from memory
99
- # FIX for TypeError: unload_lora_weights takes no adapter name
100
- pipe.unload_lora_weights()
101
- print("✅ Custom LoRA unfused and unloaded.")
 
 
102
 
103
- # --- Your Gradio Interface Code (no changes needed here) ---
104
  iface = gr.Interface(
105
  fn=generate,
106
  inputs=[
 
1
  import torch
2
+ from diffusers import UniPCMultepScheduler
3
  from diffusers import WanPipeline, AutoencoderKLWan
4
  from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
5
  from huggingface_hub import hf_hub_download
 
14
 
15
  model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
16
  print("Loading VAE...")
 
17
  vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
18
 
19
  print("Loading WanPipeline in bfloat16...")
20
+ # This will use ZeroGPU/accelerate with meta devices
21
  pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
22
 
23
  flow_shift = 1.0
24
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
25
 
26
+ # Move the base pipeline to the GPU. ZeroGPU will manage this.
27
+ print("Moving pipeline to device (ZeroGPU will handle offloading)...")
28
  pipe.to(device)
29
 
30
+ # --- LORA SETUP ---
31
+ # We will NOT fuse anything. Everything will be handled dynamically.
32
  CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
33
  CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
34
+ BASE_LORA_NAME = "causvid_lora"
35
  CUSTOM_LORA_NAME = "custom_lora"
36
 
37
+ print("Downloading base LoRA...")
38
  try:
39
  causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
40
+ print("✅ Base LoRA downloaded.")
 
 
 
 
 
 
 
 
 
 
41
  except Exception as e:
42
+ causvid_path = None
43
+ print(f"⚠️ Could not download base LoRA: {e}")
44
 
45
  print("Initialization complete. Gradio is starting...")
46
 
47
+ def move_adapter_to_device(pipe, adapter_name, device):
48
+ """
49
+ Surgically moves only the parameters of a specific LoRA adapter to the target device.
50
+ This avoids touching the base model's meta tensors.
51
+ """
52
+ print(f"Moving adapter '{adapter_name}' to {device}...")
53
+ for param in pipe.transformer.parameters():
54
+ if hasattr(param, "adapter_name") and param.adapter_name == adapter_name:
55
+ param.data = param.data.to(device, non_blocking=True)
56
+ if param.grad is not None:
57
+ param.grad.data = param.grad.data.to(device, non_blocking=True)
58
+ print(f"✅ Adapter '{adapter_name}' moved.")
59
+
60
  @spaces.GPU()
61
  def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
 
62
 
63
+ # --- DYNAMIC LORA MANAGEMENT FOR EACH RUN ---
64
+ # Start with a clean slate by disabling any active adapters from previous runs
65
+ pipe.disable_lora()
66
+
67
+ active_adapters = []
68
+ adapter_weights = []
69
+
70
+ # 1. Load the Base LoRA
71
+ if causvid_path:
72
+ try:
73
+ # We load it for every run to ensure a clean state
74
+ print(f"Loading base LoRA '{BASE_LORA_NAME}'...")
75
+ pipe.load_lora_weights(causvid_path, adapter_name=BASE_LORA_NAME)
76
+
77
+ # THE CRITICAL FIX: Move only this adapter's weights to the GPU
78
+ move_adapter_to_device(pipe, BASE_LORA_NAME, device)
79
+
80
+ active_adapters.append(BASE_LORA_NAME)
81
+ adapter_weights.append(1.0)
82
+ except Exception as e:
83
+ print(f"⚠️ Failed to load base LoRA: {e}")
84
+
85
+ # 2. Load the Custom LoRA if provided
86
+ clean_lora_id = lora_id.strip() if lora_id else ""
87
  if clean_lora_id:
88
  try:
89
+ print(f"Loading custom LoRA '{CUSTOM_LORA_NAME}' from '{clean_lora_id}'...")
 
90
  pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME)
91
 
92
+ # THE CRITICAL FIX: Move only this adapter's weights to the GPU
93
+ move_adapter_to_device(pipe, CUSTOM_LORA_NAME, device)
 
 
 
 
94
 
95
+ active_adapters.append(CUSTOM_LORA_NAME)
96
+ adapter_weights.append(1.0)
97
  except Exception as e:
98
+ print(f"⚠️ Failed to load custom LoRA '{clean_lora_id}': {e}")
99
+ # If it fails, delete the adapter config to prevent issues
100
+ if CUSTOM_LORA_NAME in pipe.transformer.peft_config:
101
+ del pipe.transformer.peft_config[CUSTOM_LORA_NAME]
102
 
103
+ # 3. Activate the successfully loaded adapters
104
+ if active_adapters:
105
+ print(f"Activating adapters: {active_adapters} with weights: {adapter_weights}")
106
+ pipe.set_adapters(active_adapters, adapter_weights)
107
+
108
  apply_cache_on_pipe(pipe)
109
 
110
  try:
 
121
  image = (image * 255).astype(np.uint8)
122
  return Image.fromarray(image)
123
  finally:
124
+ # --- PROPER CLEANUP ---
125
+ print("Cleaning up LoRAs for this run...")
126
+ # Disable adapters to stop them from being used
127
+ pipe.disable_lora()
128
+
129
+ # Delete the LoRA configs from the model to truly unload them
130
+ if BASE_LORA_NAME in pipe.transformer.peft_config:
131
+ del pipe.transformer.peft_config[BASE_LORA_NAME]
132
+ if CUSTOM_LORA_NAME in pipe.transformer.peft_config:
133
+ del pipe.transformer.peft_config[CUSTOM_LORA_NAME]
134
+ print("✅ LoRAs cleaned up.")
135
+
136
 
 
137
  iface = gr.Interface(
138
  fn=generate,
139
  inputs=[