roychao19477 commited on
Commit
dee6815
·
1 Parent(s): 723a16d

Upload new model

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -61,6 +61,7 @@ import tempfile
61
  from ultralytics import YOLO
62
  from moviepy import ImageSequenceClip
63
  from scipy.io import wavfile
 
64
 
65
  # Load face detector
66
  model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available
@@ -75,18 +76,23 @@ import spaces
75
  #ckpt_path = "ckpts/ep215_0906.oat.ckpt"
76
  #model = AVSEModule.load_from_checkpoint(ckpt_path)
77
  avse_model = AVSEModule()
78
- avse_state_dict = torch.load("ckpts/ep215_0906.oat.ckpt")
 
79
  avse_model.load_state_dict(avse_state_dict, strict=True)
80
  avse_model.to("cuda")
81
  avse_model.eval()
82
 
83
  @spaces.GPU
84
  def run_avse_inference(video_path, audio_path):
 
85
  # Load audio
86
  #noisy, _ = sf.read(audio_path, dtype='float32') # (N, )
87
  #noisy = torch.tensor(noisy).unsqueeze(0) # (1, N)
88
  noisy = wavfile.read(audio_path)[1].astype(np.float32) / (2 ** 15)
89
 
 
 
 
90
  # Load grayscale video
91
  vr = VideoReader(video_path, ctx=cpu(0))
92
  frames = vr.get_batch(list(range(len(vr)))).asnumpy()
 
61
  from ultralytics import YOLO
62
  from moviepy import ImageSequenceClip
63
  from scipy.io import wavfile
64
+ from avse_code import run_avse
65
 
66
  # Load face detector
67
  model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available
 
76
  #ckpt_path = "ckpts/ep215_0906.oat.ckpt"
77
  #model = AVSEModule.load_from_checkpoint(ckpt_path)
78
  avse_model = AVSEModule()
79
+ #avse_state_dict = torch.load("ckpts/ep215_0906.oat.ckpt")
80
+ avse_state_dict = torch.load("ckpts/ep220_0908.oat.ckpt")
81
  avse_model.load_state_dict(avse_state_dict, strict=True)
82
  avse_model.to("cuda")
83
  avse_model.eval()
84
 
85
  @spaces.GPU
86
  def run_avse_inference(video_path, audio_path):
87
+ run_avse(video_path, audio_path)
88
  # Load audio
89
  #noisy, _ = sf.read(audio_path, dtype='float32') # (N, )
90
  #noisy = torch.tensor(noisy).unsqueeze(0) # (1, N)
91
  noisy = wavfile.read(audio_path)[1].astype(np.float32) / (2 ** 15)
92
 
93
+ # Norm.
94
+ #noisy = noisy * (0.8 / np.max(np.abs(noisy)))
95
+
96
  # Load grayscale video
97
  vr = VideoReader(video_path, ctx=cpu(0))
98
  frames = vr.get_batch(list(range(len(vr)))).asnumpy()