Spaces:
Running
Running
import os | |
import torch | |
import librosa | |
import functools | |
import scipy.stats | |
import numpy as np | |
CENTS_PER_BIN, MAX_FMAX, PITCH_BINS, SAMPLE_RATE, WINDOW_SIZE = 20, 2006, 360, 16000, 1024 | |
class Crepe(torch.nn.Module): | |
def __init__(self, model='full'): | |
super().__init__() | |
if model == 'full': | |
in_channels = [1, 1024, 128, 128, 128, 256] | |
out_channels = [1024, 128, 128, 128, 256, 512] | |
self.in_features = 2048 | |
elif model == 'large': | |
in_channels = [1, 768, 96, 96, 96, 192] | |
out_channels = [768, 96, 96, 96, 192, 384] | |
self.in_features = 1536 | |
elif model == 'medium': | |
in_channels = [1, 512, 64, 64, 64, 128] | |
out_channels = [512, 64, 64, 64, 128, 256] | |
self.in_features = 1024 | |
elif model == 'small': | |
in_channels = [1, 256, 32, 32, 32, 64] | |
out_channels = [256, 32, 32, 32, 64, 128] | |
self.in_features = 512 | |
elif model == 'tiny': | |
in_channels = [1, 128, 16, 16, 16, 32] | |
out_channels = [128, 16, 16, 16, 32, 64] | |
self.in_features = 256 | |
kernel_sizes = [(512, 1)] + 5 * [(64, 1)] | |
strides = [(4, 1)] + 5 * [(1, 1)] | |
batch_norm_fn = functools.partial(torch.nn.BatchNorm2d, eps=0.0010000000474974513, momentum=0.0) | |
self.conv1 = torch.nn.Conv2d(in_channels=in_channels[0], out_channels=out_channels[0], kernel_size=kernel_sizes[0], stride=strides[0]) | |
self.conv1_BN = batch_norm_fn(num_features=out_channels[0]) | |
self.conv2 = torch.nn.Conv2d(in_channels=in_channels[1], out_channels=out_channels[1], kernel_size=kernel_sizes[1], stride=strides[1]) | |
self.conv2_BN = batch_norm_fn(num_features=out_channels[1]) | |
self.conv3 = torch.nn.Conv2d(in_channels=in_channels[2], out_channels=out_channels[2], kernel_size=kernel_sizes[2], stride=strides[2]) | |
self.conv3_BN = batch_norm_fn(num_features=out_channels[2]) | |
self.conv4 = torch.nn.Conv2d(in_channels=in_channels[3], out_channels=out_channels[3], kernel_size=kernel_sizes[3], stride=strides[3]) | |
self.conv4_BN = batch_norm_fn(num_features=out_channels[3]) | |
self.conv5 = torch.nn.Conv2d(in_channels=in_channels[4], out_channels=out_channels[4], kernel_size=kernel_sizes[4], stride=strides[4]) | |
self.conv5_BN = batch_norm_fn(num_features=out_channels[4]) | |
self.conv6 = torch.nn.Conv2d(in_channels=in_channels[5], out_channels=out_channels[5], kernel_size=kernel_sizes[5], stride=strides[5]) | |
self.conv6_BN = batch_norm_fn(num_features=out_channels[5]) | |
self.classifier = torch.nn.Linear(in_features=self.in_features, out_features=PITCH_BINS) | |
def forward(self, x, embed=False): | |
x = self.embed(x) | |
if embed: return x | |
return torch.sigmoid(self.classifier(self.layer(x, self.conv6, self.conv6_BN).permute(0, 2, 1, 3).reshape(-1, self.in_features))) | |
def embed(self, x): | |
x = x[:, None, :, None] | |
return self.layer(self.layer(self.layer(self.layer(self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254)), self.conv2, self.conv2_BN), self.conv3, self.conv3_BN), self.conv4, self.conv4_BN), self.conv5, self.conv5_BN) | |
def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)): | |
return torch.nn.functional.max_pool2d(batch_norm(torch.nn.functional.relu(conv(torch.nn.functional.pad(x, padding)))), (2, 1), (2, 1)) | |
def viterbi(logits): | |
if not hasattr(viterbi, 'transition'): | |
xx, yy = np.meshgrid(range(360), range(360)) | |
transition = np.maximum(12 - abs(xx - yy), 0) | |
viterbi.transition = transition / transition.sum(axis=1, keepdims=True) | |
with torch.no_grad(): | |
probs = torch.nn.functional.softmax(logits, dim=1) | |
bins = torch.tensor(np.array([librosa.sequence.viterbi(sequence, viterbi.transition).astype(np.int64) for sequence in probs.cpu().numpy()]), device=probs.device) | |
return bins, bins_to_frequency(bins) | |
def predict(audio, sample_rate, hop_length=None, fmin=50, fmax=MAX_FMAX, model='full', return_periodicity=False, batch_size=None, device='cpu', pad=True, providers=None, onnx=False): | |
results = [] | |
if onnx: | |
import onnxruntime as ort | |
sess_options = ort.SessionOptions() | |
sess_options.log_severity_level = 3 | |
session = ort.InferenceSession(os.path.join("assets", "models", "predictors", f"crepe_{model}.onnx"), sess_options=sess_options, providers=providers) | |
for frames in preprocess(audio, sample_rate, hop_length, batch_size, device, pad): | |
result = postprocess(torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: frames.cpu().numpy()})[0].transpose(1, 0)[None]), fmin, fmax, return_periodicity) | |
results.append((result[0], result[1]) if isinstance(result, tuple) else result) | |
del session | |
if return_periodicity: | |
pitch, periodicity = zip(*results) | |
return torch.cat(pitch, 1), torch.cat(periodicity, 1) | |
return torch.cat(results, 1) | |
else: | |
with torch.no_grad(): | |
for frames in preprocess(audio, sample_rate, hop_length, batch_size, device, pad): | |
result = postprocess(infer(frames, model, device, embed=False).reshape(audio.size(0), -1, PITCH_BINS).transpose(1, 2), fmin, fmax, return_periodicity) | |
results.append((result[0].to(audio.device), result[1].to(audio.device)) if isinstance(result, tuple) else result.to(audio.device)) | |
if return_periodicity: | |
pitch, periodicity = zip(*results) | |
return torch.cat(pitch, 1), torch.cat(periodicity, 1) | |
return torch.cat(results, 1) | |
def bins_to_frequency(bins): | |
cents = CENTS_PER_BIN * bins + 1997.3794084376191 | |
return 10 * 2 ** ((cents + cents.new_tensor(scipy.stats.triang.rvs(c=0.5, loc=-CENTS_PER_BIN, scale=2 * CENTS_PER_BIN, size=cents.size()))) / 1200) | |
def frequency_to_bins(frequency, quantize_fn=torch.floor): | |
return quantize_fn(((1200 * torch.log2(frequency / 10)) - 1997.3794084376191) / CENTS_PER_BIN).int() | |
def infer(frames, model='full', device='cpu', embed=False): | |
if not hasattr(infer, 'model') or not hasattr(infer, 'capacity') or (hasattr(infer, 'capacity') and infer.capacity != model): load_model(device, model) | |
infer.model = infer.model.to(device) | |
return infer.model(frames, embed=embed) | |
def load_model(device, capacity='full'): | |
infer.capacity = capacity | |
infer.model = Crepe(capacity) | |
infer.model.load_state_dict(torch.load(os.path.join("assets", "models", "predictors", f"crepe_{capacity}.pth"), map_location=device)) | |
infer.model = infer.model.to(torch.device(device)) | |
infer.model.eval() | |
def postprocess(probabilities, fmin=0, fmax=MAX_FMAX, return_periodicity=False): | |
probabilities = probabilities.detach() | |
probabilities[:, :frequency_to_bins(torch.tensor(fmin))] = -float('inf') | |
probabilities[:, frequency_to_bins(torch.tensor(fmax), torch.ceil):] = -float('inf') | |
bins, pitch = viterbi(probabilities) | |
if not return_periodicity: return pitch | |
return pitch, periodicity(probabilities, bins) | |
def preprocess(audio, sample_rate, hop_length=None, batch_size=None, device='cpu', pad=True): | |
hop_length = sample_rate // 100 if hop_length is None else hop_length | |
if sample_rate != SAMPLE_RATE: | |
audio = torch.tensor(librosa.resample(audio.detach().cpu().numpy().squeeze(0), orig_sr=sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_vhq"), device=audio.device).unsqueeze(0) | |
hop_length = int(hop_length * SAMPLE_RATE / sample_rate) | |
if pad: | |
total_frames = 1 + int(audio.size(1) // hop_length) | |
audio = torch.nn.functional.pad(audio, (WINDOW_SIZE // 2, WINDOW_SIZE // 2)) | |
else: total_frames = 1 + int((audio.size(1) - WINDOW_SIZE) // hop_length) | |
batch_size = total_frames if batch_size is None else batch_size | |
for i in range(0, total_frames, batch_size): | |
frames = torch.nn.functional.unfold(audio[:, None, None, max(0, i * hop_length):min(audio.size(1), (i + batch_size - 1) * hop_length + WINDOW_SIZE)], kernel_size=(1, WINDOW_SIZE), stride=(1, hop_length)) | |
frames = frames.transpose(1, 2).reshape(-1, WINDOW_SIZE).to(device) | |
frames -= frames.mean(dim=1, keepdim=True) | |
frames /= torch.max(torch.tensor(1e-10, device=frames.device), frames.std(dim=1, keepdim=True)) | |
yield frames | |
def periodicity(probabilities, bins): | |
probs_stacked = probabilities.transpose(1, 2).reshape(-1, PITCH_BINS) | |
periodicity = probs_stacked.gather(1, bins.reshape(-1, 1).to(torch.int64)) | |
return periodicity.reshape(probabilities.size(0), probabilities.size(2)) | |
def mean(signals, win_length=9): | |
assert signals.dim() == 2 | |
signals = signals.unsqueeze(1) | |
mask = ~torch.isnan(signals) | |
padding = win_length // 2 | |
ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device) | |
avg_pooled = torch.nn.functional.conv1d(torch.where(mask, signals, torch.zeros_like(signals)), ones_kernel, stride=1, padding=padding) / torch.nn.functional.conv1d(mask.float(), ones_kernel, stride=1, padding=padding).clamp(min=1) | |
avg_pooled[avg_pooled == 0] = float("nan") | |
return avg_pooled.squeeze(1) | |
def median(signals, win_length): | |
assert signals.dim() == 2 | |
signals = signals.unsqueeze(1) | |
mask = ~torch.isnan(signals) | |
padding = win_length // 2 | |
x = torch.nn.functional.pad(torch.where(mask, signals, torch.zeros_like(signals)), (padding, padding), mode="reflect") | |
mask = torch.nn.functional.pad(mask.float(), (padding, padding), mode="constant", value=0) | |
x = x.unfold(2, win_length, 1) | |
mask = mask.unfold(2, win_length, 1) | |
x = x.contiguous().view(x.size()[:3] + (-1,)) | |
mask = mask.contiguous().view(mask.size()[:3] + (-1,)) | |
x_sorted, _ = torch.sort(torch.where(mask.bool(), x.float(), float("inf")).to(x), dim=-1) | |
median_pooled = x_sorted.gather(-1, ((mask.sum(dim=-1) - 1) // 2).clamp(min=0).unsqueeze(-1).long()).squeeze(-1) | |
median_pooled[torch.isinf(median_pooled)] = float("nan") | |
return median_pooled.squeeze(1) |