Wan22-Light / app.py
rahul7star's picture
Update app.py
a4fe0cd verified
raw
history blame
2.06 kB
import spaces
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from diffusers.quantizers import PipelineQuantizationConfig
# Checkpoint ID
ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
# Configure quantization (bitsandbytes 4-bit)
quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16
},
components_to_quantize=["transformer", "text_encoder"]
)
# Load pipeline with quantization
pipe = DiffusionPipeline.from_pretrained(
ckpt_id,
quantization_config=quant_config,
torch_dtype=torch.bfloat16
).to("cuda")
# Optimize memory
pipe.enable_model_cpu_offload()
torch._dynamo.config.recompile_limit = 1000
torch._dynamo.config.capture_dynamic_output_shape_ops = True
pipe.transformer.compile()
# Gradio inference function
@spaces.GPU(duration=20)
def generate_video(prompt, seed):
generator = torch.manual_seed(seed) if seed else None
# Force ~2 second video (e.g., fps=8, frames=16)
num_frames = 16
fps = 8
video_frames = pipe(
prompt=prompt,
num_frames=num_frames,
generator=generator
).frames[0] # Take first video
# Save as GIF for Gradio preview
import imageio
out_path = "output.gif"
imageio.mimsave(out_path, video_frames, fps=fps)
return out_path
# Build Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🚀 Wan2.1 T2V - Text to Video Generator (2 sec duration, 4-bit quantized)")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", lines=3, value="A futuristic cityscape with flying cars and neon lights.")
seed = gr.Number(value=42, label="Seed (optional)")
run_btn = gr.Button("Generate Video")
with gr.Column():
output_video = gr.Video(label="Generated Video")
run_btn.click(fn=generate_video, inputs=[prompt, seed], outputs=output_video)
# Launch demo
demo.launch()