rahul7star commited on
Commit
fc2df73
·
verified ·
1 Parent(s): 9fe0c69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -47
app.py CHANGED
@@ -1,73 +1,51 @@
1
  import spaces
2
  import gradio as gr
3
  import torch
 
 
 
4
  from diffusers import DiffusionPipeline
5
  from diffusers.quantizers import PipelineQuantizationConfig
6
- import imageio
7
  from diffusers.utils.export_utils import export_to_video
8
- import gradio as gr
9
- import tempfile
10
 
11
- import os
12
- import re
13
- import json
14
- import random
15
- import tempfile
16
- import traceback
17
- from functools import partial
18
- import numpy as np
19
- from PIL import Image
20
- import random
21
- import numpy as np
22
- import random
23
- import gradio as gr
24
- import tempfile
25
- import numpy as np
26
- from PIL import Image
27
- import random
28
  LANDSCAPE_WIDTH = 832
29
  LANDSCAPE_HEIGHT = 480
30
  MAX_SEED = np.iinfo(np.int32).max
31
-
32
  FIXED_FPS = 16
33
  MIN_FRAMES_MODEL = 8
34
  MAX_FRAMES_MODEL = 81
35
  T2V_FIXED_FPS = 16
36
- MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS,1)
37
- MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS,1)
 
38
  # Checkpoint ID
39
  ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
40
 
41
- # Configure quantization (bitsandbytes 4-bit)
42
  quant_config = PipelineQuantizationConfig(
43
  quant_backend="bitsandbytes_4bit",
44
  quant_kwargs={
45
  "load_in_4bit": True,
46
  "bnb_4bit_quant_type": "nf4",
47
- "bnb_4bit_compute_dtype": torch.bfloat16
48
  },
49
- components_to_quantize=["transformer", "text_encoder"]
50
  )
51
 
52
- # Load pipeline with quantization
53
  pipe = DiffusionPipeline.from_pretrained(
54
  ckpt_id,
55
  quantization_config=quant_config,
56
- torch_dtype=torch.bfloat16
57
  )
58
  pipe.enable_model_cpu_offload()
59
- torch._dynamo.config.recompile_limit = 1000
60
- torch._dynamo.config.capture_dynamic_output_shape_ops = True
61
 
62
- # Smart duration function using all UI params
63
  def get_duration(prompt, height, width, negative_prompt, duration_seconds, guidance_scale, steps, seed, randomize_seed, progress):
64
- # Calculate dynamic duration based on steps and requested duration
65
- if duration_seconds <= 2.5:
66
- return steps * 18
67
- else:
68
- return steps * 25
69
 
70
- # Gradio inference function with spaces GPU decorator
71
  @spaces.GPU(duration=get_duration)
72
  def generate_video(prompt, height, width, negative_prompt, duration_seconds,
73
  guidance_scale, steps, seed, randomize_seed,
@@ -75,7 +53,7 @@ def generate_video(prompt, height, width, negative_prompt, duration_seconds,
75
 
76
  num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)),
77
  MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
78
- current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
79
 
80
  output_frames_list = pipe(
81
  prompt=prompt,
@@ -85,20 +63,17 @@ def generate_video(prompt, height, width, negative_prompt, duration_seconds,
85
  num_frames=num_frames,
86
  guidance_scale=float(guidance_scale),
87
  num_inference_steps=int(steps),
88
- generator=torch.Generator(device="cuda").manual_seed(current_seed),
89
  ).frames[0]
90
 
91
- filename = f"t2v_aaa.mp4"
92
  temp_dir = tempfile.mkdtemp()
93
- video_path = os.path.join(temp_dir, filename)
94
  export_to_video(output_frames_list, video_path, fps=T2V_FIXED_FPS)
95
 
96
  print(f"✅ Video saved to: {video_path}")
97
- download_label = f"📥 Download: {filename}"
98
- return video_path, current_seed, gr.File(value=video_path, visible=True, label=download_label)
99
-
100
 
101
- # Build Gradio UI with all parameters
102
  with gr.Blocks(css="body { max-width: 100vw; overflow-x: hidden; }") as demo:
103
  gr.Markdown("## 🚀 Wan2.1 T2V - Text to Video Generator (Quantized, Smart Duration)")
104
  with gr.Row():
@@ -116,8 +91,12 @@ with gr.Blocks(css="body { max-width: 100vw; overflow-x: hidden; }") as demo:
116
  with gr.Column():
117
  output_video = gr.Video(label="Generated Video")
118
 
119
- ui_inputs = [prompt_input, height_input, width_input, negative_prompt_input, duration_input, guidance_scale_input, steps_input, seed_input, randomize_seed_checkbox]
 
 
 
 
120
  run_btn.click(fn=generate_video, inputs=ui_inputs, outputs=output_video)
121
 
122
- # Launch demo
123
  demo.launch()
 
1
  import spaces
2
  import gradio as gr
3
  import torch
4
+ import numpy as np
5
+ import os
6
+ import tempfile
7
  from diffusers import DiffusionPipeline
8
  from diffusers.quantizers import PipelineQuantizationConfig
 
9
  from diffusers.utils.export_utils import export_to_video
 
 
10
 
11
+ # Constants
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  LANDSCAPE_WIDTH = 832
13
  LANDSCAPE_HEIGHT = 480
14
  MAX_SEED = np.iinfo(np.int32).max
 
15
  FIXED_FPS = 16
16
  MIN_FRAMES_MODEL = 8
17
  MAX_FRAMES_MODEL = 81
18
  T2V_FIXED_FPS = 16
19
+ MIN_DURATION = round(MIN_FRAMES_MODEL / FIXED_FPS, 1)
20
+ MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1)
21
+
22
  # Checkpoint ID
23
  ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
24
 
25
+ # Quantization config
26
  quant_config = PipelineQuantizationConfig(
27
  quant_backend="bitsandbytes_4bit",
28
  quant_kwargs={
29
  "load_in_4bit": True,
30
  "bnb_4bit_quant_type": "nf4",
31
+ "bnb_4bit_compute_dtype": torch.bfloat16,
32
  },
33
+ components_to_quantize=["transformer", "text_encoder"],
34
  )
35
 
36
+ # Load pipeline
37
  pipe = DiffusionPipeline.from_pretrained(
38
  ckpt_id,
39
  quantization_config=quant_config,
40
+ torch_dtype=torch.bfloat16,
41
  )
42
  pipe.enable_model_cpu_offload()
 
 
43
 
44
+ # Duration estimator
45
  def get_duration(prompt, height, width, negative_prompt, duration_seconds, guidance_scale, steps, seed, randomize_seed, progress):
46
+ return steps * 18 if duration_seconds <= 2.5 else steps * 25
 
 
 
 
47
 
48
+ # Inference function
49
  @spaces.GPU(duration=get_duration)
50
  def generate_video(prompt, height, width, negative_prompt, duration_seconds,
51
  guidance_scale, steps, seed, randomize_seed,
 
53
 
54
  num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)),
55
  MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
56
+ current_seed = np.random.randint(0, MAX_SEED) if randomize_seed else int(seed)
57
 
58
  output_frames_list = pipe(
59
  prompt=prompt,
 
63
  num_frames=num_frames,
64
  guidance_scale=float(guidance_scale),
65
  num_inference_steps=int(steps),
66
+ generator=torch.manual_seed(current_seed),
67
  ).frames[0]
68
 
 
69
  temp_dir = tempfile.mkdtemp()
70
+ video_path = os.path.join(temp_dir, "t2v_output.mp4")
71
  export_to_video(output_frames_list, video_path, fps=T2V_FIXED_FPS)
72
 
73
  print(f"✅ Video saved to: {video_path}")
74
+ return video_path # Only return video
 
 
75
 
76
+ # Gradio UI
77
  with gr.Blocks(css="body { max-width: 100vw; overflow-x: hidden; }") as demo:
78
  gr.Markdown("## 🚀 Wan2.1 T2V - Text to Video Generator (Quantized, Smart Duration)")
79
  with gr.Row():
 
91
  with gr.Column():
92
  output_video = gr.Video(label="Generated Video")
93
 
94
+ ui_inputs = [
95
+ prompt_input, height_input, width_input, negative_prompt_input,
96
+ duration_input, guidance_scale_input, steps_input, seed_input,
97
+ randomize_seed_checkbox
98
+ ]
99
  run_btn.click(fn=generate_video, inputs=ui_inputs, outputs=output_video)
100
 
101
+ # Launch
102
  demo.launch()