|
|
|
|
|
|
|
__all__ = ['flac_to_s2a_name'] |
|
|
|
|
|
import sys |
|
import os |
|
import itertools |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import torch |
|
import torchaudio |
|
import torch.nn.functional as F |
|
from torch.profiler import profile, record_function, ProfilerActivity |
|
|
|
from fastprogress import progress_bar |
|
from fastcore.script import * |
|
|
|
import whisper |
|
from . import vad, wh_transcribe, vq_stoks, extract_acoustic |
|
import webdataset as wds |
|
|
|
|
|
def flac_to_s2a_name(input): |
|
if '-flac-' in input: |
|
return input.rsplit("/", 1)[1].replace('flac', 's2a') + ".gz" |
|
else: |
|
return input.rsplit("/", 1)[1].replace('raw', 's2a') + ".gz" |
|
|
|
|
|
def resampler(newsr = 24000, key = 'samples_24k'): |
|
_last_sr = None |
|
tform = None |
|
|
|
def _resample(samples): |
|
for s in samples: |
|
sr = s['sample_rate'] |
|
if sr != newsr: |
|
if sr != _last_sr: tform = torchaudio.transforms.Resample(sr, newsr) |
|
s[key] = tform(s['samples']) |
|
else: |
|
s[key] = s['samples'] |
|
yield s |
|
|
|
return _resample |
|
|
|
|
|
@call_parse |
|
def prepare_s2a( |
|
input:str, |
|
proc_dataset_path:Path, |
|
output:str=None, |
|
vq_model:str="collabora/spear-tts-pytorch:whisper-vq-stoks.model", |
|
n_samples:int=None, |
|
batch_size:int=1, |
|
fix_dots:bool=False, |
|
): |
|
if ":" in vq_model: |
|
repo, fname = vq_model.split(":", 1) |
|
vq_model = vq_stoks.RQBottleneckTransformer.load_model(repo, fname).cuda() |
|
else: |
|
vq_model = vq_stoks.RQBottleneckTransformer.load_model(local_filename=vq_model).cuda() |
|
amodel = extract_acoustic.load_model() |
|
amodel.set_target_bandwidth(3) |
|
|
|
if input == "-": |
|
input = [f.strip() for f in sys.stdin.readlines()] |
|
assert output, "please provide the output shard name" |
|
else: |
|
if output is None: output = flac_to_s2a_name(input) |
|
input = [input] |
|
|
|
total = n_samples//batch_size if n_samples else 'noinfer' |
|
|
|
ds = wds.WebDataset(input, shardshuffle=True, rename_files=vad.fix_dots_in_names if fix_dots else None).compose( |
|
wds.decode(wds.torch_audio), |
|
wds.select(lambda x: 'wav' in x or 'flac' in x), |
|
vq_stoks.merge_in(vq_stoks.derived_dataset(proc_dataset_path, 'vad')), |
|
wds.map_dict(**{"vad.npy":wh_transcribe.chunk_merger}), |
|
lambda x: wh_transcribe.split_to_chunks(x), |
|
resampler(), |
|
resampler(16000, 'samples_16k'), |
|
wds.to_tuple('__key__', 'rpad_s', 'samples_16k', 'samples_24k'), |
|
wds.batched(64), |
|
) |
|
|
|
dl = wds.WebLoader(ds, num_workers=4, batch_size=None).unbatched().shuffle(2000).batched(batch_size) |
|
|
|
speakers = set() |
|
tmp = output+".tmp" |
|
with wds.TarWriter(tmp) as sink: |
|
for keys, rpad_ss, samples, samples24k in progress_bar(dl, total=total): |
|
with record_function('to_cuda'): |
|
samples, samples24k = samples.cuda(), samples24k.unsqueeze(1).cuda() |
|
with record_function('encodec'): |
|
atoks = amodel.encode(samples24k)[0][0] |
|
with record_function('vq_stoks'): |
|
stoks = vq_model.encode_audio(samples) |
|
with record_function('from_cuda'): |
|
atoks, stoks = atoks.cpu().numpy().astype(np.int16), stoks.cpu().numpy().astype(np.int16) |
|
for key, rpad_s, _atoks, _stoks in zip(keys, rpad_ss, atoks, stoks): |
|
speakers.add(key.split('/')[1]) |
|
sink.write({ |
|
"__key__": key, |
|
"atoks.npy": _atoks[:,:int(-rpad_s * 75)], |
|
"stoks.npy": _stoks[:int(-rpad_s * 25)], |
|
}) |
|
with open(output+".speakers.txt", "w") as f: f.write("\n".join(speakers)) |
|
if not n_samples: |
|
os.rename(tmp, output) |
|
|