rahul7star commited on
Commit
3148252
·
verified ·
1 Parent(s): 228a1e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -5
app.py CHANGED
@@ -1,10 +1,10 @@
1
- import spaces
2
  import gradio as gr
3
  import torch
4
  from diffusers import DiffusionPipeline
5
  from diffusers.quantizers import PipelineQuantizationConfig
6
  import imageio
7
-
 
8
 
9
  # Checkpoint ID
10
  ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
@@ -33,9 +33,20 @@ torch._dynamo.config.recompile_limit = 1000
33
  torch._dynamo.config.capture_dynamic_output_shape_ops = True
34
 
35
  # Duration function
 
 
 
 
 
 
 
 
 
 
36
 
37
- @spaces.GPU(duration=100)
38
- def generate_video(prompt, seed, steps, duration_seconds,progress=gr.Progress(track_tqdm=True)):
 
39
  generator = torch.manual_seed(seed) if seed else None
40
  fps = 8
41
  num_frames = duration_seconds * fps if duration_seconds else 16
@@ -48,8 +59,14 @@ def generate_video(prompt, seed, steps, duration_seconds,progress=gr.Progress(tr
48
  num_inference_steps=steps
49
  ).frames[0]
50
 
 
 
 
 
 
 
51
  out_path = "output.gif"
52
- imageio.mimsave(out_path, video_frames, fps=fps)
53
  return out_path
54
 
55
  # Build Gradio UI
 
 
1
  import gradio as gr
2
  import torch
3
  from diffusers import DiffusionPipeline
4
  from diffusers.quantizers import PipelineQuantizationConfig
5
  import imageio
6
+ import numpy as np
7
+ import spaces
8
 
9
  # Checkpoint ID
10
  ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
 
33
  torch._dynamo.config.capture_dynamic_output_shape_ops = True
34
 
35
  # Duration function
36
+ def get_duration(prompt, height, width,
37
+ negative_prompt, duration_seconds,
38
+ guidance_scale, steps,
39
+ seed, randomize_seed):
40
+ if steps > 4 and duration_seconds > 2:
41
+ return 90
42
+ elif steps > 4 or duration_seconds > 2:
43
+ return 75
44
+ else:
45
+ return 60
46
 
47
+ # Gradio inference function with spaces GPU decorator
48
+ @spaces.GPU(duration=90)
49
+ def generate_video(prompt, seed, steps, duration_seconds):
50
  generator = torch.manual_seed(seed) if seed else None
51
  fps = 8
52
  num_frames = duration_seconds * fps if duration_seconds else 16
 
59
  num_inference_steps=steps
60
  ).frames[0]
61
 
62
+ # Ensure frames are uint8 numpy arrays for imageio
63
+ processed_frames = [
64
+ (np.clip(frame * 255, 0, 255).astype(np.uint8) if frame.dtype in [np.float32, np.float64] else frame)
65
+ for frame in video_frames
66
+ ]
67
+
68
  out_path = "output.gif"
69
+ imageio.mimsave(out_path, processed_frames, fps=fps)
70
  return out_path
71
 
72
  # Build Gradio UI