ovi054 commited on
Commit
ed01c6d
·
verified ·
1 Parent(s): bf7c515

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -73
app.py CHANGED
@@ -1,112 +1,123 @@
1
  import torch
2
- from diffusers import UniPCMultistepScheduler, FlowMatchEulerDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler
3
- from diffusers import WanPipeline, AutoencoderKLWan # Use Wan-specific VAE
4
- # from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
5
  from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
6
- from diffusers.models import UNetSpatioTemporalConditionModel
7
- from transformers import T5EncoderModel, T5Tokenizer
8
  from huggingface_hub import hf_hub_download
9
-
10
  from PIL import Image
11
  import numpy as np
12
-
13
  import gradio as gr
14
  import spaces
 
15
 
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
17
 
18
  model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
 
19
  vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
 
 
 
20
  pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
21
- flow_shift = 1.0 #5.0 1.0 for image, 5.0 for 720P, 3.0 for 480P
 
22
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
23
 
 
 
 
24
  pipe.to(device)
25
 
26
- # Configure DDIMScheduler with a beta schedule
27
- # pipe.scheduler = DDIMScheduler.from_config(
28
- # pipe.scheduler.config,
29
- # beta_start=0.00085, # Starting beta value
30
- # beta_end=0.012, # Ending beta value
31
- # beta_schedule="linear", # Linear beta schedule (other options: "scaled_linear", "squaredcos_cap_v2")
32
- # num_train_timesteps=1000, # Number of timesteps
33
- # flow_shift=flow_shift
34
- # )
35
-
36
-
37
- # Configure FlowMatchEulerDiscreteScheduler
38
- # pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
39
- # pipe.scheduler.config,
40
- # flow_shift=flow_shift # Retain flow_shift for WanPipeline compatibility
41
- # )
42
-
43
- # --- LoRA State Management ---
44
- # Define unique names for our adapters
45
  DEFAULT_LORA_NAME = "causvid_lora"
46
  CUSTOM_LORA_NAME = "custom_lora"
47
- # Track which custom LoRA is currently loaded to avoid reloading
48
- CURRENTLY_LOADED_CUSTOM_LORA = None
49
-
50
- # Load the default base LoRA ONCE at startup
51
- print("Loading base LoRA...")
52
- CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
53
- CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors"
54
- try:
55
- causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
56
- pipe.load_lora_weights(causvid_path, adapter_name=DEFAULT_LORA_NAME)
57
- print(f"✅ Default LoRA '{DEFAULT_LORA_NAME}' loaded successfully.")
58
- except Exception as e:
59
- print(f"⚠️ Default LoRA could not be loaded: {e}")
60
- DEFAULT_LORA_NAME = None
61
-
62
- # print("Initialization complete. Gradio is starting...")
63
-
64
 
 
65
 
 
66
  @spaces.GPU()
67
- def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
68
- # if lora_id and lora_id.strip() != "":
69
- # pipe.unload_lora_weights()
70
- # pipe.load_lora_weights(lora_id.strip())
71
-
72
-
73
-
74
- #pipe.to("cuda")
75
- # apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
76
- apply_cache_on_pipe(
77
- pipe,
78
- # residual_diff_threshold=0.2,
 
 
 
 
 
79
  )
 
 
 
 
 
 
 
 
 
80
  try:
81
- output = pipe(
82
- prompt=prompt,
83
- negative_prompt=negative_prompt,
84
- height=height,
85
- width=width,
86
- num_frames=1,
87
- num_inference_steps=num_inference_steps,
88
- guidance_scale=1.0, #5.0
89
- )
90
- image = output.frames[0][0]
91
- image = (image * 255).astype(np.uint8)
92
- return Image.fromarray(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  finally:
94
- pass
 
 
 
 
 
 
95
 
96
 
 
97
  iface = gr.Interface(
98
- fn=generate,
99
  inputs=[
100
  gr.Textbox(label="Input prompt"),
101
- ],
102
- additional_inputs = [
103
  gr.Textbox(label="Negative prompt", value = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"),
104
  gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
105
  gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
106
  gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
107
- gr.Textbox(label="LoRA ID"),
108
  ],
109
  outputs=gr.Image(label="output"),
 
 
110
  )
111
 
112
  iface.launch()
 
1
  import torch
2
+ from diffusers import UniPCMultistepScheduler
3
+ from diffusers import WanPipeline, AutoencoderKLWan
 
4
  from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
 
 
5
  from huggingface_hub import hf_hub_download
 
6
  from PIL import Image
7
  import numpy as np
 
8
  import gradio as gr
9
  import spaces
10
+ import gc
11
 
12
+ # --- INITIAL SETUP ---
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"Using device: {device}")
15
 
16
  model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
17
+ print("Loading VAE...")
18
  vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
19
+
20
+ print("Loading WanPipeline in bfloat16...")
21
+ # This will use ZeroGPU/accelerate with meta devices
22
  pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
23
+
24
+ flow_shift = 1.0
25
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
26
 
27
+ # Move the base pipeline to the GPU. ZeroGPU will manage this.
28
+ # This is the critical step that puts the model into a sharded state.
29
+ print("Moving pipeline to device (ZeroGPU will handle offloading)...")
30
  pipe.to(device)
31
 
32
+ # --- LORA SETUP ---
33
+ CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
34
+ CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  DEFAULT_LORA_NAME = "causvid_lora"
36
  CUSTOM_LORA_NAME = "custom_lora"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ print("Initialization complete. Gradio is starting...")
39
 
40
+ # The decorated function that will run on the GPU. It only does inference.
41
  @spaces.GPU()
42
+ def generate(prompt, negative_prompt, width, height, num_inference_steps):
43
+ print("--- Inside generate() [GPU function] ---")
44
+
45
+ # The `pipe` object should already be configured with LoRAs by `call_infer`.
46
+ # This function's only job is to run the pipeline.
47
+
48
+ apply_cache_on_pipe(pipe)
49
+
50
+ print("Running inference...")
51
+ output = pipe(
52
+ prompt=prompt,
53
+ negative_prompt=negative_prompt,
54
+ height=height,
55
+ width=width,
56
+ num_frames=1,
57
+ num_inference_steps=num_inference_steps,
58
+ guidance_scale=1.0,
59
  )
60
+
61
+ image = output.frames[0][0]
62
+ image = (image * 255).astype(np.uint8)
63
+ return Image.fromarray(image)
64
+
65
+ # The wrapper function that the Gradio UI calls. It handles LoRA logic.
66
+ def call_infer(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
67
+ print("--- Inside call_infer() [CPU function] ---")
68
+
69
  try:
70
+ # This section attempts to load LoRAs dynamically into the ZeroGPU-managed model.
71
+ # This is the expected point of failure.
72
+ clean_lora_id = lora_id.strip() if lora_id else ""
73
+ print("Loading base LoRA for this run...")
74
+ causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
75
+ pipe.load_lora_weights(causvid_path, adapter_name=DEFAULT_LORA_NAME)
76
+
77
+ # If a custom LoRA is provided, load it as well.
78
+ if clean_lora_id:
79
+ print(f"Loading custom LoRA '{clean_lora_id}' for this run...")
80
+ pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME)
81
+ # If a custom LoRA is present, activate both.
82
+ print("Activating both LoRAs...")
83
+ pipe.set_adapters([DEFAULT_LORA_NAME, CUSTOM_LORA_NAME], adapter_weights=[1.0, 1.0])
84
+ else:
85
+ # If no custom LoRA, just activate the base one.
86
+ print("Activating base LoRA only.")
87
+ pipe.set_adapters([DEFAULT_LORA_NAME], adapter_weights=[1.0])
88
+
89
+ print("LoRA setup complete. Calling the GPU function...")
90
+ # Now, call the decorated function to perform the actual generation
91
+ return generate(prompt, negative_prompt, width, height, num_inference_steps)
92
+
93
+ except Exception as e:
94
+ print(f"ERROR DURING INFERENCE SETUP: {e}")
95
+ raise gr.Error(f"Failed during LoRA loading or inference: {e}")
96
+
97
  finally:
98
+ # --- CLEANUP ---
99
+ # This will run after `generate` has finished.
100
+ print("Unloading all LoRAs to clean up...")
101
+ pipe.unload_lora_weights()
102
+ gc.collect()
103
+ torch.cuda.empty_cache()
104
+ print("Cleanup complete.")
105
 
106
 
107
+ # The interface is now pointed at the `call_infer` wrapper function.
108
  iface = gr.Interface(
109
+ fn=call_infer,
110
  inputs=[
111
  gr.Textbox(label="Input prompt"),
 
 
112
  gr.Textbox(label="Negative prompt", value = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"),
113
  gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
114
  gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
115
  gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
116
+ gr.Textbox(label="LoRA ID (e.g., ostris/super-lora)"),
117
  ],
118
  outputs=gr.Image(label="output"),
119
+ title="Wan 2.1 Image Generator (Wrapper Function Test)",
120
+ description="A test to dynamically load LoRAs in a wrapper function before calling the GPU-decorated function."
121
  )
122
 
123
  iface.launch()