|
|
|
|
|
|
|
__all__ = [] |
|
|
|
|
|
import sys |
|
import os |
|
import itertools |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import torch |
|
import torchaudio |
|
import torch.nn.functional as F |
|
from torch.profiler import profile, record_function, ProfilerActivity |
|
|
|
from fastprogress import progress_bar |
|
from fastcore.script import * |
|
|
|
import whisper, whisperx |
|
from . import vad, wh_transcribe, vq_stoks, extract_acoustic |
|
import webdataset as wds |
|
|
|
|
|
def flac_to_t2s_name(input): |
|
return input.rsplit("/", 1)[1].replace('flac', 't2s') + ".gz" |
|
|
|
|
|
class Transcriber: |
|
""" |
|
A helper class to transcribe a batch of 30 second audio chunks. |
|
""" |
|
def __init__(self, model_size, lang=False): |
|
self.model = whisperx.asr.load_model(model_size, "cuda", compute_type="float16", language=lang) |
|
|
|
self.model.vad_model({"waveform": torch.zeros(1, 16000), "sample_rate": 16000}) |
|
|
|
def transcribe(self, batch): |
|
batch = whisper.log_mel_spectrogram(batch) |
|
embs = self.model.model.encode(batch.cpu().numpy()) |
|
return self.model.tokenizer.tokenizer.decode_batch([x.sequences_ids[0] for x in |
|
self.model.model.model.generate( |
|
embs, |
|
[self.model.model.get_prompt(self.model.tokenizer, [], without_timestamps=True)]*len(batch), |
|
)]) |
|
|
|
|
|
@call_parse |
|
def prepare_t2s( |
|
input:str, |
|
proc_dataset_path:Path, |
|
output:str=None, |
|
vq_model:str="collabora/spear-tts-pytorch:whisper-vq-stoks.model", |
|
n_samples:int=None, |
|
batch_size:int=1, |
|
transcription_model:str="small.en", |
|
): |
|
if ":" in vq_model: |
|
repo, fname = vq_model.split(":", 1) |
|
vq_model = vq_stoks.RQBottleneckTransformer.load_model(repo, fname).cuda() |
|
else: |
|
vq_model = vq_stoks.RQBottleneckTransformer.load_model(local_filename=vq_model).cuda() |
|
transcriber = Transcriber(transcription_model) |
|
|
|
if input == "-": |
|
input = [f.strip() for f in sys.stdin.readlines()] |
|
assert output, "please provide the output shard name" |
|
else: |
|
if output is None: output = flac_to_t2s_name(input) |
|
input = [input] |
|
|
|
total = n_samples//batch_size if n_samples else 'noinfer' |
|
if n_samples: print(f"Benchmarking run of {n_samples} samples ({total} batches)") |
|
|
|
ds = wds.WebDataset(input, shardshuffle=True, rename_files=vad.fix_dots_in_names).compose( |
|
wds.decode(wds.torch_audio), |
|
vq_stoks.merge_in(vq_stoks.derived_dataset(proc_dataset_path, 'vad')), |
|
wds.map_dict(**{"vad.npy": lambda s: wh_transcribe.chunk_merger(s, wh_transcribe.random_cutter)}), |
|
lambda x: wh_transcribe.split_to_chunks(x), |
|
|
|
|
|
wds.select(lambda x: x['i'] != 0 and x['i'] != x['imax']), |
|
wds.to_tuple('__key__', 'rpad', 'samples'), |
|
wds.batched(64), |
|
) |
|
|
|
dl = wds.WebLoader(ds, num_workers=4, batch_size=None).unbatched().shuffle(2000).batched(batch_size) |
|
|
|
speakers = set() |
|
tmp = output+".tmp" |
|
with wds.TarWriter(tmp) as sink: |
|
for keys, rpads, samples in progress_bar(dl, total=total): |
|
with record_function('to_cuda'): |
|
csamples = samples.cuda() |
|
with record_function('transcribe'): |
|
txts = transcriber.transcribe(csamples) |
|
with record_function('vq_stoks'): |
|
stoks = vq_model.encode_audio(csamples) |
|
with record_function('from_cuda'): |
|
stoks = stoks.cpu().numpy().astype(np.int16) |
|
for key, rpad, txt, _stoks in zip(keys, rpads, txts, stoks): |
|
speakers.add(key.split('/')[1]) |
|
sink.write({ |
|
"__key__": key, |
|
"txt": txt, |
|
"stoks.npy": _stoks[:int(-rpad/16000 * 25)], |
|
}) |
|
with open(output+".speakers.txt", "w") as f: f.write("\n".join(speakers)) |
|
if not n_samples: |
|
os.rename(tmp, output) |
|
|