import os os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces') import spaces import gradio as gr import torch import numpy as np import tempfile import random import gc from diffusers import DiffusionPipeline from diffusers.hooks import apply_group_offloading from diffusers.utils import export_to_video from diffusers.quantizers import PipelineQuantizationConfig from transformers import UMT5EncoderModel from PIL import Image device = "cuda" if torch.cuda.is_available() else "cpu" # --- DEFAULT PROMPTS --- default_prompt_t2v = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走" # --- CONSTANTS --- FIXED_FPS = 16 MIN_FRAMES_MODEL = 10 MAX_FRAMES_MODEL = 200 MIN_DURATION = MIN_FRAMES_MODEL / FIXED_FPS MAX_DURATION = MAX_FRAMES_MODEL / FIXED_FPS MAX_SEED = 2147483647 LANDSCAPE_HEIGHT = 512 LANDSCAPE_WIDTH = 512 # --- SETUP PIPELINE --- torch._dynamo.config.cache_size_limit = 1000 torch._dynamo.config.capture_dynamic_output_shape_ops = True pipeline_quant_config = PipelineQuantizationConfig( quant_backend="bitsandbytes_4bit", quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16}, components_to_quantize=["transformer", "text_encoder"], ) text_encoder = UMT5EncoderModel.from_pretrained( "Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16 ) pipeline = DiffusionPipeline.from_pretrained( "Wan-AI/Wan2.1-T2V-14B-Diffusers", quantization_config=pipeline_quant_config, torch_dtype=torch.bfloat16, ).to("cuda" if torch.cuda.is_available() else "cpu") # Group offloading onload_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") offload_device = torch.device("cpu") pipeline.transformer.enable_group_offload(onload_device, offload_device, offload_type="leaf_level", use_stream=True, non_blocking=True) pipeline.vae.enable_group_offload(onload_device, offload_device, offload_type="leaf_level", use_stream=True, non_blocking=True) apply_group_offloading(pipeline.text_encoder, onload_device, offload_type="leaf_level", use_stream=True, non_blocking=True) pipeline.transformer.compile() # --- HELPER FUNCTIONS --- def get_duration(prompt, negative_prompt, duration_seconds, guidance_scale, guidance_scale_2, steps, seed, randomize_seed): # Rough GPU runtime estimation (example) return 10 + (steps * 2) + (duration_seconds * 5) @spaces.GPU(duration=get_duration) def generate_video( prompt, negative_prompt=default_negative_prompt, duration_seconds=MAX_DURATION, guidance_scale=1, guidance_scale_2=3, steps=4, seed=42, randomize_seed=False, ): num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) pipeline.to(device) generator = torch.Generator(device=device).manual_seed(current_seed) try: with torch.no_grad(): frames = pipeline( prompt=prompt, negative_prompt=negative_prompt, height=LANDSCAPE_HEIGHT, width=LANDSCAPE_WIDTH, num_frames=num_frames, guidance_scale=float(guidance_scale), guidance_scale_2=float(guidance_scale_2), num_inference_steps=int(steps), generator=generator, ).frames # Convert tensors to PIL images if necessary if isinstance(frames[0], torch.Tensor): frames = [Image.fromarray(frame.cpu().numpy().astype(np.uint8)) for frame in frames] tmp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) export_to_video(frames, tmp_file.name, fps=FIXED_FPS) # Clean up GPU memory if device == "cuda": gc.collect() torch.cuda.synchronize() torch.cuda.empty_cache() return tmp_file.name, current_seed except Exception as e: if device == "cuda": gc.collect() torch.cuda.synchronize() torch.cuda.empty_cache() raise gr.Error(f"Video generation failed: {e}") # --- GRADIO UI --- with gr.Blocks() as demo: gr.Markdown("# Wan 2.1 Text-to-Video Generator 🎬") gr.Markdown("Generate videos in a few steps using Wan 2.1 T2V 14B model with quantization and GPU offloading.") with gr.Row(): with gr.Column(): prompt_input = gr.Textbox(label="Prompt", value=default_prompt_t2v) duration_input = gr.Slider(MIN_DURATION, MAX_DURATION, value=MAX_DURATION, step=0.1, label="Duration (seconds)") with gr.Accordion("Advanced Settings", open=False): negative_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3) steps_input = gr.Slider(1, 30, value=4, label="Inference Steps") guidance_input = gr.Slider(0.0, 10.0, value=1, step=0.5, label="Guidance Scale - High Noise Stage") guidance2_input = gr.Slider(0.0, 10.0, value=3, step=0.5, label="Guidance Scale 2 - Low Noise Stage") seed_input = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed") randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) generate_button = gr.Button("Generate Video", variant="primary") with gr.Column(): video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False) ui_inputs = [prompt_input, negative_input, duration_input, guidance_input, guidance2_input, steps_input, seed_input, randomize_seed] generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input]) gr.Examples( examples=[ ["POV selfie video, white cat with sunglasses standing on surfboard, tropical beach."], ["Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."], ["Cinematic shot of a boat sailing on a calm sea at sunset."], ["Drone footage flying over a futuristic city with flying cars."], ], inputs=[prompt_input], outputs=[video_output, seed_input], fn=generate_video, cache_examples="lazy" ) if __name__ == "__main__": demo.queue().launch()