|
|
|
|
|
|
|
__all__ = [] |
|
|
|
|
|
import os |
|
import torch |
|
import torchaudio |
|
|
|
from pathlib import Path |
|
from fastprogress import progress_bar |
|
from fastcore.script import call_parse |
|
|
|
import whisperx |
|
import random |
|
import numpy as np |
|
import webdataset as wds |
|
|
|
|
|
|
|
|
|
def fix_dots_in_names(name): |
|
name, ext = name.rsplit('.', 1) |
|
return ".".join((name.replace('.', '_'), ext)) |
|
|
|
def load_dataset(url, decode=True, rename_files=None): |
|
ds = wds.WebDataset(url, rename_files=rename_files) |
|
if not decode: return ds |
|
return ds.decode(wds.torch_audio) |
|
|
|
|
|
def extract_segments(vad_result, max_duration): |
|
binarize = whisperx.vad.Binarize(max_duration=max_duration) |
|
segments = binarize(vad_result) |
|
return [(x.start, x.end) for x in segments.get_timeline()] |
|
|
|
def segment_audio(vad_model, audio, sr=16000): |
|
vad_result = vad_model({"waveform": audio, "sample_rate": sr}) |
|
return extract_segments(vad_result, 30) |
|
|
|
|
|
def flac_to_vad_name(input): |
|
if '-flac-' in input: |
|
return input.rsplit("/", 1)[1].replace('flac', 'vad') + ".gz" |
|
else: |
|
return input.rsplit("/", 1)[1].replace('raw', 'vad') + ".gz" |
|
|
|
@call_parse |
|
def process_shard( |
|
input:str, |
|
output:str=None, |
|
fix_dots:bool=False, |
|
): |
|
if output is None: output = flac_to_vad_name(input) |
|
|
|
ds = torch.utils.data.DataLoader(load_dataset(input, rename_files=fix_dots_in_names if fix_dots else None), num_workers=2, batch_size=None) |
|
vad_model = whisperx.vad.load_vad_model('cuda') |
|
|
|
tmp = output+".tmp" |
|
with wds.TarWriter(tmp) as sink: |
|
for s in progress_bar(ds, total='noinfer'): |
|
audio, sr = s.get('flac', s.get('wav', (None, None))) |
|
if audio is None: |
|
print(f"warning: '{s['__key__']}' does not contain an audio file") |
|
continue |
|
sink.write({ |
|
"__key__": s['__key__'], |
|
"vad.npy": np.array(segment_audio(vad_model, audio, sr=sr), dtype=np.float16) |
|
}) |
|
os.rename(tmp, output) |
|
|