Wan22-Light / app_fast1.py
rahul7star's picture
Update app_fast1.py
fa7fb27 verified
import os
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
import spaces
import gradio as gr
import torch
import numpy as np
import tempfile
import random
import gc
from diffusers import DiffusionPipeline
from diffusers.hooks import apply_group_offloading
from diffusers.utils import export_to_video
from diffusers.quantizers import PipelineQuantizationConfig
from transformers import UMT5EncoderModel
from PIL import Image
device = "cuda" if torch.cuda.is_available() else "cpu"
# --- DEFAULT PROMPTS ---
default_prompt_t2v = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
# --- CONSTANTS ---
FIXED_FPS = 16
MIN_FRAMES_MODEL = 10
MAX_FRAMES_MODEL = 200
MIN_DURATION = MIN_FRAMES_MODEL / FIXED_FPS
MAX_DURATION = MAX_FRAMES_MODEL / FIXED_FPS
MAX_SEED = 2147483647
LANDSCAPE_HEIGHT = 512
LANDSCAPE_WIDTH = 512
# --- SETUP PIPELINE ---
torch._dynamo.config.cache_size_limit = 1000
torch._dynamo.config.capture_dynamic_output_shape_ops = True
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
components_to_quantize=["transformer", "text_encoder"],
)
text_encoder = UMT5EncoderModel.from_pretrained(
"Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16
)
pipeline = DiffusionPipeline.from_pretrained(
"Wan-AI/Wan2.1-T2V-14B-Diffusers",
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
).to("cuda" if torch.cuda.is_available() else "cpu")
# Group offloading
onload_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
offload_device = torch.device("cpu")
pipeline.transformer.enable_group_offload(onload_device, offload_device, offload_type="leaf_level", use_stream=True, non_blocking=True)
pipeline.vae.enable_group_offload(onload_device, offload_device, offload_type="leaf_level", use_stream=True, non_blocking=True)
apply_group_offloading(pipeline.text_encoder, onload_device, offload_type="leaf_level", use_stream=True, non_blocking=True)
pipeline.transformer.compile()
# --- HELPER FUNCTIONS ---
def get_duration(prompt, negative_prompt, duration_seconds, guidance_scale, guidance_scale_2, steps, seed, randomize_seed):
# Rough GPU runtime estimation (example)
return 10 + (steps * 2) + (duration_seconds * 5)
@spaces.GPU(duration=get_duration)
def generate_video(
prompt,
negative_prompt=default_negative_prompt,
duration_seconds=MAX_DURATION,
guidance_scale=1,
guidance_scale_2=3,
steps=4,
seed=42,
randomize_seed=False,
):
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
pipeline.to(device)
generator = torch.Generator(device=device).manual_seed(current_seed)
try:
with torch.no_grad():
frames = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=LANDSCAPE_HEIGHT,
width=LANDSCAPE_WIDTH,
num_frames=num_frames,
guidance_scale=float(guidance_scale),
guidance_scale_2=float(guidance_scale_2),
num_inference_steps=int(steps),
generator=generator,
).frames
# Convert tensors to PIL images if necessary
if isinstance(frames[0], torch.Tensor):
frames = [Image.fromarray(frame.cpu().numpy().astype(np.uint8)) for frame in frames]
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
export_to_video(frames, tmp_file.name, fps=FIXED_FPS)
# Clean up GPU memory
if device == "cuda":
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
return tmp_file.name, current_seed
except Exception as e:
if device == "cuda":
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
raise gr.Error(f"Video generation failed: {e}")
# --- GRADIO UI ---
with gr.Blocks() as demo:
gr.Markdown("# Wan 2.1 Text-to-Video Generator 🎬")
gr.Markdown("Generate videos in a few steps using Wan 2.1 T2V 14B model with quantization and GPU offloading.")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Prompt", value=default_prompt_t2v)
duration_input = gr.Slider(MIN_DURATION, MAX_DURATION, value=MAX_DURATION, step=0.1, label="Duration (seconds)")
with gr.Accordion("Advanced Settings", open=False):
negative_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
steps_input = gr.Slider(1, 30, value=4, label="Inference Steps")
guidance_input = gr.Slider(0.0, 10.0, value=1, step=0.5, label="Guidance Scale - High Noise Stage")
guidance2_input = gr.Slider(0.0, 10.0, value=3, step=0.5, label="Guidance Scale 2 - Low Noise Stage")
seed_input = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed")
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
generate_button = gr.Button("Generate Video", variant="primary")
with gr.Column():
video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
ui_inputs = [prompt_input, negative_input, duration_input, guidance_input, guidance2_input, steps_input, seed_input, randomize_seed]
generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
gr.Examples(
examples=[
["POV selfie video, white cat with sunglasses standing on surfboard, tropical beach."],
["Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."],
["Cinematic shot of a boat sailing on a calm sea at sunset."],
["Drone footage flying over a futuristic city with flying cars."],
],
inputs=[prompt_input],
outputs=[video_output, seed_input],
fn=generate_video,
cache_examples="lazy"
)
if __name__ == "__main__":
demo.queue().launch()