File size: 6,235 Bytes
812e69e
17aa94d
09fcd4a
85f6fcb
1b24a66
812e69e
 
ab613df
69a9a62
ab613df
09fcd4a
5d27053
09fcd4a
5d27053
adefe82
09fcd4a
1947659
09fcd4a
adefe82
4166d00
1947659
09fcd4a
 
ef22e42
2aad192
4166d00
 
5d27053
 
4166d00
 
15a6d13
f2c0f66
4166d00
adefe82
ffc79bb
4166d00
09fcd4a
 
4166d00
09fcd4a
4166d00
 
09fcd4a
 
812e69e
4166d00
 
 
 
 
 
 
 
 
 
 
 
 
812e69e
ec1fc83
09fcd4a
4166d00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63d55b2
09fcd4a
4166d00
09fcd4a
 
4166d00
 
adefe82
4166d00
 
09fcd4a
4166d00
 
 
 
09fcd4a
4166d00
 
 
 
 
09fcd4a
 
ec1fc83
 
 
 
 
 
 
 
09fcd4a
ec1fc83
 
 
 
 
4166d00
 
 
 
 
 
 
 
 
 
 
 
09fcd4a
812e69e
 
 
 
25761d6
28ae721
 
75ef805
adefe82
78004f2
812e69e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
from diffusers import UniPCMultistepScheduler
from diffusers import WanPipeline, AutoencoderKLWan
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
from huggingface_hub import hf_hub_download
from PIL import Image
import numpy as np
import gradio as gr
import spaces

# --- INITIAL SETUP ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
print("Loading VAE...")
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)

print("Loading WanPipeline in bfloat16...")
# This will use ZeroGPU/accelerate with meta devices
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)

flow_shift = 1.0
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)

# Move the base pipeline to the GPU. ZeroGPU will manage this.
print("Moving pipeline to device (ZeroGPU will handle offloading)...")
pipe.to(device)

# --- LORA SETUP ---
# We will NOT fuse anything. Everything will be handled dynamically.
CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
BASE_LORA_NAME = "causvid_lora"
CUSTOM_LORA_NAME = "custom_lora"

print("Downloading base LoRA...")
try:
    causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
    print("✅ Base LoRA downloaded.")
except Exception as e:
    causvid_path = None
    print(f"⚠️ Could not download base LoRA: {e}")

print("Initialization complete. Gradio is starting...")

def move_adapter_to_device(pipe, adapter_name, device):
    """
    Surgically moves only the parameters of a specific LoRA adapter to the target device.
    This avoids touching the base model's meta tensors.
    """
    print(f"Moving adapter '{adapter_name}' to {device}...")
    for param in pipe.transformer.parameters():
        if hasattr(param, "adapter_name") and param.adapter_name == adapter_name:
            param.data = param.data.to(device, non_blocking=True)
            if param.grad is not None:
                param.grad.data = param.grad.data.to(device, non_blocking=True)
    print(f"✅ Adapter '{adapter_name}' moved.")

@spaces.GPU()
def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
    
    # --- DYNAMIC LORA MANAGEMENT FOR EACH RUN ---
    # Start with a clean slate by disabling any active adapters from previous runs
    pipe.disable_lora()
    
    active_adapters = []
    adapter_weights = []

    # 1. Load the Base LoRA
    if causvid_path:
        try:
            # We load it for every run to ensure a clean state
            print(f"Loading base LoRA '{BASE_LORA_NAME}'...")
            pipe.load_lora_weights(causvid_path, adapter_name=BASE_LORA_NAME)
            
            # THE CRITICAL FIX: Move only this adapter's weights to the GPU
            move_adapter_to_device(pipe, BASE_LORA_NAME, device)

            active_adapters.append(BASE_LORA_NAME)
            adapter_weights.append(1.0)
        except Exception as e:
            print(f"⚠️ Failed to load base LoRA: {e}")

    # 2. Load the Custom LoRA if provided
    clean_lora_id = lora_id.strip() if lora_id else ""
    if clean_lora_id:
        try:
            print(f"Loading custom LoRA '{CUSTOM_LORA_NAME}' from '{clean_lora_id}'...")
            pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME)
            
            # THE CRITICAL FIX: Move only this adapter's weights to the GPU
            move_adapter_to_device(pipe, CUSTOM_LORA_NAME, device)

            active_adapters.append(CUSTOM_LORA_NAME)
            adapter_weights.append(1.0)
        except Exception as e:
            print(f"⚠️ Failed to load custom LoRA '{clean_lora_id}': {e}")
            # If it fails, delete the adapter config to prevent issues
            if CUSTOM_LORA_NAME in pipe.transformer.peft_config:
                del pipe.transformer.peft_config[CUSTOM_LORA_NAME]

    # 3. Activate the successfully loaded adapters
    if active_adapters:
        print(f"Activating adapters: {active_adapters} with weights: {adapter_weights}")
        pipe.set_adapters(active_adapters, adapter_weights)
    
    apply_cache_on_pipe(pipe)
    
    try:
        output = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            height=height,
            width=width,
            num_frames=1,
            num_inference_steps=num_inference_steps,
            guidance_scale=1.0,
        )
        image = output.frames[0][0]
        image = (image * 255).astype(np.uint8)
        return Image.fromarray(image)
    finally:
        # --- PROPER CLEANUP ---
        print("Cleaning up LoRAs for this run...")
        # Disable adapters to stop them from being used
        pipe.disable_lora()
        
        # Delete the LoRA configs from the model to truly unload them
        if BASE_LORA_NAME in pipe.transformer.peft_config:
            del pipe.transformer.peft_config[BASE_LORA_NAME]
        if CUSTOM_LORA_NAME in pipe.transformer.peft_config:
            del pipe.transformer.peft_config[CUSTOM_LORA_NAME]
        print("✅ LoRAs cleaned up.")


iface = gr.Interface(
    fn=generate,
    inputs=[
        gr.Textbox(label="Input prompt"),
        gr.Textbox(label="Negative prompt", value = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"),
        gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
        gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
        gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
        gr.Textbox(label="LoRA ID (Optional)"),
    ],
    outputs=gr.Image(label="output"),
)

iface.launch()