liuhuadai commited on
Commit
353e603
·
verified ·
1 Parent(s): 2bda0a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -232,13 +232,15 @@ if args.remove_pretransform_weight_norm == "post_load":
232
  ckpt_path = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="thinksound.ckpt",repo_type="model")
233
  training_wrapper = create_training_wrapper_from_config(model_config, model)
234
  # 加载模型权重时根据设备选择map_location
235
- training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict']).to("cuda")
 
 
236
 
237
  def get_video_duration(video_path):
238
  video = VideoFileClip(video_path)
239
  return video.duration
240
 
241
- @spaces.GPU(duration=200)
242
  @torch.inference_mode()
243
  @torch.no_grad()
244
  def get_audio(video_path, caption):
 
232
  ckpt_path = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="thinksound.ckpt",repo_type="model")
233
  training_wrapper = create_training_wrapper_from_config(model_config, model)
234
  # 加载模型权重时根据设备选择map_location
235
+ training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
236
+
237
+ training_wrapper.to("cuda")
238
 
239
  def get_video_duration(video_path):
240
  video = VideoFileClip(video_path)
241
  return video.duration
242
 
243
+ @spaces.GPU(duration=60)
244
  @torch.inference_mode()
245
  @torch.no_grad()
246
  def get_audio(video_path, caption):