liuhuadai commited on
Commit
4d03f49
·
verified ·
1 Parent(s): 8b65e18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -11
app.py CHANGED
@@ -232,10 +232,7 @@ if args.remove_pretransform_weight_norm == "post_load":
232
  ckpt_path = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="thinksound.ckpt",repo_type="model")
233
  training_wrapper = create_training_wrapper_from_config(model_config, model)
234
  # 加载模型权重时根据设备选择map_location
235
- if device == 'cuda':
236
- training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
237
- else:
238
- training_wrapper.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict'])
239
 
240
  def get_video_duration(video_path):
241
  video = VideoFileClip(video_path)
@@ -340,14 +337,14 @@ with gr.Blocks() as demo:
340
 
341
  gr.Examples(
342
  examples=[
343
- ["./examples/1_mute.mp4", "Playing Trumpet"],
344
- ["./examples/2_mute.mp4", "Axe striking"],
345
- ["./examples/3_mute.mp4", "Gentle Sucking Sounds From the Pacifier"],
346
- ["./examples/4_mute.mp4", "train passing by"],
347
- ["./examples/5_mute.mp4", "Lighting Firecrackers"]
348
  ],
349
- inputs=[video_input, caption_input],
350
  )
351
-
352
  demo.launch(share=True)
353
 
 
232
  ckpt_path = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="thinksound.ckpt",repo_type="model")
233
  training_wrapper = create_training_wrapper_from_config(model_config, model)
234
  # 加载模型权重时根据设备选择map_location
235
+ training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict']).to("cuda")
 
 
 
236
 
237
  def get_video_duration(video_path):
238
  video = VideoFileClip(video_path)
 
337
 
338
  gr.Examples(
339
  examples=[
340
+ ["./examples/1_mute.mp4", "Playing Trumpet", "./examples/1.mp4"],
341
+ ["./examples/2_mute.mp4", "Axe striking", "./examples/2.mp4"],
342
+ ["./examples/3_mute.mp4", "Gentle Sucking Sounds From the Pacifier", "./examples/3.mp4"],
343
+ ["./examples/4_mute.mp4", "train passing by", "./examples/4.mp4"],
344
+ ["./examples/5_mute.mp4", "Lighting Firecrackers", "./examples/5.mp4"]
345
  ],
346
+ inputs=[video_input, caption_input,output_video],
347
  )
348
+
349
  demo.launch(share=True)
350