Wan22-Light / app_fast.py
rahul7star's picture
Create app_fast.py
326677b verified
# PyTorch 2.8 (temporary hack)
import os
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
# Actual demo code
import spaces
import torch
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
from diffusers.utils.export_utils import export_to_video
import gradio as gr
import tempfile
import numpy as np
from PIL import Image
import random
import gc
# Optimization imports
import intel_extension_for_pytorch as ipex
from ipex_llm.transformers.optimize import optimize_model
# --- Configuration ---
MODEL_ID = "Wan-AI/Wan2.2-T2V-A14B-Diffusers"
LANDSCAPE_WIDTH = 832
LANDSCAPE_HEIGHT = 480
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 81
MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS,1)
MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS,1)
# --- Model Loading ---
print("Loading VAE...")
vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers", subfolder="vae", torch_dtype=torch.float32)
print("Loading Transformer 1...")
transformer_1 = WanTransformer3DModel.from_pretrained(
'linoyts/Wan2.2-T2V-A14B-Diffusers-BF16',
subfolder='transformer',
torch_dtype=torch.bfloat16,
device_map='cpu', # Load to CPU first for optimization
)
print("Loading Transformer 2...")
transformer_2 = WanTransformer3DModel.from_pretrained(
'linoyts/Wan2.2-T2V-A14B-Diffusers-BF16',
subfolder='transformer_2',
torch_dtype=torch.bfloat16,
device_map='cpu', # Load to CPU first for optimization
)
print("Loading Wan Pipeline...")
pipe = WanPipeline.from_pretrained(MODEL_ID,
transformer=transformer_1,
transformer_2=transformer_2,
vae=vae,
torch_dtype=torch.bfloat16,
)
pipe.to("cpu") # Ensure pipeline is on CPU before optimization
# --- Optimization ---
print("Starting optimization...")
# Placeholder for optimization function - replace with actual ipex.optimize_model usage if available and suitable
# For this example, we simulate the optimization process without direct ipex.optimize_model for the whole pipeline
# as ipex.optimize_model is typically for models, not entire pipelines with multiple components.
# We will focus on enabling torch.compile and preparing for potential FP8 if it were supported.
# 1. FP8 Quantization (Simulated - as is_fp8_supported is False)
# If is_fp8_supported was True, you would apply FP8 quantization here to relevant components.
# Example (conceptual, actual implementation may vary):
# pipe.transformer = ipex.quantization.quantize_dynamic(pipe.transformer, {torch.nn.Linear}, dtype=torch.qint8)
# pipe.transformer_2 = ipex.quantization.quantize_dynamic(pipe.transformer_2, {torch.nn.Linear}, dtype=torch.qint8)
# pipe.vae = ipex.quantization.quantize_dynamic(pipe.vae, {torch.nn.Linear}, dtype=torch.qint8)
print("FP8 quantization is disabled as is_fp8_supported is False.")
# 2. AoT Compilation (using torch.compile)
# AoT compilation is achieved via torch.compile.
# We'll compile the pipeline's __call__ method or specific parts.
# Since WanPipeline is complex, we'll apply it to the core generation logic.
# Prepare a dummy input for tracing
example_prompt = "A test prompt for tracing."
example_negative_prompt = ""
example_height = LANDSCAPE_HEIGHT
example_width = LANDSCAPE_WIDTH
example_num_frames = MIN_FRAMES_MODEL
example_guidance_scale = 1.0
example_guidance_scale_2 = 3.0
example_num_inference_steps = 4
example_seed = 42
# Create a generator for the dummy input
example_generator = torch.Generator(device="cpu").manual_seed(example_seed)
# Move components to GPU if available for compilation, otherwise CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
print("Moving pipeline components to CUDA for compilation...")
pipe.to(device)
# Ensure VAE is also on the correct device if used separately during generation
vae.to(device)
# Compile the pipeline's __call__ method (or a subset of it)
# This is a more advanced optimization and might require careful tracing.
# For a full pipeline, direct compilation can be tricky. Often, individual
# modules (like UNet, Transformer) are compiled.
# Let's try to compile the core generation process within the pipeline.
# Note: This might need specific patching or a wrapper function depending on
# how WanPipeline is structured internally for compilation.
# Attempt to compile the pipeline directly. This might not work out-of-the-box
# for complex pipelines without specific optimizations for compilation.
try:
print("Attempting to compile the pipeline with torch.compile...")
# Wrap the call in a function for compile
def pipeline_call(prompt, negative_prompt, height, width, num_frames, guidance_scale, guidance_scale_2, num_inference_steps, generator):
return pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=guidance_scale,
guidance_scale_2=guidance_scale_2,
num_inference_steps=num_inference_steps,
generator=generator,
).frames[0]
# Using mode='reduce-overhead' for faster compilation and potentially good performance.
# Other modes like 'max-autotune' might yield better performance but take longer.
# 'aot_eager' can also be useful for debugging.
compiled_pipeline = torch.compile(pipeline_call, mode="reduce-overhead", fullgraph=True)
print("Pipeline compilation successful. Warming up...")
# Warm-up run
with torch.no_grad():
_ = compiled_pipeline(
prompt=example_prompt,
negative_prompt=example_negative_prompt,
height=example_height,
width=example_width,
num_frames=example_num_frames,
guidance_scale=example_guidance_scale,
guidance_scale_2=example_guidance_2,
num_inference_steps=example_num_inference_steps,
generator=example_generator,
)
print("Compilation warmup complete.")
# Replace the original pipeline call with the compiled version
# This requires modifying the generate_video function to use `compiled_pipeline`
# instead of `pipe`. We'll adjust the `generate_video` function below.
except Exception as e:
print(f"Torch compile failed: {e}")
print("Proceeding without torch.compile.")
compiled_pipeline = None # Fallback to uncompiled pipeline
else:
print("CUDA not available. Proceeding without GPU acceleration and torch.compile.")
compiled_pipeline = None
# --- Gradio App ---
default_prompt_t2v = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
def get_duration(
prompt,
negative_prompt,
duration_seconds,
guidance_scale,
guidance_scale_2,
steps,
seed,
randomize_seed,
):
# This function is used by the @spaces.GPU decorator to estimate runtime.
# The calculation should reflect the actual generation time based on steps and duration.
# A rough estimate could be: base_time + (steps * step_time) + (duration_seconds * frame_time)
# For simplicity, let's use a linear scaling based on steps and duration.
estimated_time = 10 + (steps * 2) + (duration_seconds * 5) # Example estimation
return estimated_time
@spaces.GPU(duration=get_duration)
def generate_video(
prompt,
negative_prompt=default_negative_prompt,
duration_seconds = MAX_DURATION,
guidance_scale = 1,
guidance_scale_2 = 3,
steps = 4,
seed = 42,
randomize_seed = False,
progress=gr.Progress(track_tqdm=True),
):
"""
Generate a video from a text prompt using the Wan 2.2 14B T2V model with Lightning LoRA.
This function takes an input prompt and generates a video animation based on the provided
prompt and parameters. It uses an FP8 qunatized Wan 2.2 14B Text-to-Video model with Lightning LoRA
for fast generation in 4-8 steps.
Args:
prompt (str): Text prompt describing the desired animation or motion.
negative_prompt (str, optional): Negative prompt to avoid unwanted elements.
Defaults to default_negative_prompt (contains unwanted visual artifacts).
duration_seconds (float, optional): Duration of the generated video in seconds.
Defaults to 2. Clamped between MIN_FRAMES_MODEL/FIXED_FPS and MAX_FRAMES_MODEL/FIXED_FPS.
guidance_scale (float, optional): Controls adherence to the prompt. Higher values = more adherence.
Defaults to 1.0. Range: 0.0-20.0.
guidance_scale_2 (float, optional): Controls adherence to the prompt. Higher values = more adherence.
Defaults to 1.0. Range: 0.0-20.0.
steps (int, optional): Number of inference steps. More steps = higher quality but slower.
Defaults to 4. Range: 1-30.
seed (int, optional): Random seed for reproducible results. Defaults to 42.
Range: 0 to MAX_SEED (2147483647).
randomize_seed (bool, optional): Whether to use a random seed instead of the provided seed.
Defaults to False.
progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
Returns:
tuple: A tuple containing:
- video_path (str): Path to the generated video file (.mp4)
- current_seed (int): The seed used for generation (useful when randomize_seed=True)
Raises:
gr.Error: If input_image is None (no image uploaded).
Note:
- The function automatically resizes the input image to the target dimensions
- Frame count is calculated as duration_seconds * FIXED_FPS (24)
- Output dimensions are adjusted to be multiples of MOD_VALUE (32)
- The function uses GPU acceleration via the @spaces.GPU decorator
- Generation time varies based on steps and duration (see get_duration function)
"""
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
# Ensure pipeline is on the correct device before generation
if device == "cuda":
pipe.to(device)
else:
pipe.to("cpu") # Explicitly move to CPU if CUDA is not available
# Prepare generator on the correct device
generator = torch.Generator(device=device).manual_seed(current_seed)
print(f"Generating video with prompt: '{prompt}'")
print(f"Settings: duration={duration_seconds}s ({num_frames} frames), steps={steps}, guidance_scale={guidance_scale}, guidance_scale_2={guidance_scale_2}, seed={current_seed}")
try:
if compiled_pipeline and device == "cuda":
print("Using compiled pipeline for generation.")
output_frames_list = compiled_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=LANDSCAPE_HEIGHT,
width=LANDSCAPE_WIDTH,
num_frames=num_frames,
guidance_scale=float(guidance_scale),
guidance_scale_2=float(guidance_scale_2),
num_inference_steps=int(steps),
generator=generator,
)
else:
print("Using uncompiled pipeline for generation.")
with torch.no_grad(): # Ensure no gradients are computed during inference
output_frames_list = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=LANDSCAPE_HEIGHT,
width=LANDSCAPE_WIDTH,
num_frames=num_frames,
guidance_scale=float(guidance_scale),
guidance_scale_2=float(guidance_scale_2),
num_inference_steps=int(steps),
generator=generator,
).frames[0]
# Convert frames to a format suitable for export_to_video
# The output `frames[0]` is usually a list of PIL Images or tensors
# Ensure they are PIL Images for export_to_video if needed
if isinstance(output_frames_list, torch.Tensor):
output_frames_list = [Image.fromarray(frame.cpu().numpy().astype(np.uint8)) for frame in output_frames_list]
elif isinstance(output_frames_list[0], torch.Tensor):
output_frames_list = [Image.fromarray(frame.cpu().numpy().astype(np.uint8)) for frame in output_frames_list]
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
video_path = tmpfile.name
print(f"Exporting video to {video_path}...")
export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
print("Video exported successfully.")
# Clean up GPU memory after generation
if device == "cuda":
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
return video_path, current_seed
except Exception as e:
print(f"An error occurred during video generation: {e}")
# Clean up GPU memory in case of error
if device == "cuda":
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
raise gr.Error(f"Video generation failed: {e}")
with gr.Blocks() as demo:
gr.Markdown("# Fast 4 steps Wan 2.2 T2V (14B) with Lightning LoRA")
gr.Markdown("run Wan 2.2 in just 4-8 steps, with [Wan 2.2 Lightning LoRA](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Wan22-Lightning), fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU⚡️")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Prompt", value=default_prompt_t2v)
duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=MAX_DURATION, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Inference Steps")
guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage")
guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=3, label="Guidance Scale 2 - low noise stage")
generate_button = gr.Button("Generate Video", variant="primary")
with gr.Column():
video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
ui_inputs = [
prompt_input,
negative_prompt_input, duration_seconds_input,
guidance_scale_input, guidance_scale_2_input, steps_slider, seed_input, randomize_seed_checkbox
]
generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
gr.Examples(
examples=[
[
"POV selfie video, white cat with sunglasses standing on surfboard, relaxed smile, tropical beach behind (clear water, green hills, blue sky with clouds). Surfboard tips, cat falls into ocean, camera plunges underwater with bubbles and sunlight beams. Brief underwater view of cat’s face, then cat resurfaces, still filming selfie, playful summer vacation mood.",
],
[
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
],
[
"A cinematic shot of a boat sailing on a calm sea at sunset.",
],
[
"Drone footage flying over a futuristic city with flying cars.",
],
],
inputs=[prompt_input], outputs=[video_output, seed_input], fn=generate_video, cache_examples="lazy"
)
if __name__ == "__main__":
# This part is crucial for using the optimized pipeline.
# The optimization (especially torch.compile) needs to happen before launching the Gradio app.
# The `generate_video` function is defined to use `compiled_pipeline` if available.
# Check if CUDA is available and move the pipeline to GPU for potential compilation
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
print("CUDA is available. Proceeding with GPU acceleration.")
# The compilation logic is handled within the script before the demo launch.
else:
print("CUDA not available. The application will run on CPU, which will be significantly slower.")
demo.queue().launch(mcp_server=True)