Fly-ShuAI commited on
Commit
f55b8f5
·
verified ·
1 Parent(s): c412d8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -14
app.py CHANGED
@@ -15,15 +15,22 @@ num_frames, width, height = 49, 832, 480
15
  gpu_id = 0
16
  device = f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu'
17
 
18
- from huggingface_hub import snapshot_download, hf_hub_download
19
- # snapshot_download(
20
- # repo_id="briaai/RMBG-2.0",
21
- # local_dir="ckpt/RMBG-2.0",
22
- # local_dir_use_symlinks=False,
23
- # resume_download=True,
24
- # repo_type="model"
25
  # )
26
- # snapshot_download( # 下载整个仓库
 
 
 
 
 
 
 
 
 
 
27
  # repo_id="alibaba-pai/Wan2.1-Fun-1.3B-Control",
28
  # local_dir="ckpt/Wan2.1-Fun-1.3B-Control",
29
  # local_dir_use_symlinks=False,
@@ -31,14 +38,23 @@ from huggingface_hub import snapshot_download, hf_hub_download
31
  # repo_type="model"
32
  # )
33
 
34
- # rmbg_model = AutoModelForImageSegmentation.from_pretrained('ckpt/RMBG-2.0', trust_remote_code=True)
 
 
 
 
 
 
 
 
35
  # torch.set_float32_matmul_precision(['high', 'highest'][0])
36
  # rmbg_model.to(device)
37
  # rmbg_model.eval()
38
 
39
- model_manager = ModelManager(device="cpu") # 1.3b: device=cpu: uses 6G VRAM, device=device: uses 16G VRAM; about 1-2 min per video
40
 
 
41
  wan_dit_path = 'train_res/wan1.3b_zh/full_wc0.5_f1gt0.5_real1_2_zh_en_l_s/lightning_logs/version_0/checkpoints/step-step=30000.ckpt'
 
42
  if 'wan14b' in wan_dit_path.lower(): # 14B: uses about 36G, about 10 min per video
43
  model_manager.load_models(
44
  [
@@ -50,13 +66,12 @@ if 'wan14b' in wan_dit_path.lower(): # 14B: uses about 36G, about 10 min per vid
50
  torch_dtype=torch.bfloat16, # float8_e4m3fn fp8量化; bfloat16
51
  )
52
  else:
53
- wan_dit_path = None
54
  model_manager.load_models(
55
  [
56
- wan_dit_path if wan_dit_path else 'ckpt/Wan2.1-Fun-1.3B-Control/diffusion_pytorch_model.safetensors',
57
  'ckpt/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth',
58
- 'ckpt/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth',
59
- 'ckpt/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth',
60
  ],
61
  torch_dtype=torch.bfloat16,
62
  )
 
15
  gpu_id = 0
16
  device = f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu'
17
 
18
+ # from modelscope import snapshot_download
19
+ # model_dir = snapshot_download(
20
+ # model_id = 'AI-ModelScope/RMBG-2.0',
21
+ # local_dir = 'ckpt/',
 
 
 
22
  # )
23
+
24
+ from huggingface_hub import snapshot_download, hf_hub_download
25
+ hf_hub_download(
26
+ repo_id="alibaba-pai/Wan2.1-Fun-1.3B-Control",
27
+ filename="Wan2.1_VAE.pth",
28
+ local_dir="ckpt/Wan2.1-Fun-1.3B-Control/",
29
+ local_dir_use_symlinks=False,
30
+ resume_download=True,
31
+ )
32
+
33
+ # snapshot_download( # 下载整个仓库; 下briaai/RMBG-2.0需要token
34
  # repo_id="alibaba-pai/Wan2.1-Fun-1.3B-Control",
35
  # local_dir="ckpt/Wan2.1-Fun-1.3B-Control",
36
  # local_dir_use_symlinks=False,
 
38
  # repo_type="model"
39
  # )
40
 
41
+ # hf_hub_download(
42
+ # repo_id="Kunbyte/Lumen",
43
+ # filename="Lumen-T2V-1.3B.ckpt",
44
+ # local_dir="ckpt/",
45
+ # local_dir_use_symlinks=False,
46
+ # resume_download=True,
47
+ # )
48
+
49
+ # rmbg_model = AutoModelForImageSegmentation.from_pretrained('ckpt/RMBG-2.0', trust_remote_code=True) # ckpt/RMBG-2.0
50
  # torch.set_float32_matmul_precision(['high', 'highest'][0])
51
  # rmbg_model.to(device)
52
  # rmbg_model.eval()
53
 
 
54
 
55
+ model_manager = ModelManager(device="cpu") # 1.3b: device=cpu: uses 6G VRAM, device=device: uses 16G VRAM; about 1-2 min per video
56
  wan_dit_path = 'train_res/wan1.3b_zh/full_wc0.5_f1gt0.5_real1_2_zh_en_l_s/lightning_logs/version_0/checkpoints/step-step=30000.ckpt'
57
+
58
  if 'wan14b' in wan_dit_path.lower(): # 14B: uses about 36G, about 10 min per video
59
  model_manager.load_models(
60
  [
 
66
  torch_dtype=torch.bfloat16, # float8_e4m3fn fp8量化; bfloat16
67
  )
68
  else:
 
69
  model_manager.load_models(
70
  [
71
+ # wan_dit_path if wan_dit_path else 'ckpt/Wan2.1-Fun-1.3B-Control/diffusion_pytorch_model.safetensors',
72
  'ckpt/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth',
73
+ # 'ckpt/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth',
74
+ # 'ckpt/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth',
75
  ],
76
  torch_dtype=torch.bfloat16,
77
  )