|
"""Audio processer for talking data. |
|
Author: linzhihui.lzh |
|
Date: 2024-12-12 |
|
""" |
|
import os |
|
from re import A |
|
import sys |
|
import os.path as osp |
|
|
|
from typing import List, Dict, Tuple, Optional, Union, Any |
|
|
|
import yaml |
|
from omegaconf import OmegaConf |
|
|
|
import math |
|
import librosa |
|
import numpy as np |
|
|
|
from einops import rearrange |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from pydub import AudioSegment |
|
|
|
|
|
sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))))) |
|
from src.utils.rprint import rlog as log |
|
from src.utils.util import resample_audio |
|
|
|
from src.models.audio.wav2vec_modified import Wav2VecModel |
|
from src.models.audio.hubert import HubertModel |
|
|
|
|
|
def pad_audio(audio, audio_unit=320, pad_threshold=80): |
|
batch_size, audio_len = audio.shape |
|
n_units = audio_len // audio_unit |
|
side_len = math.ceil((audio_unit * n_units + pad_threshold - audio_len) / 2) |
|
if side_len >= 0: |
|
reflect_len = side_len // 2 |
|
replicate_len = side_len % 2 |
|
if reflect_len > 0: |
|
audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect') |
|
audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect') |
|
if replicate_len > 0: |
|
audio = F.pad(audio, (1, 1), mode='replicate') |
|
|
|
return audio |
|
|
|
|
|
def cut_audio(audio_path: str, save_dir: str, length=60) -> List[str]: |
|
"""Cut audio into sub-divisions and return subfile paths. Supports wav format. |
|
|
|
Args: |
|
audio_path (str): the source audio file path |
|
save_dir (str): the save directory of sub-divisions |
|
length (int, optional): The max length of each sub-division. Defaults to 60 secs. |
|
|
|
Returns: |
|
List[str]: the subfile paths |
|
""" |
|
audio_name = osp.basename(audio_path).split('.')[0] |
|
audio = AudioSegment.from_wav(audio_path) |
|
segment_length = length * 1000. |
|
num_segments = math.ceil(len(audio) / segment_length) |
|
|
|
os.makedirs(save_dir, exist_ok=True) |
|
audio_list = [] |
|
|
|
for i in range(num_segments): |
|
start_time = i * segment_length |
|
end_time = min((i + 1) * segment_length, len(audio)) |
|
segment = audio[start_time:end_time] |
|
|
|
path = osp.join(save_dir, f"{audio_name}_segment_{i+1}.wav") |
|
audio_list.append(path) |
|
segment.export(path, format="wav") |
|
return audio_list |
|
|
|
|
|
class AudioProcessor(object): |
|
def __init__(self, cfg_path: str, is_training: bool = False) -> None: |
|
cfg = OmegaConf.load(cfg_path) |
|
self.cfg = cfg |
|
self.is_training = is_training |
|
log("========================================= Audio Processer =========================================") |
|
log(OmegaConf.to_yaml(cfg)) |
|
|
|
|
|
self.device_id = cfg.device_params.device_id |
|
self.use_half = cfg.device_params.flag_use_half_precision |
|
if cfg.device_params.flag_force_cpu: |
|
self.device = 'cpu' |
|
else: |
|
try: |
|
if torch.backends.mps.is_available(): |
|
self.device = 'mps' |
|
else: |
|
self.device = 'cuda:' + str(self.device_id) |
|
except: |
|
self.device = 'cuda:' + str(self.device_id) |
|
|
|
|
|
self.audio_separator = None |
|
self.cache_dir = cfg.cache_dir |
|
self.tmp_dir = cfg.tmp_dir |
|
self.use_audio_separator = cfg.model_params.use_audio_separator |
|
self.audio_separator_name = cfg.model_params.audio_separator_name |
|
self.audio_separator_path = cfg.model_weights.audio_separator_path |
|
self.set_audio_separator(cfg.cache_dir) |
|
|
|
|
|
self.model_name = cfg.model_params.model_name |
|
self.is_chinese = cfg.model_params.is_chinese |
|
self.audio_encoder = self.load_model( |
|
model_name = cfg.model_params.model_name, |
|
model_type = cfg.model_params.model_type, |
|
is_chinese = cfg.model_params.is_chinese, |
|
) |
|
self.only_last_features = cfg.model_params.only_last_features |
|
if cfg.model_params.only_last_features: |
|
self.feature_shape = (1, 768) |
|
else: |
|
self.feature_shape = (12, 768) |
|
|
|
|
|
self.sample_strategy = cfg.data_params.sample_strategy |
|
self.sample_rate = cfg.data_params.sample_rate |
|
self.fps = cfg.data_params.fps |
|
self.audio_unit = cfg.data_params.sample_rate / cfg.data_params.fps |
|
self.max_length = cfg.data_params.max_length |
|
self.subclip_len = cfg.data_params.sub_clip_length |
|
self.save_to_cpu = cfg.data_params.save_to_cpu |
|
self.pad_mode = cfg.data_params.audio_pad_mode |
|
|
|
log("========================================= Audio Processer: Done =========================================") |
|
|
|
def load_model(self, model_name: str="wav2vec", model_type: str="base", is_chinese: bool = False): |
|
assert model_name in ["wav2vec", "hubert"], f"Unknown audio model {model_name}, only support wav2vec or hubert" |
|
assert model_type in ["base", "large"], f"Unknown audio model type {model_type}, only support base or large" |
|
|
|
if model_name == "wav2vec": |
|
|
|
if is_chinese: |
|
if model_type == "base": |
|
model_weight_path = self.cfg.model_weights.wav2vec_path.chinese.base |
|
else: |
|
model_weight_path = self.cfg.model_weights.wav2vec_path.chinese.large |
|
else: |
|
if model_type == "base": |
|
model_weight_path = self.cfg.model_weights.wav2vec_path.default.base |
|
else: |
|
model_weight_path = self.cfg.model_weights.wav2vec_path.default.large |
|
if model_weight_path is None: |
|
raise ValueError(f"model_weight_path is None") |
|
audio_encoder = Wav2VecModel.from_pretrained(model_weight_path, local_files_only=True).to(device=self.device) |
|
else: |
|
if is_chinese: |
|
if model_type == "base": |
|
model_weight_path = self.cfg.model_weights.hubert_path.chinese.base |
|
else: |
|
model_weight_path = self.cfg.model_weights.hubert_path.chinese.large |
|
else: |
|
if model_type == "base": |
|
model_weight_path = self.cfg.model_weights.hubert_path.default.base |
|
else: |
|
model_weight_path = self.cfg.model_weights.hubert_path.default.large |
|
if model_weight_path is None: |
|
raise ValueError(f"model_weight_path is None") |
|
audio_encoder = HubertModel.from_pretrained(model_weight_path, local_files_only=True).to(device=self.device) |
|
|
|
log(f"{model_name}-{model_type}-chinese-{is_chinese} model has beed loaded from {model_weight_path}") |
|
total_params = sum(p.numel() for p in audio_encoder.parameters()) |
|
print('Number of parameter: % .4fM' % (total_params / 1e6)) |
|
|
|
|
|
audio_encoder.feature_extractor._freeze_parameters() |
|
if not self.cfg.model_params.is_original: |
|
frozen_layers = [0, 1] |
|
for name, param in audio_encoder.named_parameters(): |
|
if name.startswith("feature_projection"): |
|
param.requires_grad = False |
|
if name.startswith("encoder.layers"): |
|
layer = int(name.split(".")[2]) |
|
if layer in frozen_layers: |
|
param.requires_grad = False |
|
|
|
audio_encoder = audio_encoder.to(self.device) |
|
if self.use_half: |
|
audio_encoder = audio_encoder.half() |
|
audio_encoder.eval() |
|
return audio_encoder |
|
|
|
def set_audio_separator(self, output_dir: str) -> None: |
|
del self.audio_separator |
|
|
|
if self.audio_separator_name is not None and self.use_audio_separator: |
|
try: |
|
os.makedirs(output_dir, exist_ok=True) |
|
except OSError as _: |
|
print("Fail to create the output cache dir.") |
|
self.audio_separator = Separator( |
|
output_dir=output_dir, |
|
output_single_stem="vocals", |
|
model_file_dir=self.audio_separator_path, |
|
) |
|
self.audio_separator.load_model(self.audio_separator_name) |
|
assert self.audio_separator.model_instance is not None, "Fail to load audio separate model." |
|
else: |
|
self.audio_separator=None |
|
log("Use audio directly without vocals seperator.") |
|
|
|
def seperate_audio(self, audio_path: str, output_dir: Union[str, None] = None) -> str: |
|
if output_dir is not None: |
|
if output_dir != self.cache_dir: |
|
|
|
self.set_audio_separator(output_dir) |
|
|
|
if self.audio_separator is not None: |
|
|
|
|
|
try: |
|
outputs = self.audio_separator.separate(audio_path) |
|
if len(outputs) <= 0: |
|
raise RuntimeError("Audio separate failed.") |
|
|
|
vocal_audio_file = outputs[0] |
|
vocal_audio_name, _ = os.path.splitext(vocal_audio_file) |
|
vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file) |
|
vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate) |
|
except Exception as e: |
|
log(f"Fail to separate vocals from {audio_path}, error info [{e}]") |
|
vocal_audio_file=audio_path |
|
else: |
|
vocal_audio_file=audio_path |
|
|
|
return vocal_audio_file |
|
|
|
def load_audio(self, audio_path: str, mono: bool = True, duration: Optional[float] = None) -> Any: |
|
try: |
|
audio_data, sampling_rate = librosa.load(audio_path, sr=self.sample_rate, mono=mono, duration=duration) |
|
except Exception as e: |
|
raise RuntimeError(f"Fail to load audio from {audio_path}, error info [{e}]") |
|
return audio_data, sampling_rate |
|
|
|
def prepare_audio_data(self, audio_data: Union[np.ndarray, torch.Tensor], n_frames: Optional[int]=None) -> Tuple[List[Any], int]: |
|
"""Prepare audio data for processing. |
|
""" |
|
clip_len = int(len(audio_data) / self.audio_unit) |
|
if n_frames is not None: |
|
if abs(n_frames - clip_len) > 2: |
|
log(f"The number of frames must be close to the clip length (in 80ms), got {n_frames} and {clip_len}") |
|
return [], n_frames |
|
clip_len = n_frames |
|
else: |
|
n_frames = clip_len |
|
|
|
|
|
if isinstance(audio_data, np.ndarray): |
|
audio_data = torch.from_numpy(audio_data).to(self.device) |
|
assert audio_data.ndim == 1, 'Audio must be 1D tensor.' |
|
audio_data = (audio_data - torch.mean(audio_data)) / (torch.std(audio_data) + 1e-7) |
|
|
|
|
|
|
|
|
|
n_audio_samples = round(self.audio_unit * clip_len) |
|
n_padding_audio_samples = n_audio_samples - len(audio_data) |
|
n_padding_frames = math.ceil(n_padding_audio_samples / self.audio_unit) |
|
if n_padding_audio_samples > 0: |
|
if self.pad_mode == 'zero': |
|
padding_value = 0 |
|
elif self.pad_mode == 'replicate': |
|
padding_value = float(audio_data[-1]) |
|
else: |
|
raise ValueError(f'Unknown pad mode: {self.pad_mode}') |
|
audio_data = F.pad(audio_data, (0, n_padding_audio_samples), value=padding_value) |
|
|
|
|
|
audio_segments = [] |
|
if clip_len <= self.subclip_len: |
|
n_subdivision = 1 |
|
subclip_len = clip_len |
|
else: |
|
n_subdivision = math.ceil(clip_len / self.subclip_len) |
|
subclip_len = self.subclip_len |
|
|
|
for i in range(0, n_subdivision): |
|
start_idx = i * subclip_len |
|
end_idx = min(start_idx + subclip_len, clip_len) |
|
|
|
|
|
audio_segments.append( |
|
{ |
|
"data": audio_data[round(start_idx * self.audio_unit):round(end_idx * self.audio_unit)].unsqueeze(0), |
|
"start_idx": start_idx, |
|
"end_idx": end_idx, |
|
"length": end_idx - start_idx |
|
} |
|
) |
|
return audio_segments, n_frames |
|
|
|
def get_audio_embedding(self, audio, clip_len: int) -> torch.Tensor: |
|
if audio.ndim == 2: |
|
|
|
assert audio.shape[1] == 16000 * clip_len / self.fps, \ |
|
f'Incorrect audio length {audio.shape[1]}' |
|
|
|
|
|
if self.use_half: |
|
audio = audio.half() |
|
embeddings = self.audio_encoder( |
|
pad_audio(audio), seq_len=clip_len, sample_strategy=self.sample_strategy, output_hidden_states=True |
|
) |
|
assert len(embeddings) > 0, "Fail to extract audio embedding" |
|
|
|
if self.only_last_features: |
|
audio_emb = embeddings.last_hidden_state.squeeze(0) |
|
else: |
|
audio_emb = torch.stack( |
|
embeddings.hidden_states[1:], dim=1 |
|
).squeeze(0) |
|
audio_emb = rearrange(audio_emb, "b s d -> s b d") |
|
|
|
elif audio.ndim == 3: |
|
assert audio.shape[1] == clip_len, f'Incorrect audio feature length {audio.shape[1]}' |
|
audio_emb = audio |
|
else: |
|
raise ValueError(f'Incorrect audio input shape {audio.shape}') |
|
|
|
return audio_emb |
|
|
|
def get_audio_embeddings(self, audio_segments: List[Any]) -> Optional[torch.Tensor]: |
|
audio_embs = [] |
|
for audio_segment in audio_segments: |
|
if self.is_training: |
|
audio_emb = self.get_audio_embedding(audio_segment["data"], audio_segment["length"]) |
|
else: |
|
with torch.no_grad(): |
|
audio_emb = self.get_audio_embedding(audio_segment["data"], audio_segment["length"]) |
|
|
|
audio_emb = audio_emb.cpu() if self.save_to_cpu else audio_emb |
|
audio_embs.append(audio_emb) |
|
|
|
|
|
if len(audio_embs) == 0: |
|
return None |
|
|
|
audio_emb = torch.cat(audio_embs, dim=0) |
|
|
|
return audio_emb |
|
|
|
def preprocess( |
|
self, |
|
audio_path: str, |
|
n_frames: Optional[int] = None, |
|
duration: Optional[float] = None, |
|
need_seperate: bool = False |
|
): |
|
""" Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate. |
|
The separated vocal track is then converted into wav2vec2 for further processing or analysis. |
|
""" |
|
if need_seperate: |
|
vocal_audio_file = self.seperate_audio(audio_path) |
|
else: |
|
vocal_audio_file = audio_path |
|
|
|
audio_data, sampling_rate = self.load_audio(vocal_audio_file, duration=duration) |
|
|
|
assert sampling_rate == 16000, "The sample rate of audio must be 16000" |
|
audio_segments, n_frames = self.prepare_audio_data(audio_data, n_frames) |
|
audio_emb = self.get_audio_embeddings(audio_segments) |
|
if audio_emb is None: |
|
log(f"{audio_path} has been processed, but no audio embedding, set as 'None'.") |
|
|
|
|
|
return audio_emb, n_frames |
|
|
|
def preprocess_long( |
|
self, |
|
audio_path: str, |
|
need_seperate: bool = False |
|
): |
|
audio_list = cut_audio(audio_path, self.tmp_dir, length=self.max_length) |
|
audio_emb_list = [] |
|
l = 0 |
|
|
|
for idx, audio_path in enumerate(audio_list): |
|
padding = (idx+1) == len(audio_list) |
|
emb, length = self.preprocess(audio_path, need_seperate=need_seperate) |
|
audio_emb_list.append(emb) |
|
log(f"Processing audio {idx+1}/{len(audio_list)}, path: {audio_path} length: {length}") |
|
l += length |
|
|
|
audio_emb = torch.cat(audio_emb_list) |
|
audio_length = l |
|
|
|
|
|
for audio_path in audio_list: |
|
os.remove(audio_path) |
|
|
|
return audio_emb, audio_length |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
|
|
|