|
import os |
|
import torch |
|
import torchaudio |
|
|
|
from functools import wraps |
|
from types import SimpleNamespace |
|
from torch.nn import SyncBatchNorm |
|
from hyperpyyaml import load_hyperpyyaml |
|
|
|
from torch.nn import DataParallel as DP |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
MAIN_PROC_ONLY = 0 |
|
|
|
def fetch(filename, source): |
|
return os.path.abspath(os.path.join(source, filename)) |
|
|
|
def run_on_main(func, args=None, kwargs=None, post_func=None, post_args=None, post_kwargs=None, run_post_on_main=False): |
|
if args is None: args = [] |
|
if kwargs is None: kwargs = {} |
|
if post_args is None: post_args = [] |
|
if post_kwargs is None: post_kwargs = {} |
|
|
|
main_process_only(func)(*args, **kwargs) |
|
ddp_barrier() |
|
|
|
if post_func is not None: |
|
if run_post_on_main: post_func(*post_args, **post_kwargs) |
|
else: |
|
if not if_main_process(): post_func(*post_args, **post_kwargs) |
|
ddp_barrier() |
|
|
|
def is_distributed_initialized(): |
|
return (torch.distributed.is_available() and torch.distributed.is_initialized()) |
|
|
|
def if_main_process(): |
|
if is_distributed_initialized(): return torch.distributed.get_rank() == 0 |
|
else: return True |
|
|
|
class MainProcessContext: |
|
def __enter__(self): |
|
global MAIN_PROC_ONLY |
|
|
|
MAIN_PROC_ONLY += 1 |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_value, traceback): |
|
global MAIN_PROC_ONLY |
|
|
|
MAIN_PROC_ONLY -= 1 |
|
|
|
def main_process_only(function): |
|
@wraps(function) |
|
def main_proc_wrapped_func(*args, **kwargs): |
|
with MainProcessContext(): |
|
return function(*args, **kwargs) if if_main_process() else None |
|
|
|
return main_proc_wrapped_func |
|
|
|
def ddp_barrier(): |
|
if MAIN_PROC_ONLY >= 1 or not is_distributed_initialized(): return |
|
|
|
if torch.distributed.get_backend() == torch.distributed.Backend.NCCL: torch.distributed.barrier(device_ids=[torch.cuda.current_device()]) |
|
else: torch.distributed.barrier() |
|
|
|
class Resample(torch.nn.Module): |
|
def __init__(self, orig_freq=16000, new_freq=16000, *args, **kwargs): |
|
super().__init__() |
|
|
|
self.orig_freq = orig_freq |
|
self.new_freq = new_freq |
|
self.resampler = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=new_freq, *args, **kwargs) |
|
|
|
def forward(self, waveforms): |
|
if self.orig_freq == self.new_freq: return waveforms |
|
|
|
unsqueezed = False |
|
if len(waveforms.shape) == 2: |
|
waveforms = waveforms.unsqueeze(1) |
|
unsqueezed = True |
|
elif len(waveforms.shape) == 3: waveforms = waveforms.transpose(1, 2) |
|
else: raise ValueError |
|
|
|
self.resampler.to(waveforms.device) |
|
resampled_waveform = self.resampler(waveforms) |
|
|
|
return resampled_waveform.squeeze(1) if unsqueezed else resampled_waveform.transpose(1, 2) |
|
|
|
class AudioNormalizer: |
|
def __init__(self, sample_rate=16000, mix="avg-to-mono"): |
|
self.sample_rate = sample_rate |
|
|
|
if mix not in ["avg-to-mono", "keep"]: raise ValueError |
|
|
|
self.mix = mix |
|
self._cached_resamplers = {} |
|
|
|
def __call__(self, audio, sample_rate): |
|
if sample_rate not in self._cached_resamplers: self._cached_resamplers[sample_rate] = Resample(sample_rate, self.sample_rate) |
|
return self._mix(self._cached_resamplers[sample_rate](audio.unsqueeze(0)).squeeze(0)) |
|
|
|
def _mix(self, audio): |
|
flat_input = audio.dim() == 1 |
|
|
|
if self.mix == "avg-to-mono": |
|
if flat_input: return audio |
|
return torch.mean(audio, 1) |
|
|
|
if self.mix == "keep": return audio |
|
|
|
class Pretrained(torch.nn.Module): |
|
HPARAMS_NEEDED, MODULES_NEEDED = [], [] |
|
def __init__(self, modules=None, hparams=None, run_opts=None, freeze_params=True): |
|
super().__init__() |
|
|
|
for arg, default in {"device": "cpu", "data_parallel_count": -1, "data_parallel_backend": False, "distributed_launch": False, "distributed_backend": "nccl", "jit": False, "jit_module_keys": None, "compile": False, "compile_module_keys": None, "compile_mode": "reduce-overhead", "compile_using_fullgraph": False, "compile_using_dynamic_shape_tracing": False}.items(): |
|
if run_opts is not None and arg in run_opts: setattr(self, arg, run_opts[arg]) |
|
elif hparams is not None and arg in hparams: setattr(self, arg, hparams[arg]) |
|
else: setattr(self, arg, default) |
|
|
|
self.mods = torch.nn.ModuleDict(modules) |
|
|
|
for module in self.mods.values(): |
|
if module is not None: module.to(self.device) |
|
|
|
if self.HPARAMS_NEEDED and hparams is None: raise ValueError |
|
|
|
if hparams is not None: |
|
for hp in self.HPARAMS_NEEDED: |
|
if hp not in hparams: raise ValueError |
|
|
|
self.hparams = SimpleNamespace(**hparams) |
|
|
|
self._prepare_modules(freeze_params) |
|
self.audio_normalizer = hparams.get("audio_normalizer", AudioNormalizer()) |
|
|
|
def _prepare_modules(self, freeze_params): |
|
self._compile() |
|
self._wrap_distributed() |
|
|
|
if freeze_params: |
|
self.mods.eval() |
|
for p in self.mods.parameters(): |
|
p.requires_grad = False |
|
|
|
def _compile(self): |
|
compile_available = hasattr(torch, "compile") |
|
if not compile_available and self.compile_module_keys is not None: raise ValueError |
|
|
|
compile_module_keys = set() |
|
if self.compile: compile_module_keys = set(self.mods) if self.compile_module_keys is None else set(self.compile_module_keys) |
|
|
|
jit_module_keys = set() |
|
if self.jit: jit_module_keys = set(self.mods) if self.jit_module_keys is None else set(self.jit_module_keys) |
|
|
|
for name in compile_module_keys | jit_module_keys: |
|
if name not in self.mods: raise ValueError |
|
|
|
for name in compile_module_keys: |
|
try: |
|
module = torch.compile(self.mods[name], mode=self.compile_mode, fullgraph=self.compile_using_fullgraph, dynamic=self.compile_using_dynamic_shape_tracing) |
|
except Exception: |
|
continue |
|
|
|
self.mods[name] = module.to(self.device) |
|
jit_module_keys.discard(name) |
|
|
|
for name in jit_module_keys: |
|
module = torch.jit.script(self.mods[name]) |
|
self.mods[name] = module.to(self.device) |
|
|
|
def _compile_jit(self): |
|
self._compile() |
|
|
|
def _wrap_distributed(self): |
|
if not self.distributed_launch and not self.data_parallel_backend: return |
|
elif self.distributed_launch: |
|
for name, module in self.mods.items(): |
|
if any(p.requires_grad for p in module.parameters()): self.mods[name] = DDP(SyncBatchNorm.convert_sync_batchnorm(module), device_ids=[self.device]) |
|
else: |
|
for name, module in self.mods.items(): |
|
if any(p.requires_grad for p in module.parameters()): self.mods[name] = DP(module) if self.data_parallel_count == -1 else DP(module, [i for i in range(self.data_parallel_count)]) |
|
|
|
@classmethod |
|
def from_hparams(cls, source, hparams_file="hyperparams.yaml", overrides={}, download_only=False, overrides_must_match=True, **kwargs): |
|
with open(fetch(filename=hparams_file, source=source)) as fin: |
|
hparams = load_hyperpyyaml(fin, overrides, overrides_must_match=overrides_must_match) |
|
|
|
pretrainer = hparams.get("pretrainer", None) |
|
|
|
if pretrainer is not None: |
|
run_on_main(pretrainer.collect_files, kwargs={"default_source": source}) |
|
if not download_only: |
|
pretrainer.load_collected() |
|
return cls(hparams["modules"], hparams, **kwargs) |
|
else: return cls(hparams["modules"], hparams, **kwargs) |
|
|
|
class EncoderClassifier(Pretrained): |
|
MODULES_NEEDED = ["compute_features", "mean_var_norm", "embedding_model", "classifier"] |
|
|
|
def encode_batch(self, wavs, wav_lens=None, normalize=False): |
|
if len(wavs.shape) == 1: wavs = wavs.unsqueeze(0) |
|
if wav_lens is None: wav_lens = torch.ones(wavs.shape[0], device=self.device) |
|
|
|
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) |
|
wavs = wavs.float() |
|
|
|
embeddings = self.mods.embedding_model(self.mods.mean_var_norm(self.mods.compute_features(wavs), wav_lens), wav_lens) |
|
|
|
if normalize: embeddings = self.hparams.mean_var_norm_emb(embeddings, torch.ones(embeddings.shape[0], device=self.device)) |
|
return embeddings |
|
|
|
def classify_batch(self, wavs, wav_lens=None): |
|
out_prob = self.mods.classifier(self.encode_batch(wavs, wav_lens)).squeeze(1) |
|
score, index = torch.max(out_prob, dim=-1) |
|
|
|
return out_prob, score, index, self.hparams.label_encoder.decode_torch(index) |
|
|
|
def forward(self, wavs, wav_lens=None): |
|
return self.classify_batch(wavs, wav_lens) |