Spaces:
Sleeping
Sleeping
roychao19477
commited on
Commit
·
dba6227
1
Parent(s):
5a2e862
Upload to debug
Browse files
app.py
CHANGED
@@ -82,6 +82,9 @@ avse_model.load_state_dict(avse_state_dict, strict=True)
|
|
82 |
avse_model.to("cuda")
|
83 |
avse_model.eval()
|
84 |
|
|
|
|
|
|
|
85 |
@spaces.GPU
|
86 |
def run_avse_inference(video_path, audio_path):
|
87 |
estimated = run_avse(video_path, audio_path)
|
@@ -101,19 +104,39 @@ def run_avse_inference(video_path, audio_path):
|
|
101 |
]).astype(np.float32)
|
102 |
bg_frames /= 255.0
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
|
109 |
# Combine into input dict (match what model.enhance expects)
|
110 |
-
data = {
|
111 |
-
|
112 |
-
|
113 |
-
}
|
|
|
|
|
|
|
|
|
114 |
|
115 |
with torch.no_grad():
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
# Save result
|
119 |
tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
|
|
|
82 |
avse_model.to("cuda")
|
83 |
avse_model.eval()
|
84 |
|
85 |
+
CHUNK_SIZE_AUDIO = 48000 # 3 sec at 16kHz
|
86 |
+
CHUNK_SIZE_VIDEO = 75 # 25fps × 3 sec
|
87 |
+
|
88 |
@spaces.GPU
|
89 |
def run_avse_inference(video_path, audio_path):
|
90 |
estimated = run_avse(video_path, audio_path)
|
|
|
104 |
]).astype(np.float32)
|
105 |
bg_frames /= 255.0
|
106 |
|
107 |
+
audio_chunks = [
|
108 |
+
noisy[i:i + CHUNK_SIZE_AUDIO]
|
109 |
+
for i in range(0, len(noisy), CHUNK_SIZE_AUDIO)
|
110 |
+
]
|
111 |
+
|
112 |
+
video_chunks = [
|
113 |
+
bg_frames[i:i + CHUNK_SIZE_VIDEO]
|
114 |
+
for i in range(0, len(bg_frames), CHUNK_SIZE_VIDEO)
|
115 |
+
]
|
116 |
+
|
117 |
+
min_len = min(len(audio_chunks), len(video_chunks)) # sync length
|
118 |
|
119 |
|
120 |
# Combine into input dict (match what model.enhance expects)
|
121 |
+
#data = {
|
122 |
+
# "noisy_audio": noisy,
|
123 |
+
# "video_frames": bg_frames[np.newaxis, ...]
|
124 |
+
#}
|
125 |
+
|
126 |
+
#with torch.no_grad():
|
127 |
+
# estimated = avse_model.enhance(data).reshape(-1)
|
128 |
+
estimated_chunks = []
|
129 |
|
130 |
with torch.no_grad():
|
131 |
+
for i in range(min_len):
|
132 |
+
chunk_data = {
|
133 |
+
"noisy_audio": audio_chunks[i],
|
134 |
+
"video_frames": video_chunks[i][np.newaxis, ...]
|
135 |
+
}
|
136 |
+
est = avse_model.enhance(chunk_data).reshape(-1)
|
137 |
+
estimated_chunks.append(est)
|
138 |
+
|
139 |
+
estimated = torch.cat(estimated_chunks).cpu().numpy()
|
140 |
|
141 |
# Save result
|
142 |
tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
|