File size: 4,572 Bytes
d8bee93
4641482
 
8265486
 
52b1dc0
7c0acdb
 
 
976c1b3
07fa2e1
 
 
 
 
 
 
7c0acdb
 
 
3148252
b86a84e
2fd357a
 
 
 
 
16df797
 
 
aba684e
16df797
 
 
9fe0c69
16df797
 
8265486
 
 
97422f6
8265486
97422f6
 
 
 
 
 
8265486
 
 
cc52ef6
8265486
 
 
 
df00973
8265486
 
 
 
cc52ef6
 
 
 
 
3148252
cc52ef6
52b1dc0
3148252
cc52ef6
4808c1f
 
 
 
 
 
59bbf3c
4b233a9
59bbf3c
8265486
59bbf3c
 
 
8265486
59bbf3c
 
 
86a50c3
8265486
addd32a
133dd35
 
 
59bbf3c
133dd35
 
 
 
4641482
cc52ef6
 
 
4641482
8265486
86a50c3
cc52ef6
 
 
 
86a50c3
cc52ef6
 
 
8265486
 
 
 
cc52ef6
 
8265486
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import spaces
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from diffusers.quantizers import PipelineQuantizationConfig
import imageio
from diffusers.utils.export_utils import export_to_video
import gradio as gr
import tempfile

import os
import re
import json
import random
import tempfile
import traceback
from functools import partial
import numpy as np
from PIL import Image
import random
import numpy as np
import random
import gradio as gr
import tempfile
import numpy as np
from PIL import Image
import random
LANDSCAPE_WIDTH = 832
LANDSCAPE_HEIGHT = 480
MAX_SEED = np.iinfo(np.int32).max

FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 81
T2V_FIXED_FPS = 16
MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS,1)
MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS,1)
# Checkpoint ID
ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"

# Configure quantization (bitsandbytes 4-bit)
quant_config = PipelineQuantizationConfig(
    quant_backend="bitsandbytes_4bit",
    quant_kwargs={
        "load_in_4bit": True,
        "bnb_4bit_quant_type": "nf4",
        "bnb_4bit_compute_dtype": torch.bfloat16
    },
    components_to_quantize=["transformer", "text_encoder"]
)

# Load pipeline with quantization
pipe = DiffusionPipeline.from_pretrained(
    ckpt_id,
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
torch._dynamo.config.recompile_limit = 1000
torch._dynamo.config.capture_dynamic_output_shape_ops = True

# Smart duration function using all UI params
def get_duration(prompt, height, width, negative_prompt, duration_seconds, guidance_scale, steps, seed, randomize_seed, progress):
    # Calculate dynamic duration based on steps and requested duration
    if duration_seconds <= 2.5:
        return steps * 18
    else:
        return steps * 25

# Gradio inference function with spaces GPU decorator
@spaces.GPU(duration=get_duration)
def generate_video(prompt, height, width, negative_prompt, duration_seconds,
                   guidance_scale, steps, seed, randomize_seed,
                   progress=gr.Progress(track_tqdm=True)):
    
    num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)),
                         MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
    current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)

    output_frames_list = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=int(height),
        width=int(width),
        num_frames=num_frames,
        guidance_scale=float(guidance_scale),
        num_inference_steps=int(steps),
        generator=torch.Generator(device="cuda").manual_seed(current_seed),
    ).frames[0]

    filename = f"t2v_aaa.mp4"
    temp_dir = tempfile.mkdtemp()
    video_path = os.path.join(temp_dir, filename)
    export_to_video(output_frames_list, video_path, fps=T2V_FIXED_FPS)

    print(f"✅ Video saved to: {video_path}")
    download_label = f"📥 Download: {filename}"
    return video_path, current_seed, gr.File(value=video_path, visible=True, label=download_label)
  

# Build Gradio UI with all parameters
with gr.Blocks(css="body { max-width: 100vw; overflow-x: hidden; }") as demo:
    gr.Markdown("## 🚀 Wan2.1 T2V - Text to Video Generator (Quantized, Smart Duration)")
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Prompt", lines=3, value="A futuristic cityscape with flying cars and neon lights.")
            negative_prompt_input = gr.Textbox(label="Negative Prompt", lines=3, value="")
            height_input = gr.Slider(256, 1024, step=8, value=512, label="Height")
            width_input = gr.Slider(256, 1024, step=8, value=512, label="Width")
            duration_input = gr.Slider(1, 10, value=2, step=0.1, label="Duration (seconds)")
            steps_input = gr.Slider(1, 50, value=20, step=1, label="Inference Steps")
            guidance_scale_input = gr.Slider(0.0, 20.0, step=0.5, value=7.5, label="Guidance Scale")
            seed_input = gr.Number(value=42, label="Seed (optional)")
            randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True)
            run_btn = gr.Button("Generate Video")
        with gr.Column():
            output_video = gr.Video(label="Generated Video")

    ui_inputs = [prompt_input, height_input, width_input, negative_prompt_input, duration_input, guidance_scale_input, steps_input, seed_input, randomize_seed_checkbox]
    run_btn.click(fn=generate_video, inputs=ui_inputs, outputs=output_video)

# Launch demo
demo.launch()