Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	| import gradio as gr | |
| import torch | |
| from diffusers import DiffusionPipeline | |
| from diffusers.quantizers import PipelineQuantizationConfig | |
| import imageio | |
| # Checkpoint ID | |
| ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" | |
| # Configure quantization (bitsandbytes 4-bit) | |
| quant_config = PipelineQuantizationConfig( | |
| quant_backend="bitsandbytes_4bit", | |
| quant_kwargs={ | |
| "load_in_4bit": True, | |
| "bnb_4bit_quant_type": "nf4", | |
| "bnb_4bit_compute_dtype": torch.bfloat16 | |
| }, | |
| components_to_quantize=["transformer", "text_encoder"] | |
| ) | |
| # Load pipeline with quantization | |
| pipe = DiffusionPipeline.from_pretrained( | |
| ckpt_id, | |
| quantization_config=quant_config, | |
| torch_dtype=torch.bfloat16 | |
| ).to("cuda") | |
| # Optimize memory and performance | |
| pipe.enable_model_cpu_offload() | |
| torch._dynamo.config.recompile_limit = 1000 | |
| torch._dynamo.config.capture_dynamic_output_shape_ops = True | |
| pipe.transformer.compile() | |
| # Duration function | |
| def get_duration(prompt, height, width, | |
| negative_prompt, duration_seconds, | |
| guidance_scale, steps, | |
| seed, randomize_seed): | |
| if steps > 4 and duration_seconds > 2: | |
| return 90 | |
| elif steps > 4 or duration_seconds > 2: | |
| return 75 | |
| else: | |
| return 60 | |
| # Gradio inference function (no @spaces.GPU decorator) to avoid progress ContextVar error | |
| def generate_video(prompt, seed, steps, duration_seconds): | |
| generator = torch.manual_seed(seed) if seed else None | |
| fps = 8 | |
| num_frames = duration_seconds * fps if duration_seconds else 16 | |
| video_frames = pipe( | |
| prompt=prompt, | |
| num_frames=num_frames, | |
| generator=generator, | |
| num_inference_steps=steps | |
| ).frames[0] | |
| out_path = "output.gif" | |
| imageio.mimsave(out_path, video_frames, fps=fps) | |
| return out_path | |
| # Build Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🚀 Wan2.1 T2V - Text to Video Generator (Quantized, Dynamic Duration)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox(label="Prompt", lines=3, value="A futuristic cityscape with flying cars and neon lights.") | |
| seed_input = gr.Number(value=42, label="Seed (optional)") | |
| steps_input = gr.Slider(1, 50, value=20, step=1, label="Inference Steps") | |
| duration_input = gr.Slider(1, 10, value=2, step=1, label="Video Duration (seconds)") | |
| run_btn = gr.Button("Generate Video") | |
| with gr.Column(): | |
| output_video = gr.Video(label="Generated Video") | |
| run_btn.click(fn=generate_video, inputs=[prompt_input, seed_input, steps_input, duration_input], outputs=output_video) | |
| # Launch demo | |
| demo.launch() | |