Wan22-Light / app.py
rahul7star's picture
Update app.py
b86a84e verified
raw
history blame
3.97 kB
import spaces
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from diffusers.quantizers import PipelineQuantizationConfig
import imageio
import numpy as np
import random
LANDSCAPE_WIDTH = 832
LANDSCAPE_HEIGHT = 480
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 81
MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS,1)
MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS,1)
# Checkpoint ID
ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
# Configure quantization (bitsandbytes 4-bit)
quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16
},
components_to_quantize=["transformer", "text_encoder"]
)
# Load pipeline with quantization
pipe = DiffusionPipeline.from_pretrained(
ckpt_id,
quantization_config=quant_config,
torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
torch._dynamo.config.recompile_limit = 1000
torch._dynamo.config.capture_dynamic_output_shape_ops = True
# Smart duration function using all UI params
def get_duration(prompt, height, width, negative_prompt, duration_seconds, guidance_scale, steps, seed, randomize_seed, progress):
# Calculate dynamic duration based on steps and requested duration
if duration_seconds <= 2.5:
return steps * 18
else:
return steps * 25
# Gradio inference function with spaces GPU decorator
@spaces.GPU(duration=get_duration)
def generate_video(prompt, height, width, negative_prompt, duration_seconds, guidance_scale, steps, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)):
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
output_frames_list = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=int(height),
width=int(width),
num_frames=num_frames,
guidance_scale=float(guidance_scale),
num_inference_steps=int(steps),
generator=torch.Generator(device="cuda").manual_seed(current_seed),
).frames[0]
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
video_path = tmpfile.name
export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
return video_path, current_seed
# Build Gradio UI with all parameters
with gr.Blocks(css="body { max-width: 100vw; overflow-x: hidden; }") as demo:
gr.Markdown("## 🚀 Wan2.1 T2V - Text to Video Generator (Quantized, Smart Duration)")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Prompt", lines=3, value="A futuristic cityscape with flying cars and neon lights.")
negative_prompt_input = gr.Textbox(label="Negative Prompt", lines=3, value="")
height_input = gr.Slider(256, 1024, step=8, value=512, label="Height")
width_input = gr.Slider(256, 1024, step=8, value=512, label="Width")
duration_input = gr.Slider(1, 10, value=2, step=0.1, label="Duration (seconds)")
steps_input = gr.Slider(1, 50, value=20, step=1, label="Inference Steps")
guidance_scale_input = gr.Slider(0.0, 20.0, step=0.5, value=7.5, label="Guidance Scale")
seed_input = gr.Number(value=42, label="Seed (optional)")
randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True)
run_btn = gr.Button("Generate Video")
with gr.Column():
output_video = gr.Video(label="Generated Video")
ui_inputs = [prompt_input, height_input, width_input, negative_prompt_input, duration_input, guidance_scale_input, steps_input, seed_input, randomize_seed_checkbox]
run_btn.click(fn=generate_video, inputs=ui_inputs, outputs=output_video)
# Launch demo
demo.launch()