Spaces:
Paused
Paused
| import os | |
| os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces') | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import tempfile | |
| import random | |
| import gc | |
| from diffusers import DiffusionPipeline | |
| from diffusers.hooks import apply_group_offloading | |
| from diffusers.utils import export_to_video | |
| from diffusers.quantizers import PipelineQuantizationConfig | |
| from transformers import UMT5EncoderModel | |
| from PIL import Image | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # --- DEFAULT PROMPTS --- | |
| default_prompt_t2v = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." | |
| default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走" | |
| # --- CONSTANTS --- | |
| FIXED_FPS = 16 | |
| MIN_FRAMES_MODEL = 10 | |
| MAX_FRAMES_MODEL = 200 | |
| MIN_DURATION = MIN_FRAMES_MODEL / FIXED_FPS | |
| MAX_DURATION = MAX_FRAMES_MODEL / FIXED_FPS | |
| MAX_SEED = 2147483647 | |
| LANDSCAPE_HEIGHT = 512 | |
| LANDSCAPE_WIDTH = 512 | |
| # --- SETUP PIPELINE --- | |
| torch._dynamo.config.cache_size_limit = 1000 | |
| torch._dynamo.config.capture_dynamic_output_shape_ops = True | |
| pipeline_quant_config = PipelineQuantizationConfig( | |
| quant_backend="bitsandbytes_4bit", | |
| quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16}, | |
| components_to_quantize=["transformer", "text_encoder"], | |
| ) | |
| text_encoder = UMT5EncoderModel.from_pretrained( | |
| "Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16 | |
| ) | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| "Wan-AI/Wan2.1-T2V-14B-Diffusers", | |
| quantization_config=pipeline_quant_config, | |
| torch_dtype=torch.bfloat16, | |
| ).to("cuda" if torch.cuda.is_available() else "cpu") | |
| # Group offloading | |
| onload_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| offload_device = torch.device("cpu") | |
| pipeline.transformer.enable_group_offload(onload_device, offload_device, offload_type="leaf_level", use_stream=True, non_blocking=True) | |
| pipeline.vae.enable_group_offload(onload_device, offload_device, offload_type="leaf_level", use_stream=True, non_blocking=True) | |
| apply_group_offloading(pipeline.text_encoder, onload_device, offload_type="leaf_level", use_stream=True, non_blocking=True) | |
| pipeline.transformer.compile() | |
| # --- HELPER FUNCTIONS --- | |
| def get_duration(prompt, negative_prompt, duration_seconds, guidance_scale, guidance_scale_2, steps, seed, randomize_seed): | |
| # Rough GPU runtime estimation (example) | |
| return 10 + (steps * 2) + (duration_seconds * 5) | |
| def generate_video( | |
| prompt, | |
| negative_prompt=default_negative_prompt, | |
| duration_seconds=MAX_DURATION, | |
| guidance_scale=1, | |
| guidance_scale_2=3, | |
| steps=4, | |
| seed=42, | |
| randomize_seed=False, | |
| ): | |
| num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) | |
| current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) | |
| pipeline.to(device) | |
| generator = torch.Generator(device=device).manual_seed(current_seed) | |
| try: | |
| with torch.no_grad(): | |
| frames = pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=LANDSCAPE_HEIGHT, | |
| width=LANDSCAPE_WIDTH, | |
| num_frames=num_frames, | |
| guidance_scale=float(guidance_scale), | |
| guidance_scale_2=float(guidance_scale_2), | |
| num_inference_steps=int(steps), | |
| generator=generator, | |
| ).frames | |
| # Convert tensors to PIL images if necessary | |
| if isinstance(frames[0], torch.Tensor): | |
| frames = [Image.fromarray(frame.cpu().numpy().astype(np.uint8)) for frame in frames] | |
| tmp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
| export_to_video(frames, tmp_file.name, fps=FIXED_FPS) | |
| # Clean up GPU memory | |
| if device == "cuda": | |
| gc.collect() | |
| torch.cuda.synchronize() | |
| torch.cuda.empty_cache() | |
| return tmp_file.name, current_seed | |
| except Exception as e: | |
| if device == "cuda": | |
| gc.collect() | |
| torch.cuda.synchronize() | |
| torch.cuda.empty_cache() | |
| raise gr.Error(f"Video generation failed: {e}") | |
| # --- GRADIO UI --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Wan 2.1 Text-to-Video Generator 🎬") | |
| gr.Markdown("Generate videos in a few steps using Wan 2.1 T2V 14B model with quantization and GPU offloading.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox(label="Prompt", value=default_prompt_t2v) | |
| duration_input = gr.Slider(MIN_DURATION, MAX_DURATION, value=MAX_DURATION, step=0.1, label="Duration (seconds)") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| negative_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3) | |
| steps_input = gr.Slider(1, 30, value=4, label="Inference Steps") | |
| guidance_input = gr.Slider(0.0, 10.0, value=1, step=0.5, label="Guidance Scale - High Noise Stage") | |
| guidance2_input = gr.Slider(0.0, 10.0, value=3, step=0.5, label="Guidance Scale 2 - Low Noise Stage") | |
| seed_input = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed") | |
| randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) | |
| generate_button = gr.Button("Generate Video", variant="primary") | |
| with gr.Column(): | |
| video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False) | |
| ui_inputs = [prompt_input, negative_input, duration_input, guidance_input, guidance2_input, steps_input, seed_input, randomize_seed] | |
| generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input]) | |
| gr.Examples( | |
| examples=[ | |
| ["POV selfie video, white cat with sunglasses standing on surfboard, tropical beach."], | |
| ["Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."], | |
| ["Cinematic shot of a boat sailing on a calm sea at sunset."], | |
| ["Drone footage flying over a futuristic city with flying cars."], | |
| ], | |
| inputs=[prompt_input], | |
| outputs=[video_output, seed_input], | |
| fn=generate_video, | |
| cache_examples="lazy" | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |