Spaces:
Running
on
Zero
Running
on
Zero
"""Audio processer for talking data. | |
Author: linzhihui.lzh | |
Date: 2024-12-12 | |
""" | |
import os | |
from re import A | |
import sys | |
import os.path as osp | |
from typing import List, Dict, Tuple, Optional, Union, Any | |
import yaml | |
from omegaconf import OmegaConf | |
import math | |
import librosa | |
import numpy as np | |
from einops import rearrange | |
import torch | |
import torch.nn.functional as F | |
from pydub import AudioSegment | |
# from audio_separator.separator import Separator | |
sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))))) | |
from src.utils.rprint import rlog as log | |
from src.utils.util import resample_audio | |
from src.models.audio.wav2vec_modified import Wav2VecModel | |
from src.models.audio.hubert import HubertModel | |
def pad_audio(audio, audio_unit=320, pad_threshold=80): | |
batch_size, audio_len = audio.shape | |
n_units = audio_len // audio_unit | |
side_len = math.ceil((audio_unit * n_units + pad_threshold - audio_len) / 2) | |
if side_len >= 0: | |
reflect_len = side_len // 2 | |
replicate_len = side_len % 2 | |
if reflect_len > 0: | |
audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect') | |
audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect') | |
if replicate_len > 0: | |
audio = F.pad(audio, (1, 1), mode='replicate') | |
return audio | |
def cut_audio(audio_path: str, save_dir: str, length=60) -> List[str]: | |
"""Cut audio into sub-divisions and return subfile paths. Supports wav format. | |
Args: | |
audio_path (str): the source audio file path | |
save_dir (str): the save directory of sub-divisions | |
length (int, optional): The max length of each sub-division. Defaults to 60 secs. | |
Returns: | |
List[str]: the subfile paths | |
""" | |
audio_name = osp.basename(audio_path).split('.')[0] | |
audio = AudioSegment.from_wav(audio_path) | |
segment_length = length * 1000. # pydub uses milliseconds | |
num_segments = math.ceil(len(audio) / segment_length) | |
os.makedirs(save_dir, exist_ok=True) | |
audio_list = [] | |
for i in range(num_segments): | |
start_time = i * segment_length | |
end_time = min((i + 1) * segment_length, len(audio)) | |
segment = audio[start_time:end_time] | |
path = osp.join(save_dir, f"{audio_name}_segment_{i+1}.wav") | |
audio_list.append(path) | |
segment.export(path, format="wav") | |
return audio_list | |
class AudioProcessor(object): | |
def __init__(self, cfg_path: str, is_training: bool = False) -> None: | |
cfg = OmegaConf.load(cfg_path) | |
self.cfg = cfg | |
self.is_training = is_training | |
log("========================================= Audio Processer =========================================") | |
log(OmegaConf.to_yaml(cfg)) | |
# setting device | |
self.device_id = cfg.device_params.device_id | |
self.use_half = cfg.device_params.flag_use_half_precision | |
if cfg.device_params.flag_force_cpu: | |
self.device = 'cpu' | |
else: | |
try: | |
if torch.backends.mps.is_available(): | |
self.device = 'mps' | |
else: | |
self.device = 'cuda:' + str(self.device_id) | |
except: | |
self.device = 'cuda:' + str(self.device_id) | |
# init audio separator | |
self.audio_separator = None | |
self.cache_dir = cfg.cache_dir | |
self.tmp_dir = cfg.tmp_dir | |
self.use_audio_separator = cfg.model_params.use_audio_separator | |
self.audio_separator_name = cfg.model_params.audio_separator_name | |
self.audio_separator_path = cfg.model_weights.audio_separator_path | |
self.set_audio_separator(cfg.cache_dir) | |
# load audio encoder, wav2vec or hubert | |
self.model_name = cfg.model_params.model_name | |
self.is_chinese = cfg.model_params.is_chinese | |
self.audio_encoder = self.load_model( | |
model_name = cfg.model_params.model_name, | |
model_type = cfg.model_params.model_type, | |
is_chinese = cfg.model_params.is_chinese, | |
) | |
self.only_last_features = cfg.model_params.only_last_features | |
if cfg.model_params.only_last_features: | |
self.feature_shape = (1, 768) | |
else: | |
self.feature_shape = (12, 768) # features of 12 blocks | |
# init data params | |
self.sample_strategy = cfg.data_params.sample_strategy | |
self.sample_rate = cfg.data_params.sample_rate | |
self.fps = cfg.data_params.fps | |
self.audio_unit = cfg.data_params.sample_rate / cfg.data_params.fps # num of audio samples per frame | |
self.max_length = cfg.data_params.max_length | |
self.subclip_len = cfg.data_params.sub_clip_length | |
self.save_to_cpu = cfg.data_params.save_to_cpu | |
self.pad_mode = cfg.data_params.audio_pad_mode | |
log("========================================= Audio Processer: Done =========================================") | |
def load_model(self, model_name: str="wav2vec", model_type: str="base", is_chinese: bool = False): | |
assert model_name in ["wav2vec", "hubert"], f"Unknown audio model {model_name}, only support wav2vec or hubert" | |
assert model_type in ["base", "large"], f"Unknown audio model type {model_type}, only support base or large" | |
if model_name == "wav2vec": | |
# load wav2vec model weights | |
if is_chinese: | |
if model_type == "base": | |
model_weight_path = self.cfg.model_weights.wav2vec_path.chinese.base | |
else: | |
model_weight_path = self.cfg.model_weights.wav2vec_path.chinese.large | |
else: | |
if model_type == "base": | |
model_weight_path = self.cfg.model_weights.wav2vec_path.default.base | |
else: | |
model_weight_path = self.cfg.model_weights.wav2vec_path.default.large | |
if model_weight_path is None: | |
raise ValueError(f"model_weight_path is None") | |
audio_encoder = Wav2VecModel.from_pretrained(model_weight_path, local_files_only=True).to(device=self.device) | |
else: | |
if is_chinese: | |
if model_type == "base": | |
model_weight_path = self.cfg.model_weights.hubert_path.chinese.base | |
else: | |
model_weight_path = self.cfg.model_weights.hubert_path.chinese.large | |
else: | |
if model_type == "base": | |
model_weight_path = self.cfg.model_weights.hubert_path.default.base | |
else: | |
model_weight_path = self.cfg.model_weights.hubert_path.default.large | |
if model_weight_path is None: | |
raise ValueError(f"model_weight_path is None") | |
audio_encoder = HubertModel.from_pretrained(model_weight_path, local_files_only=True).to(device=self.device) | |
log(f"{model_name}-{model_type}-chinese-{is_chinese} model has beed loaded from {model_weight_path}") | |
total_params = sum(p.numel() for p in audio_encoder.parameters()) | |
print('Number of parameter: % .4fM' % (total_params / 1e6)) | |
# weights initialization | |
audio_encoder.feature_extractor._freeze_parameters() | |
if not self.cfg.model_params.is_original: | |
frozen_layers = [0, 1] | |
for name, param in audio_encoder.named_parameters(): | |
if name.startswith("feature_projection"): | |
param.requires_grad = False | |
if name.startswith("encoder.layers"): | |
layer = int(name.split(".")[2]) | |
if layer in frozen_layers: | |
param.requires_grad = False | |
audio_encoder = audio_encoder.to(self.device) | |
if self.use_half: | |
audio_encoder = audio_encoder.half() | |
audio_encoder.eval() | |
return audio_encoder | |
def set_audio_separator(self, output_dir: str) -> None: | |
del self.audio_separator | |
if self.audio_separator_name is not None and self.use_audio_separator: | |
try: | |
os.makedirs(output_dir, exist_ok=True) | |
except OSError as _: | |
print("Fail to create the output cache dir.") | |
self.audio_separator = Separator( | |
output_dir=output_dir, | |
output_single_stem="vocals", | |
model_file_dir=self.audio_separator_path, | |
) | |
self.audio_separator.load_model(self.audio_separator_name) | |
assert self.audio_separator.model_instance is not None, "Fail to load audio separate model." | |
else: | |
self.audio_separator=None | |
log("Use audio directly without vocals seperator.") | |
def seperate_audio(self, audio_path: str, output_dir: Union[str, None] = None) -> str: | |
if output_dir is not None: | |
if output_dir != self.cache_dir: | |
# reload audio separator | |
self.set_audio_separator(output_dir) | |
if self.audio_separator is not None: | |
# 1. separate vocals | |
# TODO: process in memory | |
try: | |
outputs = self.audio_separator.separate(audio_path) | |
if len(outputs) <= 0: | |
raise RuntimeError("Audio separate failed.") | |
vocal_audio_file = outputs[0] | |
vocal_audio_name, _ = os.path.splitext(vocal_audio_file) | |
vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file) | |
vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate) | |
except Exception as e: | |
log(f"Fail to separate vocals from {audio_path}, error info [{e}]") | |
vocal_audio_file=audio_path | |
else: | |
vocal_audio_file=audio_path | |
return vocal_audio_file | |
def load_audio(self, audio_path: str, mono: bool = True, duration: Optional[float] = None) -> Any: | |
try: | |
audio_data, sampling_rate = librosa.load(audio_path, sr=self.sample_rate, mono=mono, duration=duration) | |
except Exception as e: | |
raise RuntimeError(f"Fail to load audio from {audio_path}, error info [{e}]") | |
return audio_data, sampling_rate | |
def prepare_audio_data(self, audio_data: Union[np.ndarray, torch.Tensor], n_frames: Optional[int]=None) -> Tuple[List[Any], int]: | |
"""Prepare audio data for processing. | |
""" | |
clip_len = int(len(audio_data) / self.audio_unit) | |
if n_frames is not None: | |
if abs(n_frames - clip_len) > 2: | |
log(f"The number of frames must be close to the clip length (in 80ms), got {n_frames} and {clip_len}") | |
return [], n_frames | |
clip_len = n_frames | |
else: | |
n_frames = clip_len | |
# normalize audio, replace Wav2Vec2FeatureExtractor | |
if isinstance(audio_data, np.ndarray): | |
audio_data = torch.from_numpy(audio_data).to(self.device) | |
assert audio_data.ndim == 1, 'Audio must be 1D tensor.' | |
audio_data = (audio_data - torch.mean(audio_data)) / (torch.std(audio_data) + 1e-7) | |
#log(f"audio loaded! {audio_data.shape}") | |
# padding | |
# padding audio to fit the clip length | |
n_audio_samples = round(self.audio_unit * clip_len) | |
n_padding_audio_samples = n_audio_samples - len(audio_data) | |
n_padding_frames = math.ceil(n_padding_audio_samples / self.audio_unit) | |
if n_padding_audio_samples > 0: | |
if self.pad_mode == 'zero': | |
padding_value = 0 | |
elif self.pad_mode == 'replicate': | |
padding_value = float(audio_data[-1]) | |
else: | |
raise ValueError(f'Unknown pad mode: {self.pad_mode}') | |
audio_data = F.pad(audio_data, (0, n_padding_audio_samples), value=padding_value) | |
# devide audio into sub-divisions for saving GPU memory | |
audio_segments = [] | |
if clip_len <= self.subclip_len: | |
n_subdivision = 1 | |
subclip_len = clip_len | |
else: | |
n_subdivision = math.ceil(clip_len / self.subclip_len) | |
subclip_len = self.subclip_len | |
for i in range(0, n_subdivision): | |
start_idx = i * subclip_len | |
end_idx = min(start_idx + subclip_len, clip_len) | |
# debug | |
#log(f"[{i+1}/{n_subdivision}] data index [{round(start_idx * self.audio_unit)}, {round(end_idx * self.audio_unit)})") | |
audio_segments.append( | |
{ | |
"data": audio_data[round(start_idx * self.audio_unit):round(end_idx * self.audio_unit)].unsqueeze(0), | |
"start_idx": start_idx, | |
"end_idx": end_idx, | |
"length": end_idx - start_idx | |
} | |
) | |
return audio_segments, n_frames | |
def get_audio_embedding(self, audio, clip_len: int) -> torch.Tensor: | |
if audio.ndim == 2: | |
# Extract audio features | |
assert audio.shape[1] == 16000 * clip_len / self.fps, \ | |
f'Incorrect audio length {audio.shape[1]}' | |
# Extract audio features | |
if self.use_half: | |
audio = audio.half() | |
embeddings = self.audio_encoder( | |
pad_audio(audio), seq_len=clip_len, sample_strategy=self.sample_strategy, output_hidden_states=True | |
) # (N, L, 768) | |
assert len(embeddings) > 0, "Fail to extract audio embedding" | |
if self.only_last_features: | |
audio_emb = embeddings.last_hidden_state.squeeze(0) | |
else: | |
audio_emb = torch.stack( | |
embeddings.hidden_states[1:], dim=1 | |
).squeeze(0) | |
audio_emb = rearrange(audio_emb, "b s d -> s b d") | |
elif audio.ndim == 3: | |
assert audio.shape[1] == clip_len, f'Incorrect audio feature length {audio.shape[1]}' | |
audio_emb = audio | |
else: | |
raise ValueError(f'Incorrect audio input shape {audio.shape}') | |
return audio_emb | |
def get_audio_embeddings(self, audio_segments: List[Any]) -> Optional[torch.Tensor]: | |
audio_embs = [] | |
for audio_segment in audio_segments: | |
if self.is_training: | |
audio_emb = self.get_audio_embedding(audio_segment["data"], audio_segment["length"]) | |
else: | |
with torch.no_grad(): | |
audio_emb = self.get_audio_embedding(audio_segment["data"], audio_segment["length"]) | |
audio_emb = audio_emb.cpu() if self.save_to_cpu else audio_emb | |
audio_embs.append(audio_emb) | |
#log(f"audio segment [{audio_segment['start_idx']}, {audio_segment['end_idx']}) has been processed.") | |
if len(audio_embs) == 0: | |
return None | |
audio_emb = torch.cat(audio_embs, dim=0) | |
return audio_emb | |
def preprocess( | |
self, | |
audio_path: str, | |
n_frames: Optional[int] = None, | |
duration: Optional[float] = None, | |
need_seperate: bool = False | |
): | |
""" Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate. | |
The separated vocal track is then converted into wav2vec2 for further processing or analysis. | |
""" | |
if need_seperate: | |
vocal_audio_file = self.seperate_audio(audio_path) | |
else: | |
vocal_audio_file = audio_path | |
audio_data, sampling_rate = self.load_audio(vocal_audio_file, duration=duration) | |
assert sampling_rate == 16000, "The sample rate of audio must be 16000" | |
audio_segments, n_frames = self.prepare_audio_data(audio_data, n_frames) | |
audio_emb = self.get_audio_embeddings(audio_segments) | |
if audio_emb is None: | |
log(f"{audio_path} has been processed, but no audio embedding, set as 'None'.") | |
#else: | |
#log(f"{audio_path} has been processed, audio embedding shape {audio_emb.shape}.") | |
return audio_emb, n_frames | |
def preprocess_long( | |
self, | |
audio_path: str, | |
need_seperate: bool = False | |
): | |
audio_list = cut_audio(audio_path, self.tmp_dir, length=self.max_length) | |
audio_emb_list = [] | |
l = 0 | |
for idx, audio_path in enumerate(audio_list): | |
padding = (idx+1) == len(audio_list) | |
emb, length = self.preprocess(audio_path, need_seperate=need_seperate) | |
audio_emb_list.append(emb) | |
log(f"Processing audio {idx+1}/{len(audio_list)}, path: {audio_path} length: {length}") | |
l += length | |
audio_emb = torch.cat(audio_emb_list) | |
audio_length = l | |
# remove tmp file | |
for audio_path in audio_list: | |
os.remove(audio_path) | |
return audio_emb, audio_length | |
def __enter__(self): | |
return self | |