ovi054 commited on
Commit
bf7c515
·
verified ·
1 Parent(s): 4294cf6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -46
app.py CHANGED
@@ -1,68 +1,82 @@
1
  import torch
2
- from diffusers import UniPCMultistepScheduler
3
- from diffusers import WanPipeline, AutoencoderKLWan
 
4
  from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
 
 
5
  from huggingface_hub import hf_hub_download
 
6
  from PIL import Image
7
  import numpy as np
 
8
  import gradio as gr
9
  import spaces
10
- import gc
11
 
12
- # --- INITIAL SETUP ---
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
- print(f"Using device: {device}")
15
 
16
  model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # --- PRE-FUSION STRATEGY ---
19
- # 1. Load everything on the CPU first. Do NOT use device_map or torch_dtype here yet.
20
- # This prevents ZeroGPU/accelerate from taking over too early.
21
- print("Loading all components on CPU for pre-fusion...")
22
- vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae")
23
- pipe = WanPipeline.from_pretrained(model_id, vae=vae)
24
 
25
- # 2. Load and FUSE the LoRA while the entire pipeline is still on the CPU.
26
- print("Loading and fusing base LoRA on CPU...")
27
  CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
28
- CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
29
  try:
30
  causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
31
- pipe.load_lora_weights(causvid_path)
32
- print("✅ Base LoRA loaded on CPU.")
33
-
34
- pipe.fuse_lora()
35
- print("✅ Base LoRA fused on CPU.")
36
-
37
- # After fusing, we must unload the original LoRA weights to save memory
38
- pipe.unload_lora_weights()
39
- print("✅ Original LoRA weights unloaded.")
40
-
41
  except Exception as e:
42
- print(f"⚠️ Could not pre-fuse the LoRA: {e}. The model will run without it.")
 
43
 
44
- # 3. NOW, move the final, fused model to the target device and set the dtype.
45
- # ZeroGPU will now take over a single, coherent model.
46
- print("Moving the final, fused pipeline to device and setting dtype...")
47
- pipe.to(device=device, dtype=torch.bfloat16)
48
 
49
- # 4. Configure the scheduler
50
- flow_shift = 1.0
51
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
52
-
53
- # Clean up memory before starting the server
54
- del vae
55
- gc.collect()
56
- torch.cuda.empty_cache()
57
 
58
- print("Initialization complete. Gradio is starting...")
59
 
60
  @spaces.GPU()
61
- def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, progress=gr.Progress(track_tqdm=True)):
62
- # The LoRA is permanently fused. No dynamic loading is possible or needed.
63
-
64
- apply_cache_on_pipe(pipe)
65
-
 
 
 
 
 
 
 
 
66
  try:
67
  output = pipe(
68
  prompt=prompt,
@@ -71,24 +85,26 @@ def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_ste
71
  width=width,
72
  num_frames=1,
73
  num_inference_steps=num_inference_steps,
74
- guidance_scale=1.0,
75
  )
76
  image = output.frames[0][0]
77
  image = (image * 255).astype(np.uint8)
78
  return Image.fromarray(image)
79
  finally:
80
- # No cleanup needed as the state is static
81
  pass
82
 
83
- # Interface is simplified as the custom LoRA option is removed for stability.
84
  iface = gr.Interface(
85
  fn=generate,
86
  inputs=[
87
  gr.Textbox(label="Input prompt"),
 
 
88
  gr.Textbox(label="Negative prompt", value = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"),
89
  gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
90
  gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
91
  gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
 
92
  ],
93
  outputs=gr.Image(label="output"),
94
  )
 
1
  import torch
2
+ from diffusers import UniPCMultistepScheduler, FlowMatchEulerDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler
3
+ from diffusers import WanPipeline, AutoencoderKLWan # Use Wan-specific VAE
4
+ # from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
5
  from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
6
+ from diffusers.models import UNetSpatioTemporalConditionModel
7
+ from transformers import T5EncoderModel, T5Tokenizer
8
  from huggingface_hub import hf_hub_download
9
+
10
  from PIL import Image
11
  import numpy as np
12
+
13
  import gradio as gr
14
  import spaces
 
15
 
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
17
 
18
  model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
19
+ vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
20
+ pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
21
+ flow_shift = 1.0 #5.0 1.0 for image, 5.0 for 720P, 3.0 for 480P
22
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
23
+
24
+ pipe.to(device)
25
+
26
+ # Configure DDIMScheduler with a beta schedule
27
+ # pipe.scheduler = DDIMScheduler.from_config(
28
+ # pipe.scheduler.config,
29
+ # beta_start=0.00085, # Starting beta value
30
+ # beta_end=0.012, # Ending beta value
31
+ # beta_schedule="linear", # Linear beta schedule (other options: "scaled_linear", "squaredcos_cap_v2")
32
+ # num_train_timesteps=1000, # Number of timesteps
33
+ # flow_shift=flow_shift
34
+ # )
35
+
36
+
37
+ # Configure FlowMatchEulerDiscreteScheduler
38
+ # pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
39
+ # pipe.scheduler.config,
40
+ # flow_shift=flow_shift # Retain flow_shift for WanPipeline compatibility
41
+ # )
42
 
43
+ # --- LoRA State Management ---
44
+ # Define unique names for our adapters
45
+ DEFAULT_LORA_NAME = "causvid_lora"
46
+ CUSTOM_LORA_NAME = "custom_lora"
47
+ # Track which custom LoRA is currently loaded to avoid reloading
48
+ CURRENTLY_LOADED_CUSTOM_LORA = None
49
 
50
+ # Load the default base LoRA ONCE at startup
51
+ print("Loading base LoRA...")
52
  CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
53
+ CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors"
54
  try:
55
  causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
56
+ pipe.load_lora_weights(causvid_path, adapter_name=DEFAULT_LORA_NAME)
57
+ print(f"✅ Default LoRA '{DEFAULT_LORA_NAME}' loaded successfully.")
 
 
 
 
 
 
 
 
58
  except Exception as e:
59
+ print(f"⚠️ Default LoRA could not be loaded: {e}")
60
+ DEFAULT_LORA_NAME = None
61
 
62
+ # print("Initialization complete. Gradio is starting...")
 
 
 
63
 
 
 
 
 
 
 
 
 
64
 
 
65
 
66
  @spaces.GPU()
67
+ def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
68
+ # if lora_id and lora_id.strip() != "":
69
+ # pipe.unload_lora_weights()
70
+ # pipe.load_lora_weights(lora_id.strip())
71
+
72
+
73
+
74
+ #pipe.to("cuda")
75
+ # apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
76
+ apply_cache_on_pipe(
77
+ pipe,
78
+ # residual_diff_threshold=0.2,
79
+ )
80
  try:
81
  output = pipe(
82
  prompt=prompt,
 
85
  width=width,
86
  num_frames=1,
87
  num_inference_steps=num_inference_steps,
88
+ guidance_scale=1.0, #5.0
89
  )
90
  image = output.frames[0][0]
91
  image = (image * 255).astype(np.uint8)
92
  return Image.fromarray(image)
93
  finally:
 
94
  pass
95
 
96
+
97
  iface = gr.Interface(
98
  fn=generate,
99
  inputs=[
100
  gr.Textbox(label="Input prompt"),
101
+ ],
102
+ additional_inputs = [
103
  gr.Textbox(label="Negative prompt", value = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"),
104
  gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
105
  gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
106
  gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
107
+ gr.Textbox(label="LoRA ID"),
108
  ],
109
  outputs=gr.Image(label="output"),
110
  )