roychao19477 commited on
Commit
b478c0f
·
1 Parent(s): 9ecc54e

Upload to debug

Browse files
Files changed (1) hide show
  1. app.py +8 -130
app.py CHANGED
@@ -75,13 +75,6 @@ import spaces
75
  # Load model once globally
76
  #ckpt_path = "ckpts/ep215_0906.oat.ckpt"
77
  #model = AVSEModule.load_from_checkpoint(ckpt_path)
78
- avse_model = AVSEModule()
79
- #avse_state_dict = torch.load("ckpts/ep215_0906.oat.ckpt")
80
- avse_state_dict = torch.load("ckpts/ep220_0908.oat.ckpt")
81
- avse_model.load_state_dict(avse_state_dict, strict=True)
82
- avse_model.to("cuda")
83
- avse_model.eval()
84
-
85
  CHUNK_SIZE_AUDIO = 2 * 48000 # 3 sec at 16kHz
86
  CHUNK_SIZE_VIDEO = 2 * 75 # 25fps × 3 sec
87
 
@@ -166,7 +159,15 @@ def extract_resampled_audio(video_path, target_sr=16000):
166
  def yolo_detection(frame, verbose=False):
167
  return model(frame, verbose=verbose)[0]
168
 
 
169
  def extract_faces(video_file):
 
 
 
 
 
 
 
170
  cap = cv2.VideoCapture(video_file)
171
  fps = cap.get(cv2.CAP_PROP_FPS)
172
  frames = []
@@ -264,126 +265,3 @@ iface = gr.Interface(
264
  )
265
 
266
  iface.launch()
267
-
268
-
269
-
270
- ckpt = "ckpts/SEMamba_advanced.pth"
271
- cfg_f = "recipes/SEMamba_advanced.yaml"
272
-
273
- # load config
274
- with open(cfg_f, 'r') as f:
275
- cfg = yaml.safe_load(f)
276
-
277
-
278
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
279
- device = "cuda"
280
- model = SEMamba(cfg).to(device)
281
- #sdict = torch.load(ckpt, map_location=device)
282
- #model.load_state_dict(sdict["generator"])
283
- #model.eval()
284
-
285
- @spaces.GPU
286
- def enhance(filepath, model_name):
287
- # Load model based on selection
288
- ckpt_path = {
289
- "VCTK-Demand": "ckpts/SEMamba_advanced.pth",
290
- "VCTK+DNS": "ckpts/vd.pth"
291
- }[model_name]
292
-
293
- print("Loading:", ckpt_path)
294
- model.load_state_dict(torch.load(ckpt_path, map_location=device)["generator"])
295
- model.eval()
296
- with torch.no_grad():
297
- # load & resample
298
- wav, orig_sr = librosa.load(filepath, sr=None)
299
- noisy_wav = wav.copy()
300
- if orig_sr != 16000:
301
- wav = librosa.resample(wav, orig_sr=orig_sr, target_sr=16000)
302
- x = torch.from_numpy(wav).float().to(device)
303
- norm = torch.sqrt(len(x)/torch.sum(x**2))
304
- #x = (x * norm).unsqueeze(0)
305
- x = (x * norm)
306
-
307
- # split into 4s segments (64000 samples)
308
- segment_len = 4 * 16000
309
- chunks = x.split(segment_len)
310
- enhanced_chunks = []
311
-
312
- for chunk in chunks:
313
- if len(chunk) < segment_len:
314
- #pad = torch.zeros(segment_len - len(chunk), device=chunk.device)
315
- pad = (torch.randn(segment_len - len(chunk), device=chunk.device) * 1e-4)
316
- chunk = torch.cat([chunk, pad])
317
- chunk = chunk.unsqueeze(0)
318
-
319
- amp, pha, _ = mag_phase_stft(chunk, 400, 100, 400, 0.3)
320
- amp2, pha2, _ = model(amp, pha)
321
- out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3)
322
- out = (out / norm).squeeze(0)
323
- enhanced_chunks.append(out)
324
-
325
- out = torch.cat(enhanced_chunks)[:len(x)].cpu().numpy() # trim padding
326
-
327
- # back to original rate
328
- if orig_sr != 16000:
329
- out = librosa.resample(out, orig_sr=16000, target_sr=orig_sr)
330
-
331
- # Normalize
332
- peak = np.max(np.abs(out))
333
- if peak > 0.05:
334
- out = out / peak * 0.85
335
-
336
- # write file
337
- sf.write("enhanced.wav", out, orig_sr)
338
-
339
- # spectrograms
340
- fig, axs = plt.subplots(1, 2, figsize=(16, 4))
341
-
342
- # noisy
343
- D_noisy = librosa.stft(noisy_wav, n_fft=512, hop_length=256)
344
- S_noisy = librosa.amplitude_to_db(np.abs(D_noisy), ref=np.max)
345
- librosa.display.specshow(S_noisy, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[0], vmax=0)
346
- axs[0].set_title("Noisy Spectrogram")
347
-
348
- # enhanced
349
- D_clean = librosa.stft(out, n_fft=512, hop_length=256)
350
- S_clean = librosa.amplitude_to_db(np.abs(D_clean), ref=np.max)
351
- librosa.display.specshow(S_clean, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[1], vmax=0)
352
- #librosa.display.specshow(S_clean, sr=16000, hop_length=512, x_axis="time", y_axis="hz", ax=axs[1], vmax=0)
353
- axs[1].set_title("Enhanced Spectrogram")
354
-
355
- plt.tight_layout()
356
-
357
- return "enhanced.wav", fig
358
-
359
- #with gr.Blocks() as demo:
360
- # gr.Markdown(ABOUT)
361
- # input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
362
- # enhance_btn = gr.Button("Enhance")
363
- # output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
364
- # plot_output = gr.Plot(label="Spectrograms")
365
- #
366
- # enhance_btn.click(fn=enhance, inputs=input_audio, outputs=[output_audio, plot_output])
367
- #
368
- #demo.queue().launch()
369
-
370
- with gr.Blocks() as demo:
371
- gr.Markdown(ABOUT)
372
- input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
373
- model_choice = gr.Radio(
374
- label="Choose Model (The use of VCTK+DNS is recommended)",
375
- choices=["VCTK-Demand", "VCTK+DNS"],
376
- value="VCTK-Demand"
377
- )
378
- enhance_btn = gr.Button("Enhance")
379
- output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
380
- plot_output = gr.Plot(label="Spectrograms")
381
-
382
- enhance_btn.click(
383
- fn=enhance,
384
- inputs=[input_audio, model_choice],
385
- outputs=[output_audio, plot_output]
386
- )
387
- gr.Markdown("**Note**: The current models are trained on 16kHz audio. Therefore, any input audio not sampled at 16kHz will be automatically resampled before enhancement.")
388
-
389
- demo.queue().launch()
 
75
  # Load model once globally
76
  #ckpt_path = "ckpts/ep215_0906.oat.ckpt"
77
  #model = AVSEModule.load_from_checkpoint(ckpt_path)
 
 
 
 
 
 
 
78
  CHUNK_SIZE_AUDIO = 2 * 48000 # 3 sec at 16kHz
79
  CHUNK_SIZE_VIDEO = 2 * 75 # 25fps × 3 sec
80
 
 
159
  def yolo_detection(frame, verbose=False):
160
  return model(frame, verbose=verbose)[0]
161
 
162
+ @spaces.GPU
163
  def extract_faces(video_file):
164
+ avse_model = AVSEModule()
165
+ #avse_state_dict = torch.load("ckpts/ep215_0906.oat.ckpt")
166
+ avse_state_dict = torch.load("ckpts/ep220_0908.oat.ckpt")
167
+ avse_model.load_state_dict(avse_state_dict, strict=True)
168
+ avse_model.to("cuda")
169
+ avse_model.eval()
170
+
171
  cap = cv2.VideoCapture(video_file)
172
  fps = cap.get(cv2.CAP_PROP_FPS)
173
  frames = []
 
265
  )
266
 
267
  iface.launch()