File size: 5,179 Bytes
812e69e
17aa94d
09fcd4a
85f6fcb
1b24a66
812e69e
 
ab613df
69a9a62
3e81ff5
ab613df
09fcd4a
5d27053
09fcd4a
5d27053
adefe82
09fcd4a
1947659
09fcd4a
adefe82
4166d00
1947659
09fcd4a
 
ef22e42
2aad192
4166d00
 
5d27053
 
4166d00
15a6d13
f2c0f66
4166d00
adefe82
ffc79bb
4166d00
09fcd4a
 
4166d00
09fcd4a
4166d00
 
09fcd4a
 
812e69e
 
ec1fc83
09fcd4a
4166d00
 
 
 
3e81ff5
4166d00
 
 
3e81ff5
4166d00
 
3e81ff5
4166d00
 
 
3e81ff5
4166d00
63d55b2
09fcd4a
4166d00
3e81ff5
4166d00
 
3e81ff5
09fcd4a
4166d00
3e81ff5
4166d00
09fcd4a
4166d00
 
 
 
b0aa5c4
3e81ff5
 
 
4166d00
09fcd4a
 
ec1fc83
 
 
 
 
 
 
 
09fcd4a
ec1fc83
 
 
 
 
4166d00
3e81ff5
 
b0aa5c4
 
3e81ff5
4166d00
812e69e
 
 
 
b0aa5c4
28ae721
 
75ef805
b0aa5c4
78004f2
812e69e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
from diffusers import UniPCMultistepScheduler
from diffusers import WanPipeline, AutoencoderKLWan
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
from huggingface_hub import hf_hub_download
from PIL import Image
import numpy as np
import gradio as gr
import spaces
import gc

# --- INITIAL SETUP ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
print("Loading VAE...")
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)

print("Loading WanPipeline in bfloat16...")
# This will use ZeroGPU/accelerate with meta devices
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)

flow_shift = 1.0
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)

# Move the base pipeline to the GPU. ZeroGPU will manage this.
print("Moving pipeline to device (ZeroGPU will handle offloading)...")
pipe.to(device)

# --- LORA SETUP ---
CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
BASE_LORA_NAME = "causvid_lora"
CUSTOM_LORA_NAME = "custom_lora"

print("Downloading base LoRA...")
try:
    causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
    print("✅ Base LoRA downloaded.")
except Exception as e:
    causvid_path = None
    print(f"⚠️ Could not download base LoRA: {e}")

print("Initialization complete. Gradio is starting...")

@spaces.GPU()
def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
    
    # --- DYNAMIC LORA MANAGEMENT FOR EACH RUN ---
    active_adapters = []
    adapter_weights = []

    # 1. Load the Base LoRA directly onto the correct device
    if causvid_path:
        try:
            print(f"Loading base LoRA '{BASE_LORA_NAME}'...")
            pipe.load_lora_weights(causvid_path, adapter_name=BASE_LORA_NAME, device_map={"":device})
            active_adapters.append(BASE_LORA_NAME)
            adapter_weights.append(1.0)
            print("✅ Base LoRA loaded to device.")
        except Exception as e:
            print(f"⚠️ Failed to load base LoRA: {e}")

    # 2. Load the Custom LoRA if provided, also directly to the device
    clean_lora_id = lora_id.strip() if lora_id else ""
    if clean_lora_id:
        try:
            print(f"Loading custom LoRA '{CUSTOM_LORA_NAME}' from '{clean_lora_id}'...")
            pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME, device_map={"":device})
            active_adapters.append(CUSTOM_LORA_NAME)
            adapter_weights.append(1.0)
            print("✅ Custom LoRA loaded to device.")
        except Exception as e:
            print(f"⚠️ Failed to load custom LoRA '{clean_lora_id}': {e}")
            if CUSTOM_LORA_NAME in getattr(pipe.transformer, 'peft_config', {}):
                del pipe.transformer.peft_config[CUSTOM_LORA_NAME]

    # 3. Activate the successfully loaded adapters
    if active_adapters:
        print(f"Activating adapters: {active_adapters} with weights: {adapter_weights}")
        pipe.set_adapters(active_adapters, adapter_weights)
        pipe.transformer.to(device)  # Explicitly move transformer to GPU after setting adapters
    else:
        # Ensure LoRA is disabled if no adapters were loaded
        pipe.disable_lora()
    
    apply_cache_on_pipe(pipe)
    
    try:
        output = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            height=height,
            width=width,
            num_frames=1,
            num_inference_steps=num_inference_steps,
            guidance_scale=1.0,
        )
        image = output.frames[0][0]
        image = (image * 255).astype(np.uint8)
        return Image.fromarray(image)
    finally:
        # --- PROPER CLEANUP ---
        print("Unloading all LoRAs to ensure a clean state...")
        pipe.unload_lora_weights()
        gc.collect()  # Force garbage collection
        torch.cuda.empty_cache()  # Clear CUDA cache
        print("✅ LoRAs unloaded and memory cleaned.")

iface = gr.Interface(
    fn=generate,
    inputs=[
        gr.Textbox(label="Input prompt"),
        gr.Textbox(label="Negative prompt", value="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"),
        gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
        gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
        gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
        gr.Textbox(label="LoRA ID (Optional, loads dynamically)"),
    ],
    outputs=gr.Image(label="output"),
)

iface.launch()