Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,332 Bytes
812e69e ed01c6d 85f6fcb 1b24a66 812e69e ab613df 69a9a62 ed01c6d ab613df ed01c6d 5d27053 ed01c6d 5d27053 adefe82 ed01c6d bf7c515 ed01c6d bf7c515 ed01c6d bf7c515 ed01c6d bf7c515 ed01c6d bf7c515 4166d00 ed01c6d 09fcd4a ed01c6d 4294cf6 ed01c6d bf7c515 ed01c6d ec1fc83 ed01c6d ec1fc83 ed01c6d 4166d00 bf7c515 ed01c6d 812e69e ed01c6d 812e69e 4294cf6 28ae721 75ef805 ed01c6d 78004f2 812e69e ed01c6d 812e69e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import torch
from diffusers import UniPCMultistepScheduler
from diffusers import WanPipeline, AutoencoderKLWan
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
from huggingface_hub import hf_hub_download
from PIL import Image
import numpy as np
import gradio as gr
import spaces
import gc
# --- INITIAL SETUP ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
print("Loading VAE...")
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
print("Loading WanPipeline in bfloat16...")
# This will use ZeroGPU/accelerate with meta devices
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
flow_shift = 1.0
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
# Move the base pipeline to the GPU. ZeroGPU will manage this.
# This is the critical step that puts the model into a sharded state.
print("Moving pipeline to device (ZeroGPU will handle offloading)...")
pipe.to(device)
# --- LORA SETUP ---
CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
DEFAULT_LORA_NAME = "causvid_lora"
CUSTOM_LORA_NAME = "custom_lora"
print("Initialization complete. Gradio is starting...")
# The decorated function that will run on the GPU. It only does inference.
@spaces.GPU()
def generate(prompt, negative_prompt, width, height, num_inference_steps):
print("--- Inside generate() [GPU function] ---")
# The `pipe` object should already be configured with LoRAs by `call_infer`.
# This function's only job is to run the pipeline.
apply_cache_on_pipe(pipe)
print("Running inference...")
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=1,
num_inference_steps=num_inference_steps,
guidance_scale=1.0,
)
image = output.frames[0][0]
image = (image * 255).astype(np.uint8)
return Image.fromarray(image)
# The wrapper function that the Gradio UI calls. It handles LoRA logic.
def call_infer(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
print("--- Inside call_infer() [CPU function] ---")
try:
# This section attempts to load LoRAs dynamically into the ZeroGPU-managed model.
# This is the expected point of failure.
clean_lora_id = lora_id.strip() if lora_id else ""
print("Loading base LoRA for this run...")
causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
pipe.load_lora_weights(causvid_path, adapter_name=DEFAULT_LORA_NAME)
# If a custom LoRA is provided, load it as well.
if clean_lora_id:
print(f"Loading custom LoRA '{clean_lora_id}' for this run...")
pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME)
# If a custom LoRA is present, activate both.
print("Activating both LoRAs...")
pipe.set_adapters([DEFAULT_LORA_NAME, CUSTOM_LORA_NAME], adapter_weights=[1.0, 1.0])
else:
# If no custom LoRA, just activate the base one.
print("Activating base LoRA only.")
pipe.set_adapters([DEFAULT_LORA_NAME], adapter_weights=[1.0])
print("LoRA setup complete. Calling the GPU function...")
# Now, call the decorated function to perform the actual generation
return generate(prompt, negative_prompt, width, height, num_inference_steps)
except Exception as e:
print(f"ERROR DURING INFERENCE SETUP: {e}")
raise gr.Error(f"Failed during LoRA loading or inference: {e}")
finally:
# --- CLEANUP ---
# This will run after `generate` has finished.
print("Unloading all LoRAs to clean up...")
pipe.unload_lora_weights()
gc.collect()
torch.cuda.empty_cache()
print("Cleanup complete.")
# The interface is now pointed at the `call_infer` wrapper function.
iface = gr.Interface(
fn=call_infer,
inputs=[
gr.Textbox(label="Input prompt"),
gr.Textbox(label="Negative prompt", value = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"),
gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
gr.Textbox(label="LoRA ID (e.g., ostris/super-lora)"),
],
outputs=gr.Image(label="output"),
title="Wan 2.1 Image Generator (Wrapper Function Test)",
description="A test to dynamically load LoRAs in a wrapper function before calling the GPU-decorated function."
)
iface.launch() |