Spaces:
Paused
Paused
File size: 2,777 Bytes
4641482 52b1dc0 8265486 52b1dc0 aba684e 8265486 97422f6 8265486 97422f6 8265486 97422f6 8265486 52b1dc0 8265486 52b1dc0 8265486 4b233a9 52b1dc0 4b233a9 52b1dc0 4b233a9 8265486 52b1dc0 8265486 4b233a9 4641482 8265486 52b1dc0 4641482 8265486 52b1dc0 8265486 52b1dc0 8265486 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import gradio as gr
import torch
import spaces
from diffusers import DiffusionPipeline
from diffusers.quantizers import PipelineQuantizationConfig
import imageio
# Checkpoint ID
ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
# Configure quantization (bitsandbytes 4-bit)
quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16
},
components_to_quantize=["transformer", "text_encoder"]
)
# Load pipeline with quantization
pipe = DiffusionPipeline.from_pretrained(
ckpt_id,
quantization_config=quant_config,
torch_dtype=torch.bfloat16
).to("cuda")
# Optimize memory and performance
pipe.enable_model_cpu_offload()
torch._dynamo.config.recompile_limit = 1000
torch._dynamo.config.capture_dynamic_output_shape_ops = True
pipe.transformer.compile()
# Duration function
def get_duration(prompt, height, width,
negative_prompt, duration_seconds,
guidance_scale, steps,
seed, randomize_seed,
progress):
if steps > 4 and duration_seconds > 2:
return 90
elif steps > 4 or duration_seconds > 2:
return 75
else:
return 60
# Gradio inference function with GPU duration control
@spaces.GPU(duration=get_duration)
def generate_video(prompt, seed, steps, duration_seconds):
generator = torch.manual_seed(seed) if seed else None
# Force duration-based frames
fps = 8
num_frames = duration_seconds * fps if duration_seconds else 16
video_frames = pipe(
prompt=prompt,
num_frames=num_frames,
generator=generator,
num_inference_steps=steps
).frames[0] # Take first video
# Save as GIF for Gradio preview
out_path = "output.gif"
imageio.mimsave(out_path, video_frames, fps=fps)
return out_path
# Build Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🚀 Wan2.1 T2V - Text to Video Generator (Quantized, Dynamic Duration)")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", lines=3, value="A futuristic cityscape with flying cars and neon lights.")
seed = gr.Number(value=42, label="Seed (optional)")
steps = gr.Slider(1, 50, value=20, step=1, label="Inference Steps")
duration_seconds = gr.Slider(1, 10, value=2, step=1, label="Video Duration (seconds)")
run_btn = gr.Button("Generate Video")
with gr.Column():
output_video = gr.Video(label="Generated Video")
run_btn.click(fn=generate_video, inputs=[prompt, seed, steps, duration_seconds], outputs=output_video)
# Launch demo
demo.launch()
|