|
|
|
|
|
|
|
__all__ = [] |
|
|
|
|
|
import os |
|
import io |
|
import time |
|
import torch |
|
import torchaudio |
|
|
|
|
|
from pathlib import Path |
|
import json |
|
from fastprogress import progress_bar, master_bar |
|
import numpy as np |
|
import random |
|
|
|
import whisper |
|
|
|
from torch import nn |
|
import torch.nn.functional as F |
|
from torch.utils.data.dataloader import DataLoader |
|
|
|
from fastcore.script import * |
|
|
|
from . import vad |
|
import webdataset as wds |
|
|
|
|
|
|
|
|
|
def random_cutter(dur): |
|
if random.random() < 0.5: |
|
return dur > 28 * (random.random()*0.95+0.05) |
|
else: |
|
return dur > 28 |
|
|
|
def chunk_merger(segments, should_cut=lambda x: x > 28): |
|
if len(segments) == 0: return segments |
|
curr_start = segments[0][0] |
|
curr_end = 0 |
|
merged = [] |
|
|
|
for ts,te in segments: |
|
if should_cut(te - curr_start) and curr_end - curr_start > 0: |
|
merged.append((curr_start, curr_end)) |
|
curr_start = ts |
|
curr_end = te |
|
merged.append((curr_start, curr_end)) |
|
return merged |
|
|
|
|
|
def merge_in(*datasets): |
|
"""Merge multiple datasets into the current one returning samples with the union of keys. |
|
|
|
It requires (and validates) all datasets to have the same ordering of keys so you have |
|
to use it before any sample shuffling. Shard shuffling is ok. |
|
""" |
|
def merge_loop(main_samples): |
|
for samples in zip(*[main_samples]+[iter(x) for x in datasets]): |
|
key = samples[0]['__key__'] |
|
news = {} |
|
for s in samples: |
|
assert s['__key__'] == key |
|
news.update(s) |
|
yield news |
|
return merge_loop |
|
|
|
|
|
import copy |
|
|
|
|
|
|
|
|
|
def wds_compose(ds, *args): |
|
ds = copy.copy(ds) |
|
ds.pipeline = copy.copy(ds.pipeline) |
|
for f in args: |
|
ds.append(f) |
|
return ds |
|
|
|
|
|
def split_to_chunks(stream, pad_to_seconds=30, random_shift=False): |
|
for s in stream: |
|
audio, sr = s.get('flac', s.get('wav', (None, None))) |
|
if audio is None: |
|
print(f"warning: '{s['__key__']}' does not contain an audio file") |
|
continue |
|
imax = len(s['vad.npy']) - 1 |
|
for i,(ts,te) in enumerate(s['vad.npy']): |
|
samples = audio[0,int(ts*sr):int(te*sr)] |
|
if pad_to_seconds is not None: |
|
padding = pad_to_seconds*sr-samples.shape[-1] |
|
lpad = random.randint(0, padding) if random_shift else 0 |
|
samples = F.pad(samples, (lpad, padding-lpad)) |
|
yield {"__key__": s['__key__'] + f"_{i:03d}", |
|
"__url__": s['__url__'], |
|
"i": i, "imax": imax, |
|
"tstart": ts, "tend": te, "total_seconds": audio.shape[-1]/sr, |
|
"lpad": lpad, "rpad": padding-lpad, |
|
"lpad_s": lpad/sr, "rpad_s": (padding-lpad)/sr, |
|
"samples": samples, "sample_rate": sr} |
|
|
|
|
|
def flac_to_txt_name(input, model_size): |
|
return input.rsplit("/", 1)[1].replace('flac', f'{model_size}-txt') + ".gz" |
|
|
|
@call_parse |
|
def process_shard( |
|
input:str, |
|
output:str=None, |
|
bs:int=None, |
|
n_samples:int=None, |
|
whisper_model:str="base.en" |
|
): |
|
if output is None: output = flac_to_txt_name(input, whisper_model) |
|
if bs is None: bs = 16 |
|
if n_samples is None: n_samples = 'noinfer' |
|
else: n_samples = n_samples // bs |
|
|
|
ds = wds_compose(vad.load_dataset(input), |
|
merge_in(wds.WebDataset(vad.flac_to_vad_name(input)).decode()), |
|
wds.map_dict(**{"vad.npy":chunk_merger}), |
|
split_to_chunks, |
|
wds.to_tuple('__key__', 'samples'), |
|
wds.batched(bs), |
|
) |
|
dl = DataLoader(ds, num_workers=2, batch_size=None) |
|
|
|
whmodel = whisper.load_model(whisper_model) |
|
decoding_options = whisper.DecodingOptions(language='en') |
|
|
|
tmp = output+".tmp" |
|
with wds.TarWriter(tmp) as sink: |
|
for keys, samples in progress_bar(dl, total=n_samples): |
|
with torch.no_grad(): |
|
embs = whmodel.encoder(whisper.log_mel_spectrogram(samples).cuda()) |
|
decs = whmodel.decode(embs, decoding_options) |
|
for key, dec in zip(keys, decs): |
|
sink.write({ |
|
"__key__": key, |
|
"txt": dec.text, |
|
}) |
|
os.rename(tmp, output) |
|
|