Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -232,13 +232,15 @@ if args.remove_pretransform_weight_norm == "post_load":
|
|
232 |
ckpt_path = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="thinksound.ckpt",repo_type="model")
|
233 |
training_wrapper = create_training_wrapper_from_config(model_config, model)
|
234 |
# 加载模型权重时根据设备选择map_location
|
235 |
-
training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
|
|
|
|
|
236 |
|
237 |
def get_video_duration(video_path):
|
238 |
video = VideoFileClip(video_path)
|
239 |
return video.duration
|
240 |
|
241 |
-
@spaces.GPU(duration=
|
242 |
@torch.inference_mode()
|
243 |
@torch.no_grad()
|
244 |
def get_audio(video_path, caption):
|
|
|
232 |
ckpt_path = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="thinksound.ckpt",repo_type="model")
|
233 |
training_wrapper = create_training_wrapper_from_config(model_config, model)
|
234 |
# 加载模型权重时根据设备选择map_location
|
235 |
+
training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
|
236 |
+
|
237 |
+
training_wrapper.to("cuda")
|
238 |
|
239 |
def get_video_duration(video_path):
|
240 |
video = VideoFileClip(video_path)
|
241 |
return video.duration
|
242 |
|
243 |
+
@spaces.GPU(duration=60)
|
244 |
@torch.inference_mode()
|
245 |
@torch.no_grad()
|
246 |
def get_audio(video_path, caption):
|