ovi054 commited on
Commit
adefe82
·
verified ·
1 Parent(s): 09fcd4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -31
app.py CHANGED
@@ -12,12 +12,13 @@ import spaces
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  print(f"Using device: {device}")
14
 
15
- model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" # Using the large model
16
  print("Loading VAE...")
 
17
  vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
18
 
19
- print("Loading WanPipeline...")
20
- # low_cpu_mem_usage is often needed for meta device loading
21
  pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
22
 
23
  flow_shift = 1.0
@@ -28,24 +29,24 @@ print("Moving pipeline to device...")
28
  pipe.to(device)
29
 
30
  # --- LORA FUSING (Done ONCE at startup) ---
31
- # We will fuse the base LoRA permanently into the model for performance and to solve the device issue.
32
- # This means we cannot dynamically unload it, but it's the correct approach for a fixed setup.
33
-
34
  print("Loading and Fusing base LoRA...")
35
  CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
36
  CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
37
- CUSTOM_LORA_NAME = "custom_lora" # For any optional LoRA
38
 
39
  try:
40
  causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
41
- # 1. Load the LoRA weights. They will likely be on the CPU.
42
- pipe.load_lora_weights(causvid_path) # Use the default adapter name
43
  print("✅ Base LoRA loaded.")
44
 
45
- # 2. Fuse the weights into the base model. This resolves the device mismatch.
46
- pipe.fuse_lora()
47
- print("✅ Base LoRA fused successfully.")
48
 
 
 
 
 
 
49
  except Exception as e:
50
  print(f"⚠️ Could not load or fuse the base LoRA: {e}")
51
 
@@ -53,28 +54,25 @@ print("Initialization complete. Gradio is starting...")
53
 
54
  @spaces.GPU()
55
  def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
56
- # The base LoRA is already fused. We only need to handle the optional custom LoRA.
57
  clean_lora_id = lora_id.strip() if lora_id else ""
58
 
59
- # We will load and unload the custom LoRA dynamically for each run
60
  if clean_lora_id:
61
  try:
62
- print(f"Loading custom LoRA '{clean_lora_id}' for this run...")
63
- # Load the custom LoRA. Note: We will NOT fuse this one to keep it temporary.
64
  pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME)
65
 
66
- # This is the critical part for dynamic LoRAs on large models.
67
- # We must explicitly move the adapter to the correct device.
68
- pipe.to(device, dtype=pipe.transformer.dtype) # Ensure dtype matches
69
 
70
- pipe.set_adapters([CUSTOM_LORA_NAME], adapter_weights=[1.0])
71
- print(f"✅ Custom LoRA '{CUSTOM_LORA_NAME}' activated.")
 
 
72
  except Exception as e:
73
- print(f"⚠️ Failed to load custom LoRA '{clean_lora_id}': {e}. Running without it.")
74
- # Ensure no adapters are active if loading failed
75
- pipe.disable_lora()
76
 
77
- # Apply performance optimizations
78
  apply_cache_on_pipe(pipe)
79
 
80
  try:
@@ -91,12 +89,16 @@ def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_ste
91
  image = (image * 255).astype(np.uint8)
92
  return Image.fromarray(image)
93
  finally:
94
- # Clean up the dynamic LoRA if it was loaded
95
  if clean_lora_id:
96
- print(f"Unloading '{CUSTOM_LORA_NAME}' to clean up.")
97
- pipe.unload_lora_weights(CUSTOM_LORA_NAME)
98
- # It's good practice to disable all just in case.
99
- pipe.disable_lora()
 
 
 
 
100
 
101
  # --- Your Gradio Interface Code (no changes needed here) ---
102
  iface = gr.Interface(
@@ -107,7 +109,7 @@ iface = gr.Interface(
107
  gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
108
  gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
109
  gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
110
- gr.Textbox(label="LoRA ID (Optional, loads dynamically)"),
111
  ],
112
  outputs=gr.Image(label="output"),
113
  )
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  print(f"Using device: {device}")
14
 
15
+ model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
16
  print("Loading VAE...")
17
+ # VAE is often kept in float32 for precision during decoding
18
  vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
19
 
20
+ print("Loading WanPipeline in bfloat16...")
21
+ # Load the main model in bfloat16 to save memory
22
  pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
23
 
24
  flow_shift = 1.0
 
29
  pipe.to(device)
30
 
31
  # --- LORA FUSING (Done ONCE at startup) ---
 
 
 
32
  print("Loading and Fusing base LoRA...")
33
  CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
34
  CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
35
+ CUSTOM_LORA_NAME = "custom_lora"
36
 
37
  try:
38
  causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
39
+ pipe.load_lora_weights(causvid_path) # Loads LoRA, likely onto CPU
 
40
  print("✅ Base LoRA loaded.")
41
 
42
+ pipe.fuse_lora() # Fuses float32 LoRA into bfloat16 model
43
+ print("✅ Base LoRA fused.")
 
44
 
45
+ # FIX for Dtype Mismatch: After fusing, some layers became float32.
46
+ # We must cast the entire pipeline back to bfloat16 to ensure consistency.
47
+ pipe.to(dtype=torch.bfloat16)
48
+ print("✅ Pipeline converted back to bfloat16 post-fusion.")
49
+
50
  except Exception as e:
51
  print(f"⚠️ Could not load or fuse the base LoRA: {e}")
52
 
 
54
 
55
  @spaces.GPU()
56
  def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
 
57
  clean_lora_id = lora_id.strip() if lora_id else ""
58
 
 
59
  if clean_lora_id:
60
  try:
61
+ print(f"Applying custom LoRA '{clean_lora_id}' for this run...")
62
+ # 1. Load the temporary LoRA
63
  pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME)
64
 
65
+ # 2. Fuse it into the model. This is the key to avoiding device/meta errors.
66
+ pipe.fuse_lora(adapter_names=[CUSTOM_LORA_NAME])
 
67
 
68
+ # 3. Ensure dtype consistency after the new fusion
69
+ pipe.to(dtype=torch.bfloat16)
70
+ print(f"✅ Custom LoRA '{CUSTOM_LORA_NAME}' fused and activated.")
71
+
72
  except Exception as e:
73
+ print(f"⚠️ Failed to load or fuse custom LoRA '{clean_lora_id}': {e}. Running without it.")
74
+ clean_lora_id = "" # Clear the id so we don't try to clean it up
 
75
 
 
76
  apply_cache_on_pipe(pipe)
77
 
78
  try:
 
89
  image = (image * 255).astype(np.uint8)
90
  return Image.fromarray(image)
91
  finally:
92
+ # Clean up the dynamic LoRA if it was successfully loaded and fused
93
  if clean_lora_id:
94
+ print(f"Cleaning up custom LoRA '{CUSTOM_LORA_NAME}'...")
95
+ # 1. Unfuse the temporary LoRA to revert the model's weights
96
+ pipe.unfuse_lora(adapter_names=[CUSTOM_LORA_NAME])
97
+
98
+ # 2. Unload the LoRA weights from memory
99
+ # FIX for TypeError: unload_lora_weights takes no adapter name
100
+ pipe.unload_lora_weights()
101
+ print("✅ Custom LoRA unfused and unloaded.")
102
 
103
  # --- Your Gradio Interface Code (no changes needed here) ---
104
  iface = gr.Interface(
 
109
  gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
110
  gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
111
  gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
112
+ gr.Textbox(label="LoRA ID (Optional)"),
113
  ],
114
  outputs=gr.Image(label="output"),
115
  )