ovi054 commited on
Commit
09fcd4a
·
verified ·
1 Parent(s): 9757f67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -86
app.py CHANGED
@@ -1,97 +1,82 @@
1
  import torch
2
- from diffusers import UniPCMultistepScheduler, FlowMatchEulerDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler
3
- from diffusers import WanPipeline, AutoencoderKLWan # Use Wan-specific VAE
4
- # from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
5
  from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
6
- from diffusers.models import UNetSpatioTemporalConditionModel
7
- from transformers import T5EncoderModel, T5Tokenizer
8
  from huggingface_hub import hf_hub_download
9
-
10
  from PIL import Image
11
  import numpy as np
12
-
13
  import gradio as gr
14
  import spaces
15
 
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
17
 
18
- model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
 
19
  vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
 
 
 
20
  pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
21
- flow_shift = 1.0 #5.0 1.0 for image, 5.0 for 720P, 3.0 for 480P
 
22
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
23
 
 
 
24
  pipe.to(device)
25
 
26
- # Configure DDIMScheduler with a beta schedule
27
- # pipe.scheduler = DDIMScheduler.from_config(
28
- # pipe.scheduler.config,
29
- # beta_start=0.00085, # Starting beta value
30
- # beta_end=0.012, # Ending beta value
31
- # beta_schedule="linear", # Linear beta schedule (other options: "scaled_linear", "squaredcos_cap_v2")
32
- # num_train_timesteps=1000, # Number of timesteps
33
- # flow_shift=flow_shift
34
- # )
35
-
36
 
37
- # Configure FlowMatchEulerDiscreteScheduler
38
- # pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
39
- # pipe.scheduler.config,
40
- # flow_shift=flow_shift # Retain flow_shift for WanPipeline compatibility
41
- # )
42
-
43
- # --- LoRA State Management ---
44
- # Define unique names for our adapters
45
- DEFAULT_LORA_NAME = "causvid_lora"
46
- CUSTOM_LORA_NAME = "custom_lora"
47
- # Track which custom LoRA is currently loaded to avoid reloading
48
- CURRENTLY_LOADED_CUSTOM_LORA = None
49
-
50
- # Load the default base LoRA ONCE at startup
51
- print("Loading base LoRA...")
52
  CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
53
  CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
54
- # try:
55
- # causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
56
- # pipe.load_lora_weights(causvid_path, adapter_name=DEFAULT_LORA_NAME)
57
- # print(f"✅ Default LoRA '{DEFAULT_LORA_NAME}' loaded successfully.")
58
- # except Exception as e:
59
- # print(f"⚠️ Default LoRA could not be loaded: {e}")
60
- # DEFAULT_LORA_NAME = None
61
-
62
- # print("Initialization complete. Gradio is starting...")
63
-
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  @spaces.GPU()
67
  def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
68
- # if lora_id and lora_id.strip() != "":
69
- # pipe.unload_lora_weights()
70
- # pipe.load_lora_weights(lora_id.strip())
71
-
72
  clean_lora_id = lora_id.strip() if lora_id else ""
73
- print("Loading base LoRA for this run...")
74
- causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
75
- pipe.load_lora_weights(causvid_path, adapter_name=DEFAULT_LORA_NAME)
76
-
77
- # If a custom LoRA is provided, load it as well.
78
  if clean_lora_id:
79
- print(f"Loading custom LoRA '{clean_lora_id}' for this run...")
80
- pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME)
81
- # If a custom LoRA is present, activate both.
82
- pipe.set_adapters([DEFAULT_LORA_NAME, CUSTOM_LORA_NAME], adapter_weights=[1.0, 1.0])
83
- else:
84
- # If no custom LoRA, just activate the base one.
85
- print("Activating base LoRA only.")
86
- pipe.set_adapters([DEFAULT_LORA_NAME], adapter_weights=[1.0])
87
-
88
- pipe.to(device)
89
- #pipe.to("cuda")
90
- # apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
91
- apply_cache_on_pipe(
92
- pipe,
93
- # residual_diff_threshold=0.2,
94
- )
 
 
 
95
  try:
96
  output = pipe(
97
  prompt=prompt,
@@ -100,36 +85,29 @@ def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_ste
100
  width=width,
101
  num_frames=1,
102
  num_inference_steps=num_inference_steps,
103
- guidance_scale=1.0, #5.0
104
  )
105
  image = output.frames[0][0]
106
  image = (image * 255).astype(np.uint8)
107
  return Image.fromarray(image)
108
  finally:
109
- # if lora_id and lora_id.strip() != "":
110
- # pass
111
- # pipe.unload_lora_weights()
112
- # if clean_lora_id:
113
- # print(f"Unloading '{CUSTOM_LORA_NAME}' from this run.")
114
- # pipe.unload_lora_weights(CUSTOM_LORA_NAME)
115
-
116
- # # Always disable all active LoRAs to reset the state.
117
- # pipe.disable_lora()
118
- print("Unloading all LoRAs to clean up.")
119
- pipe.unload_lora_weights()
120
-
121
-
122
  iface = gr.Interface(
123
  fn=generate,
124
  inputs=[
125
  gr.Textbox(label="Input prompt"),
126
- ],
127
- additional_inputs = [
128
  gr.Textbox(label="Negative prompt", value = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"),
129
  gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
130
  gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
131
  gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
132
- gr.Textbox(label="LoRA ID"),
133
  ],
134
  outputs=gr.Image(label="output"),
135
  )
 
1
  import torch
2
+ from diffusers import UniPCMultistepScheduler
3
+ from diffusers import WanPipeline, AutoencoderKLWan
 
4
  from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
 
 
5
  from huggingface_hub import hf_hub_download
 
6
  from PIL import Image
7
  import numpy as np
 
8
  import gradio as gr
9
  import spaces
10
 
11
+ # --- INITIAL SETUP ---
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ print(f"Using device: {device}")
14
 
15
+ model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" # Using the large model
16
+ print("Loading VAE...")
17
  vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
18
+
19
+ print("Loading WanPipeline...")
20
+ # low_cpu_mem_usage is often needed for meta device loading
21
  pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
22
+
23
+ flow_shift = 1.0
24
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
25
 
26
+ # Move the base pipeline to the GPU. This is where meta-device loading happens.
27
+ print("Moving pipeline to device...")
28
  pipe.to(device)
29
 
30
+ # --- LORA FUSING (Done ONCE at startup) ---
31
+ # We will fuse the base LoRA permanently into the model for performance and to solve the device issue.
32
+ # This means we cannot dynamically unload it, but it's the correct approach for a fixed setup.
 
 
 
 
 
 
 
33
 
34
+ print("Loading and Fusing base LoRA...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
36
  CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
37
+ CUSTOM_LORA_NAME = "custom_lora" # For any optional LoRA
 
 
 
 
 
 
 
 
 
38
 
39
+ try:
40
+ causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
41
+ # 1. Load the LoRA weights. They will likely be on the CPU.
42
+ pipe.load_lora_weights(causvid_path) # Use the default adapter name
43
+ print("✅ Base LoRA loaded.")
44
+
45
+ # 2. Fuse the weights into the base model. This resolves the device mismatch.
46
+ pipe.fuse_lora()
47
+ print("✅ Base LoRA fused successfully.")
48
+
49
+ except Exception as e:
50
+ print(f"⚠️ Could not load or fuse the base LoRA: {e}")
51
+
52
+ print("Initialization complete. Gradio is starting...")
53
 
54
  @spaces.GPU()
55
  def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
56
+ # The base LoRA is already fused. We only need to handle the optional custom LoRA.
 
 
 
57
  clean_lora_id = lora_id.strip() if lora_id else ""
58
+
59
+ # We will load and unload the custom LoRA dynamically for each run
 
 
 
60
  if clean_lora_id:
61
+ try:
62
+ print(f"Loading custom LoRA '{clean_lora_id}' for this run...")
63
+ # Load the custom LoRA. Note: We will NOT fuse this one to keep it temporary.
64
+ pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME)
65
+
66
+ # This is the critical part for dynamic LoRAs on large models.
67
+ # We must explicitly move the adapter to the correct device.
68
+ pipe.to(device, dtype=pipe.transformer.dtype) # Ensure dtype matches
69
+
70
+ pipe.set_adapters([CUSTOM_LORA_NAME], adapter_weights=[1.0])
71
+ print(f"✅ Custom LoRA '{CUSTOM_LORA_NAME}' activated.")
72
+ except Exception as e:
73
+ print(f"⚠️ Failed to load custom LoRA '{clean_lora_id}': {e}. Running without it.")
74
+ # Ensure no adapters are active if loading failed
75
+ pipe.disable_lora()
76
+
77
+ # Apply performance optimizations
78
+ apply_cache_on_pipe(pipe)
79
+
80
  try:
81
  output = pipe(
82
  prompt=prompt,
 
85
  width=width,
86
  num_frames=1,
87
  num_inference_steps=num_inference_steps,
88
+ guidance_scale=1.0,
89
  )
90
  image = output.frames[0][0]
91
  image = (image * 255).astype(np.uint8)
92
  return Image.fromarray(image)
93
  finally:
94
+ # Clean up the dynamic LoRA if it was loaded
95
+ if clean_lora_id:
96
+ print(f"Unloading '{CUSTOM_LORA_NAME}' to clean up.")
97
+ pipe.unload_lora_weights(CUSTOM_LORA_NAME)
98
+ # It's good practice to disable all just in case.
99
+ pipe.disable_lora()
100
+
101
+ # --- Your Gradio Interface Code (no changes needed here) ---
 
 
 
 
 
102
  iface = gr.Interface(
103
  fn=generate,
104
  inputs=[
105
  gr.Textbox(label="Input prompt"),
 
 
106
  gr.Textbox(label="Negative prompt", value = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"),
107
  gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
108
  gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
109
  gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
110
+ gr.Textbox(label="LoRA ID (Optional, loads dynamically)"),
111
  ],
112
  outputs=gr.Image(label="output"),
113
  )