rahul7star commited on
Commit
4e485ec
Β·
verified Β·
1 Parent(s): 62960dc

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +23 -11
generate.py CHANGED
@@ -1,24 +1,37 @@
1
  import argparse
2
- import subprocess
3
  import os
4
  import torch
5
  from huggingface_hub import snapshot_download
6
 
7
  # Arguments
8
  parser = argparse.ArgumentParser()
9
- parser.add_argument("--task", type=str, default="t2v-14B")
10
- parser.add_argument("--size", type=str, default="200*200")
11
  parser.add_argument("--frame_num", type=int, default=60)
12
  parser.add_argument("--sample_steps", type=int, default=20)
13
- parser.add_argument("--ckpt_dir", type=str, default="./Wan2.1-T2V-14B")
14
  parser.add_argument("--offload_model", type=str, default="True")
 
 
 
15
  parser.add_argument("--prompt", type=str, required=True)
16
  args = parser.parse_args()
17
- print("πŸ”„ Downloading WAN 2.1 start..")
 
 
 
 
 
 
 
 
 
 
 
18
  # Ensure the model is downloaded
19
  if not os.path.exists(args.ckpt_dir):
20
- print("πŸ”„ Downloading WAN 2.1 - 14B model from Hugging Face...")
21
- snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir=args.ckpt_dir)
22
 
23
  # Free up GPU memory
24
  if torch.cuda.is_available():
@@ -26,9 +39,8 @@ if torch.cuda.is_available():
26
  torch.backends.cudnn.benchmark = False
27
  torch.backends.cudnn.deterministic = True
28
 
29
- # Run WAN 2.1 - 14B Model
30
- #command = f"python generate.py --task {args.task} --size {args.size} --frame_num {args.frame_num} --sample_steps {args.sample_steps} --ckpt_dir {args.ckpt_dir} --offload_model {args.offload_model} --prompt \"{args.prompt}\""
31
- command = f"python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
32
 
33
  process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
34
  stdout, stderr = process.communicate()
@@ -41,4 +53,4 @@ print("πŸ”Ί Error:", stderr.decode())
41
  if os.path.exists("output.mp4"):
42
  print("βœ… Video generated successfully: output.mp4")
43
  else:
44
- print("❌ Error: Video file not found!")
 
1
  import argparse
 
2
  import os
3
  import torch
4
  from huggingface_hub import snapshot_download
5
 
6
  # Arguments
7
  parser = argparse.ArgumentParser()
8
+ parser.add_argument("--task", type=str, default="t2v-1.3B")
9
+ parser.add_argument("--size", type=str, default="832*480")
10
  parser.add_argument("--frame_num", type=int, default=60)
11
  parser.add_argument("--sample_steps", type=int, default=20)
12
+ parser.add_argument("--ckpt_dir", type=str, default="./Wan2.1-T2V-1.3B")
13
  parser.add_argument("--offload_model", type=str, default="True")
14
+ parser.add_argument("--t5_cpu", action="store_true", help="Use CPU for T5 model (optional)")
15
+ parser.add_argument("--sample_shift", type=int, default=8, help="Sampling shift for generation")
16
+ parser.add_argument("--sample_guide_scale", type=int, default=6, help="Sampling guide scale for generation")
17
  parser.add_argument("--prompt", type=str, required=True)
18
  args = parser.parse_args()
19
+
20
+ # Log input parameters
21
+ print(f"Generating video with the following settings:\n"
22
+ f"Task: {args.task}\n"
23
+ f"Resolution: {args.size}\n"
24
+ f"Frames: {args.frame_num}\n"
25
+ f"Sample Steps: {args.sample_steps}\n"
26
+ f"Prompt: {args.prompt}\n"
27
+ f"Sample Shift: {args.sample_shift}\n"
28
+ f"Sample Guide Scale: {args.sample_guide_scale}\n"
29
+ f"Using T5 on CPU: {args.t5_cpu}")
30
+
31
  # Ensure the model is downloaded
32
  if not os.path.exists(args.ckpt_dir):
33
+ print("πŸ”„ Downloading WAN 2.1 - 1.3B model from Hugging Face...")
34
+ snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir=args.ckpt_dir)
35
 
36
  # Free up GPU memory
37
  if torch.cuda.is_available():
 
39
  torch.backends.cudnn.benchmark = False
40
  torch.backends.cudnn.deterministic = True
41
 
42
+ # Run the model (Ensure that `generate.py` includes these new params in its model call)
43
+ command = f"python generate.py --task {args.task} --size {args.size} --frame_num {args.frame_num} --sample_steps {args.sample_steps} --ckpt_dir {args.ckpt_dir} --offload_model {args.offload_model} --t5_cpu {args.t5_cpu} --sample_shift {args.sample_shift} --sample_guide_scale {args.sample_guide_scale} --prompt \"{args.prompt}\""
 
44
 
45
  process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
46
  stdout, stderr = process.communicate()
 
53
  if os.path.exists("output.mp4"):
54
  print("βœ… Video generated successfully: output.mp4")
55
  else:
56
+ print("❌ Error: Video file not found!")