File size: 2,777 Bytes
4641482
 
52b1dc0
8265486
 
52b1dc0
aba684e
8265486
 
 
97422f6
8265486
97422f6
 
 
 
 
 
8265486
 
 
 
 
 
 
 
97422f6
8265486
52b1dc0
8265486
 
 
 
 
52b1dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8265486
4b233a9
52b1dc0
4b233a9
52b1dc0
4b233a9
8265486
 
 
52b1dc0
 
8265486
 
 
 
4b233a9
4641482
 
8265486
 
52b1dc0
4641482
8265486
 
 
52b1dc0
 
8265486
 
 
 
52b1dc0
8265486
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
import torch
import spaces
from diffusers import DiffusionPipeline
from diffusers.quantizers import PipelineQuantizationConfig
import imageio

# Checkpoint ID
ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"

# Configure quantization (bitsandbytes 4-bit)
quant_config = PipelineQuantizationConfig(
    quant_backend="bitsandbytes_4bit",
    quant_kwargs={
        "load_in_4bit": True,
        "bnb_4bit_quant_type": "nf4",
        "bnb_4bit_compute_dtype": torch.bfloat16
    },
    components_to_quantize=["transformer", "text_encoder"]
)

# Load pipeline with quantization
pipe = DiffusionPipeline.from_pretrained(
    ckpt_id,
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16
).to("cuda")

# Optimize memory and performance
pipe.enable_model_cpu_offload()
torch._dynamo.config.recompile_limit = 1000
torch._dynamo.config.capture_dynamic_output_shape_ops = True
pipe.transformer.compile()

# Duration function

def get_duration(prompt, height, width, 
                   negative_prompt, duration_seconds,
                   guidance_scale, steps,
                   seed, randomize_seed, 
                   progress):
    if steps > 4 and duration_seconds > 2:
        return 90
    elif steps > 4 or duration_seconds > 2:
        return 75
    else:
        return 60

# Gradio inference function with GPU duration control
@spaces.GPU(duration=get_duration)
def generate_video(prompt, seed, steps, duration_seconds):
    generator = torch.manual_seed(seed) if seed else None
    
    # Force duration-based frames
    fps = 8
    num_frames = duration_seconds * fps if duration_seconds else 16

    video_frames = pipe(
        prompt=prompt,
        num_frames=num_frames,
        generator=generator,
        num_inference_steps=steps
    ).frames[0]  # Take first video

    # Save as GIF for Gradio preview
    out_path = "output.gif"
    imageio.mimsave(out_path, video_frames, fps=fps)
    return out_path

# Build Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## 🚀 Wan2.1 T2V - Text to Video Generator (Quantized, Dynamic Duration)")
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", lines=3, value="A futuristic cityscape with flying cars and neon lights.")
            seed = gr.Number(value=42, label="Seed (optional)")
            steps = gr.Slider(1, 50, value=20, step=1, label="Inference Steps")
            duration_seconds = gr.Slider(1, 10, value=2, step=1, label="Video Duration (seconds)")
            run_btn = gr.Button("Generate Video")
        with gr.Column():
            output_video = gr.Video(label="Generated Video")

    run_btn.click(fn=generate_video, inputs=[prompt, seed, steps, duration_seconds], outputs=output_video)

# Launch demo
demo.launch()