Arnab Das
bug fix
e646637
raw
history blame
3.66 kB
import os
import torch
import torchaudio
import numpy as np
from omegaconf import OmegaConf
from torch.nn.functional import pad, normalize, softmax
from manipulate_model.model import Model
def get_config_and_model(model_root="manipulate_model/demo-model/audio"):
config_path = os.path.join(model_root, "config.yaml")
config = OmegaConf.load(config_path)
if isinstance(config.model.encoder, str):
config.model.encoder = OmegaConf.load(config.model.encoder)
if isinstance(config.model.decoder, str):
config.model.decoder = OmegaConf.load(config.model.decoder)
model = Model(config)
#weights = torch.load(os.path.join(model_root, "weights.pt"))
#model.load_state_dict(weights["model_state_dict"])
return config, model
def load_audio(file_path, config):
# Load audio
# Parameters
# ----------
# file_path : str
# Path to audio file
# Returns
# -------
# torch.Tensor
audio = None
if file_path.endswith(".wav") or file_path.endswith(".flac"):
audio, sample_rate = torchaudio.load(file_path)
elif file_path.endswith(".mp3"):
pass
elif file_path.endswith(".mp4"):
#_, audio, _ = read_video(file_path)
pass
return preprocess_audio(audio, config)
def preprocess_audio(audio, config, step_size=1):
# Preprocess audio
# Parameters
# ----------
# audio : torch.Tensor
# Audio signal
# config : OmegaConf
# Configuration object
# Returns
# -------
# torch.Tensor : Normalized audio signal
window_size = config.data.window_size
sr = config.data.sr
fps = config.data.fps
audio_len = audio.shape[1]
step_size = step_size * (sr // fps)
window_size = window_size * (sr // fps)
audio = pad(audio, (window_size, window_size), "constant", 0)
sliced_audio = []
for i in range(0, audio_len + window_size, step_size):
audio_slice = audio[:, i : i + window_size]
if audio_slice.shape[1] < window_size:
audio_slice = pad(
audio_slice, (0, window_size - audio_slice.shape[1]), "constant", 0
)
audio_slice = normalize(audio_slice, dim=1)
sliced_audio.append(audio_slice)
sliced_audio = torch.stack(sliced_audio).squeeze()
return sliced_audio
def infere(model, x, config, device="cpu", bs=8):
print(x)
model.eval()
x = load_audio(x, config)
# Inference (x is a stack of windows)
frame_predictions = []
with torch.no_grad():
n_iter = x.shape[0]
for i in range(0, n_iter, bs):
input_batch = x[i: i + bs]
input_batch = input_batch.to(device)
output = softmax(model(input_batch), dim=1)
frame_predictions.append(output.cpu().numpy())
frame_predictions = np.concatenate(frame_predictions, axis=0)[:,0]
return frame_predictions
def convert_frame_predictions_to_timestamps(frame_predictions, fps, window_size):
# Convert frame predictions to timestamps
# Parameters
# ----------
# frame_predictions : np.ndarray
# Frame predictions
# fps : int
# Frames per second
# Returns
# -------
# np.ndarray : Timestamps
frame_predictions = (
frame_predictions[
int(window_size / 2) : -int(window_size / 2), 0
] # removes the padding, does not consider step size as of now
.round()
.astype(int)
)
timestamps = []
for i, frame_prediction in enumerate(frame_predictions):
if frame_prediction == 1:
timestamps.append(i / fps)
return timestamps