rahul7star commited on
Commit
df00973
·
verified ·
1 Parent(s): 16572f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -12
app.py CHANGED
@@ -1,10 +1,9 @@
1
- import spaces
2
  import gradio as gr
3
  import torch
4
  from diffusers import DiffusionPipeline
5
  from diffusers.quantizers import PipelineQuantizationConfig
6
  import imageio
7
-
8
 
9
  # Checkpoint ID
10
  ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
@@ -20,25 +19,23 @@ quant_config = PipelineQuantizationConfig(
20
  components_to_quantize=["transformer", "text_encoder"]
21
  )
22
 
23
- # Load pipeline with quantization
24
  pipe = DiffusionPipeline.from_pretrained(
25
  ckpt_id,
26
  quantization_config=quant_config,
27
  torch_dtype=torch.bfloat16
28
- ).to("cuda")
29
 
30
- # Optimize memory and performance
31
  pipe.enable_model_cpu_offload()
32
  torch._dynamo.config.recompile_limit = 1000
33
  torch._dynamo.config.capture_dynamic_output_shape_ops = True
34
- pipe.transformer.compile()
35
 
36
- # Duration function using progress context safely
37
  def get_duration(prompt, height, width,
38
  negative_prompt, duration_seconds,
39
  guidance_scale, steps,
40
- seed, randomize_seed,
41
- progress=None): # Added default None for progress
42
  if steps > 4 and duration_seconds > 2:
43
  return 90
44
  elif steps > 4 or duration_seconds > 2:
@@ -46,13 +43,14 @@ def get_duration(prompt, height, width,
46
  else:
47
  return 60
48
 
49
- # Gradio inference function with safe GPU duration
50
- @spaces.GPU(duration=50)
51
- def generate_video(prompt, seed, steps, duration_seconds, progress=gr.Progress(track_tqdm=True)):
52
  generator = torch.manual_seed(seed) if seed else None
53
  fps = 8
54
  num_frames = duration_seconds * fps if duration_seconds else 16
55
 
 
56
  video_frames = pipe(
57
  prompt=prompt,
58
  num_frames=num_frames,
 
 
1
  import gradio as gr
2
  import torch
3
  from diffusers import DiffusionPipeline
4
  from diffusers.quantizers import PipelineQuantizationConfig
5
  import imageio
6
+ import spaces
7
 
8
  # Checkpoint ID
9
  ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
 
19
  components_to_quantize=["transformer", "text_encoder"]
20
  )
21
 
22
+ # Load pipeline with quantization, without immediately moving to CUDA to avoid ConstantVariable error
23
  pipe = DiffusionPipeline.from_pretrained(
24
  ckpt_id,
25
  quantization_config=quant_config,
26
  torch_dtype=torch.bfloat16
27
+ )
28
 
29
+ # Enable CPU offload and compile after offload
30
  pipe.enable_model_cpu_offload()
31
  torch._dynamo.config.recompile_limit = 1000
32
  torch._dynamo.config.capture_dynamic_output_shape_ops = True
 
33
 
34
+ # Duration function
35
  def get_duration(prompt, height, width,
36
  negative_prompt, duration_seconds,
37
  guidance_scale, steps,
38
+ seed, randomize_seed):
 
39
  if steps > 4 and duration_seconds > 2:
40
  return 90
41
  elif steps > 4 or duration_seconds > 2:
 
43
  else:
44
  return 60
45
 
46
+ # Gradio inference function with spaces GPU decorator
47
+ @spaces.GPU(duration=30)
48
+ def generate_video(prompt, seed, steps, duration_seconds):
49
  generator = torch.manual_seed(seed) if seed else None
50
  fps = 8
51
  num_frames = duration_seconds * fps if duration_seconds else 16
52
 
53
+ # Run pipeline on default device with automatic offload
54
  video_frames = pipe(
55
  prompt=prompt,
56
  num_frames=num_frames,