ovi054 commited on
Commit
4294cf6
·
verified ·
1 Parent(s): b0aa5c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -66
app.py CHANGED
@@ -14,76 +14,52 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
14
  print(f"Using device: {device}")
15
 
16
  model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
17
- print("Loading VAE...")
18
- vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
19
 
20
- print("Loading WanPipeline in bfloat16...")
21
- # This will use ZeroGPU/accelerate with meta devices
22
- pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
 
 
 
23
 
24
- flow_shift = 1.0
25
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
26
-
27
- # Move the base pipeline to the GPU. ZeroGPU will manage this.
28
- print("Moving pipeline to device (ZeroGPU will handle offloading)...")
29
- pipe.to(device)
30
-
31
- # --- LORA SETUP ---
32
  CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
33
  CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
34
- BASE_LORA_NAME = "causvid_lora"
35
- CUSTOM_LORA_NAME = "custom_lora"
36
-
37
- print("Downloading base LoRA...")
38
  try:
39
  causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
40
- print("✅ Base LoRA downloaded.")
 
 
 
 
 
 
 
 
 
41
  except Exception as e:
42
- causvid_path = None
43
- print(f"⚠️ Could not download base LoRA: {e}")
44
 
45
- print("Initialization complete. Gradio is starting...")
 
 
 
46
 
47
- @spaces.GPU()
48
- def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
49
-
50
- # --- DYNAMIC LORA MANAGEMENT FOR EACH RUN ---
51
- active_adapters = []
52
- adapter_weights = []
53
 
54
- # 1. Load the Base LoRA directly onto the correct device
55
- if causvid_path:
56
- try:
57
- print(f"Loading base LoRA '{BASE_LORA_NAME}'...")
58
- pipe.load_lora_weights(causvid_path, adapter_name=BASE_LORA_NAME, device_map={"":device})
59
- active_adapters.append(BASE_LORA_NAME)
60
- adapter_weights.append(1.0)
61
- print("✅ Base LoRA loaded to device.")
62
- except Exception as e:
63
- print(f"⚠️ Failed to load base LoRA: {e}")
64
 
65
- # 2. Load the Custom LoRA if provided, also directly to the device
66
- clean_lora_id = lora_id.strip() if lora_id else ""
67
- if clean_lora_id:
68
- try:
69
- print(f"Loading custom LoRA '{CUSTOM_LORA_NAME}' from '{clean_lora_id}'...")
70
- pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME, device_map={"":device})
71
- active_adapters.append(CUSTOM_LORA_NAME)
72
- adapter_weights.append(1.0)
73
- print("✅ Custom LoRA loaded to device.")
74
- except Exception as e:
75
- print(f"⚠️ Failed to load custom LoRA '{clean_lora_id}': {e}")
76
- if CUSTOM_LORA_NAME in getattr(pipe.transformer, 'peft_config', {}):
77
- del pipe.transformer.peft_config[CUSTOM_LORA_NAME]
78
 
79
- # 3. Activate the successfully loaded adapters
80
- if active_adapters:
81
- print(f"Activating adapters: {active_adapters} with weights: {adapter_weights}")
82
- pipe.set_adapters(active_adapters, adapter_weights)
83
- pipe.transformer.to(device) # Explicitly move transformer to GPU after setting adapters
84
- else:
85
- # Ensure LoRA is disabled if no adapters were loaded
86
- pipe.disable_lora()
87
 
88
  apply_cache_on_pipe(pipe)
89
 
@@ -101,22 +77,18 @@ def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_ste
101
  image = (image * 255).astype(np.uint8)
102
  return Image.fromarray(image)
103
  finally:
104
- # --- PROPER CLEANUP ---
105
- print("Unloading all LoRAs to ensure a clean state...")
106
- pipe.unload_lora_weights()
107
- gc.collect() # Force garbage collection
108
- torch.cuda.empty_cache() # Clear CUDA cache
109
- print("✅ LoRAs unloaded and memory cleaned.")
110
 
 
111
  iface = gr.Interface(
112
  fn=generate,
113
  inputs=[
114
  gr.Textbox(label="Input prompt"),
115
- gr.Textbox(label="Negative prompt", value="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"),
116
  gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
117
  gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
118
  gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
119
- gr.Textbox(label="LoRA ID (Optional, loads dynamically)"),
120
  ],
121
  outputs=gr.Image(label="output"),
122
  )
 
14
  print(f"Using device: {device}")
15
 
16
  model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
 
 
17
 
18
+ # --- PRE-FUSION STRATEGY ---
19
+ # 1. Load everything on the CPU first. Do NOT use device_map or torch_dtype here yet.
20
+ # This prevents ZeroGPU/accelerate from taking over too early.
21
+ print("Loading all components on CPU for pre-fusion...")
22
+ vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae")
23
+ pipe = WanPipeline.from_pretrained(model_id, vae=vae)
24
 
25
+ # 2. Load and FUSE the LoRA while the entire pipeline is still on the CPU.
26
+ print("Loading and fusing base LoRA on CPU...")
 
 
 
 
 
 
27
  CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
28
  CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
 
 
 
 
29
  try:
30
  causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
31
+ pipe.load_lora_weights(causvid_path)
32
+ print("✅ Base LoRA loaded on CPU.")
33
+
34
+ pipe.fuse_lora()
35
+ print("✅ Base LoRA fused on CPU.")
36
+
37
+ # After fusing, we must unload the original LoRA weights to save memory
38
+ pipe.unload_lora_weights()
39
+ print("✅ Original LoRA weights unloaded.")
40
+
41
  except Exception as e:
42
+ print(f"⚠️ Could not pre-fuse the LoRA: {e}. The model will run without it.")
 
43
 
44
+ # 3. NOW, move the final, fused model to the target device and set the dtype.
45
+ # ZeroGPU will now take over a single, coherent model.
46
+ print("Moving the final, fused pipeline to device and setting dtype...")
47
+ pipe.to(device=device, dtype=torch.bfloat16)
48
 
49
+ # 4. Configure the scheduler
50
+ flow_shift = 1.0
51
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
 
 
 
52
 
53
+ # Clean up memory before starting the server
54
+ del vae
55
+ gc.collect()
56
+ torch.cuda.empty_cache()
 
 
 
 
 
 
57
 
58
+ print("Initialization complete. Gradio is starting...")
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ @spaces.GPU()
61
+ def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, progress=gr.Progress(track_tqdm=True)):
62
+ # The LoRA is permanently fused. No dynamic loading is possible or needed.
 
 
 
 
 
63
 
64
  apply_cache_on_pipe(pipe)
65
 
 
77
  image = (image * 255).astype(np.uint8)
78
  return Image.fromarray(image)
79
  finally:
80
+ # No cleanup needed as the state is static
81
+ pass
 
 
 
 
82
 
83
+ # Interface is simplified as the custom LoRA option is removed for stability.
84
  iface = gr.Interface(
85
  fn=generate,
86
  inputs=[
87
  gr.Textbox(label="Input prompt"),
88
+ gr.Textbox(label="Negative prompt", value = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"),
89
  gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
90
  gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
91
  gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
 
92
  ],
93
  outputs=gr.Image(label="output"),
94
  )