rahul7star commited on
Commit
52b1dc0
·
verified ·
1 Parent(s): a4fe0cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -13
app.py CHANGED
@@ -1,10 +1,9 @@
1
-
2
- import spaces
3
-
4
  import gradio as gr
5
  import torch
 
6
  from diffusers import DiffusionPipeline
7
  from diffusers.quantizers import PipelineQuantizationConfig
 
8
 
9
  # Checkpoint ID
10
  ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
@@ -27,45 +26,61 @@ pipe = DiffusionPipeline.from_pretrained(
27
  torch_dtype=torch.bfloat16
28
  ).to("cuda")
29
 
30
- # Optimize memory
31
  pipe.enable_model_cpu_offload()
32
  torch._dynamo.config.recompile_limit = 1000
33
  torch._dynamo.config.capture_dynamic_output_shape_ops = True
34
  pipe.transformer.compile()
35
 
36
- # Gradio inference function
37
- @spaces.GPU(duration=20)
38
- def generate_video(prompt, seed):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  generator = torch.manual_seed(seed) if seed else None
40
 
41
- # Force ~2 second video (e.g., fps=8, frames=16)
42
- num_frames = 16
43
  fps = 8
 
44
 
45
  video_frames = pipe(
46
  prompt=prompt,
47
  num_frames=num_frames,
48
- generator=generator
 
49
  ).frames[0] # Take first video
50
 
51
  # Save as GIF for Gradio preview
52
- import imageio
53
  out_path = "output.gif"
54
  imageio.mimsave(out_path, video_frames, fps=fps)
55
  return out_path
56
 
57
  # Build Gradio UI
58
  with gr.Blocks() as demo:
59
- gr.Markdown("## 🚀 Wan2.1 T2V - Text to Video Generator (2 sec duration, 4-bit quantized)")
60
  with gr.Row():
61
  with gr.Column():
62
  prompt = gr.Textbox(label="Prompt", lines=3, value="A futuristic cityscape with flying cars and neon lights.")
63
  seed = gr.Number(value=42, label="Seed (optional)")
 
 
64
  run_btn = gr.Button("Generate Video")
65
  with gr.Column():
66
  output_video = gr.Video(label="Generated Video")
67
 
68
- run_btn.click(fn=generate_video, inputs=[prompt, seed], outputs=output_video)
69
 
70
  # Launch demo
71
  demo.launch()
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ import spaces
4
  from diffusers import DiffusionPipeline
5
  from diffusers.quantizers import PipelineQuantizationConfig
6
+ import imageio
7
 
8
  # Checkpoint ID
9
  ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
 
26
  torch_dtype=torch.bfloat16
27
  ).to("cuda")
28
 
29
+ # Optimize memory and performance
30
  pipe.enable_model_cpu_offload()
31
  torch._dynamo.config.recompile_limit = 1000
32
  torch._dynamo.config.capture_dynamic_output_shape_ops = True
33
  pipe.transformer.compile()
34
 
35
+ # Duration function
36
+
37
+ def get_duration(prompt, height, width,
38
+ negative_prompt, duration_seconds,
39
+ guidance_scale, steps,
40
+ seed, randomize_seed,
41
+ progress):
42
+ if steps > 4 and duration_seconds > 2:
43
+ return 90
44
+ elif steps > 4 or duration_seconds > 2:
45
+ return 75
46
+ else:
47
+ return 60
48
+
49
+ # Gradio inference function with GPU duration control
50
+ @spaces.GPU(duration=get_duration)
51
+ def generate_video(prompt, seed, steps, duration_seconds):
52
  generator = torch.manual_seed(seed) if seed else None
53
 
54
+ # Force duration-based frames
 
55
  fps = 8
56
+ num_frames = duration_seconds * fps if duration_seconds else 16
57
 
58
  video_frames = pipe(
59
  prompt=prompt,
60
  num_frames=num_frames,
61
+ generator=generator,
62
+ num_inference_steps=steps
63
  ).frames[0] # Take first video
64
 
65
  # Save as GIF for Gradio preview
 
66
  out_path = "output.gif"
67
  imageio.mimsave(out_path, video_frames, fps=fps)
68
  return out_path
69
 
70
  # Build Gradio UI
71
  with gr.Blocks() as demo:
72
+ gr.Markdown("## 🚀 Wan2.1 T2V - Text to Video Generator (Quantized, Dynamic Duration)")
73
  with gr.Row():
74
  with gr.Column():
75
  prompt = gr.Textbox(label="Prompt", lines=3, value="A futuristic cityscape with flying cars and neon lights.")
76
  seed = gr.Number(value=42, label="Seed (optional)")
77
+ steps = gr.Slider(1, 50, value=20, step=1, label="Inference Steps")
78
+ duration_seconds = gr.Slider(1, 10, value=2, step=1, label="Video Duration (seconds)")
79
  run_btn = gr.Button("Generate Video")
80
  with gr.Column():
81
  output_video = gr.Video(label="Generated Video")
82
 
83
+ run_btn.click(fn=generate_video, inputs=[prompt, seed, steps, duration_seconds], outputs=output_video)
84
 
85
  # Launch demo
86
  demo.launch()