Spaces:
Paused
Paused
import spaces | |
import gradio as gr | |
import torch | |
from diffusers import DiffusionPipeline | |
from diffusers.quantizers import PipelineQuantizationConfig | |
# Checkpoint ID | |
ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" | |
# Configure quantization (bitsandbytes 4-bit) | |
quant_config = PipelineQuantizationConfig( | |
quant_backend="bitsandbytes_4bit", | |
quant_kwargs={ | |
"load_in_4bit": True, | |
"bnb_4bit_quant_type": "nf4", | |
"bnb_4bit_compute_dtype": torch.bfloat16 | |
}, | |
components_to_quantize=["transformer", "text_encoder"] | |
) | |
# Load pipeline with quantization | |
pipe = DiffusionPipeline.from_pretrained( | |
ckpt_id, | |
quantization_config=quant_config, | |
torch_dtype=torch.bfloat16 | |
).to("cuda") | |
# Optimize memory | |
pipe.enable_model_cpu_offload() | |
torch._dynamo.config.recompile_limit = 1000 | |
torch._dynamo.config.capture_dynamic_output_shape_ops = True | |
pipe.transformer.compile() | |
# Gradio inference function | |
def generate_video(prompt, seed): | |
generator = torch.manual_seed(seed) if seed else None | |
# Force ~2 second video (e.g., fps=8, frames=16) | |
num_frames = 16 | |
fps = 8 | |
video_frames = pipe( | |
prompt=prompt, | |
num_frames=num_frames, | |
generator=generator | |
).frames[0] # Take first video | |
# Save as GIF for Gradio preview | |
import imageio | |
out_path = "output.gif" | |
imageio.mimsave(out_path, video_frames, fps=fps) | |
return out_path | |
# Build Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("## 🚀 Wan2.1 T2V - Text to Video Generator (2 sec duration, 4-bit quantized)") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", lines=3, value="A futuristic cityscape with flying cars and neon lights.") | |
seed = gr.Number(value=42, label="Seed (optional)") | |
run_btn = gr.Button("Generate Video") | |
with gr.Column(): | |
output_video = gr.Video(label="Generated Video") | |
run_btn.click(fn=generate_video, inputs=[prompt, seed], outputs=output_video) | |
# Launch demo | |
demo.launch() | |