File size: 3,540 Bytes
d8bee93
4641482
 
8265486
 
52b1dc0
3148252
d8bee93
aba684e
8265486
 
 
97422f6
8265486
97422f6
 
 
 
 
 
8265486
 
 
cc52ef6
8265486
 
 
 
df00973
8265486
 
 
 
cc52ef6
 
 
 
 
3148252
cc52ef6
52b1dc0
3148252
cc52ef6
 
8265486
4b233a9
cc52ef6
4b233a9
8265486
 
 
52b1dc0
 
86a50c3
8265486
3148252
 
 
 
 
8265486
3148252
4641482
 
cc52ef6
 
 
4641482
8265486
86a50c3
cc52ef6
 
 
 
86a50c3
cc52ef6
 
 
8265486
 
 
 
cc52ef6
 
8265486
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import spaces
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from diffusers.quantizers import PipelineQuantizationConfig
import imageio
import numpy as np


# Checkpoint ID
ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"

# Configure quantization (bitsandbytes 4-bit)
quant_config = PipelineQuantizationConfig(
    quant_backend="bitsandbytes_4bit",
    quant_kwargs={
        "load_in_4bit": True,
        "bnb_4bit_quant_type": "nf4",
        "bnb_4bit_compute_dtype": torch.bfloat16
    },
    components_to_quantize=["transformer", "text_encoder"]
)

# Load pipeline with quantization
pipe = DiffusionPipeline.from_pretrained(
    ckpt_id,
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
torch._dynamo.config.recompile_limit = 1000
torch._dynamo.config.capture_dynamic_output_shape_ops = True

# Smart duration function using all UI params
def get_duration(prompt, height, width, negative_prompt, duration_seconds, guidance_scale, steps, seed, randomize_seed, progress):
    # Calculate dynamic duration based on steps and requested duration
    if duration_seconds <= 2.5:
        return steps * 18
    else:
        return steps * 25

# Gradio inference function with spaces GPU decorator
@spaces.GPU(duration=get_duration)
def generate_video(prompt, height, width, negative_prompt, duration_seconds, guidance_scale, steps, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)):
    generator = torch.manual_seed(seed) if seed else None
    fps = 8
    num_frames = int(duration_seconds * fps) if duration_seconds else 16

    video_frames = pipe(
        prompt=prompt,
        num_frames=num_frames,
        generator=generator,
        num_inference_steps=steps
    ).frames[0]

    processed_frames = [
        (np.clip(frame * 255, 0, 255).astype(np.uint8) if frame.dtype in [np.float32, np.float64] else frame)
        for frame in video_frames
    ]

    out_path = "output.gif"
    imageio.mimsave(out_path, processed_frames, fps=fps)
    return out_path

# Build Gradio UI with all parameters
with gr.Blocks(css="body { max-width: 100vw; overflow-x: hidden; }") as demo:
    gr.Markdown("## 🚀 Wan2.1 T2V - Text to Video Generator (Quantized, Smart Duration)")
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Prompt", lines=3, value="A futuristic cityscape with flying cars and neon lights.")
            negative_prompt_input = gr.Textbox(label="Negative Prompt", lines=3, value="")
            height_input = gr.Slider(256, 1024, step=8, value=512, label="Height")
            width_input = gr.Slider(256, 1024, step=8, value=512, label="Width")
            duration_input = gr.Slider(1, 10, value=2, step=0.1, label="Duration (seconds)")
            steps_input = gr.Slider(1, 50, value=20, step=1, label="Inference Steps")
            guidance_scale_input = gr.Slider(0.0, 20.0, step=0.5, value=7.5, label="Guidance Scale")
            seed_input = gr.Number(value=42, label="Seed (optional)")
            randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True)
            run_btn = gr.Button("Generate Video")
        with gr.Column():
            output_video = gr.Video(label="Generated Video")

    ui_inputs = [prompt_input, height_input, width_input, negative_prompt_input, duration_input, guidance_scale_input, steps_input, seed_input, randomize_seed_checkbox]
    run_btn.click(fn=generate_video, inputs=ui_inputs, outputs=output_video)

# Launch demo
demo.launch()