roychao19477 commited on
Commit
27bac1c
·
1 Parent(s): d3458da

Update chunk feature

Browse files
Files changed (1) hide show
  1. app.py +37 -5
app.py CHANGED
@@ -103,13 +103,45 @@ def run_avse_inference(video_path, audio_path):
103
 
104
 
105
  # Combine into input dict (match what model.enhance expects)
106
- data = {
107
- "noisy_audio": noisy,
108
- "video_frames": bg_frames[np.newaxis, ...]
109
- }
110
 
111
  with torch.no_grad():
112
- estimated = avse_model.enhance(data).reshape(-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  # Save result
115
  tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
 
103
 
104
 
105
  # Combine into input dict (match what model.enhance expects)
106
+ #data = {
107
+ # "noisy_audio": noisy,
108
+ # "video_frames": bg_frames[np.newaxis, ...]
109
+ #}
110
 
111
  with torch.no_grad():
112
+ # Version 1
113
+ #estimated = avse_model.enhance(data).reshape(-1)
114
+ # Version 2
115
+ chunk_sec = 3
116
+ sr = 16000
117
+ audio_chunk_len = chunk_sec * sr # 48000
118
+ video_chunk_len = chunk_sec * 25 # 75
119
+
120
+ estimated_chunks = []
121
+
122
+ for i in range(0, len(noisy), audio_chunk_len):
123
+ audio_chunk = noisy[i:i+audio_chunk_len]
124
+ if len(audio_chunk) < audio_chunk_len:
125
+ pad = np.zeros(audio_chunk_len - len(audio_chunk), dtype=audio_chunk.dtype)
126
+ audio_chunk = np.concatenate([audio_chunk, pad])
127
+
128
+ vid_idx = i // sr * 25 # convert audio index to video frame index
129
+ video_chunk = bg_frames[0, vid_idx:vid_idx+video_chunk_len, :, :]
130
+ if video_chunk.shape[0] < video_chunk_len:
131
+ pad = np.zeros((video_chunk_len - video_chunk.shape[0], *video_chunk.shape[1:]), dtype=video_chunk.dtype)
132
+ video_chunk = np.concatenate([video_chunk, pad], axis=0)
133
+
134
+ data = {
135
+ "noisy_audio": audio_chunk,
136
+ "video_frames": video_chunk[np.newaxis, ...]
137
+ }
138
+
139
+ with torch.no_grad():
140
+ out = avse_model.enhance(data).reshape(-1)
141
+ estimated_chunks.append(out)
142
+
143
+ estimated = np.concatenate(estimated_chunks)[:len(noisy)]
144
+
145
 
146
  # Save result
147
  tmp_wav = audio_path.replace(".wav", "_enhanced.wav")