Spaces:
Runtime error
Runtime error
Upload with huggingface_hub
Browse files- DESCRIPTION.md +1 -0
- README.md +6 -7
- data_setups.py +80 -0
- requirements.txt +6 -0
- run.py +50 -0
DESCRIPTION.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
This demo identifies musical instruments from an audio file. It uses Gradio's Audio and Label components.
|
README.md
CHANGED
@@ -1,12 +1,11 @@
|
|
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.6
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
|
2 |
---
|
3 |
+
title: musical_instrument_identification_main
|
4 |
+
emoji: 🔥
|
5 |
+
colorFrom: indigo
|
6 |
+
colorTo: indigo
|
7 |
sdk: gradio
|
8 |
sdk_version: 3.6
|
9 |
+
app_file: run.py
|
10 |
pinned: false
|
11 |
---
|
|
|
|
data_setups.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Make function to find classes in target directory
|
2 |
+
import os
|
3 |
+
import librosa
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from torchaudio.transforms import Resample
|
7 |
+
|
8 |
+
SAMPLE_RATE = 44100
|
9 |
+
AUDIO_LEN = 2.90
|
10 |
+
|
11 |
+
# Parameters to control the MelSpec generation
|
12 |
+
N_MELS = 128
|
13 |
+
F_MIN = 20
|
14 |
+
F_MAX = 16000
|
15 |
+
N_FFT = 1024
|
16 |
+
HOP_LEN = 512
|
17 |
+
|
18 |
+
# Make function to find classes in target directory
|
19 |
+
def find_classes(directory: str):
|
20 |
+
# 1. Get the class names by scanning the target directory
|
21 |
+
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
|
22 |
+
# 2. Raise an error if class names not found
|
23 |
+
if not classes:
|
24 |
+
raise FileNotFoundError(f"Couldn't find any classes in {directory}.")
|
25 |
+
# 3. Crearte a dictionary of index labels (computers prefer numerical rather than string labels)
|
26 |
+
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
|
27 |
+
return classes, class_to_idx
|
28 |
+
|
29 |
+
def resample(wav, sample_rate, new_sample_rate):
|
30 |
+
if wav.shape[0] >= 2:
|
31 |
+
wav = torch.mean(wav, dim=0)
|
32 |
+
else:
|
33 |
+
wav = wav.squeeze(0)
|
34 |
+
if sample_rate > new_sample_rate:
|
35 |
+
resampler = Resample(sample_rate, new_sample_rate)
|
36 |
+
wav = resampler(wav)
|
37 |
+
return wav
|
38 |
+
|
39 |
+
def mono_to_color(X, eps=1e-6, mean=None, std=None):
|
40 |
+
X = np.stack([X, X, X], axis=-1)
|
41 |
+
# Standardize
|
42 |
+
mean = mean or X.mean()
|
43 |
+
std = std or X.std()
|
44 |
+
X = (X - mean) / (std + eps)
|
45 |
+
# Normalize to [0, 255]
|
46 |
+
_min, _max = X.min(), X.max()
|
47 |
+
if (_max - _min) > eps:
|
48 |
+
V = np.clip(X, _min, _max)
|
49 |
+
V = 255 * (V - _min) / (_max - _min)
|
50 |
+
V = V.astype(np.uint8)
|
51 |
+
else:
|
52 |
+
V = np.zeros_like(X, dtype=np.uint8)
|
53 |
+
return V
|
54 |
+
|
55 |
+
def normalize(image, mean=None, std=None):
|
56 |
+
image = image / 255.0
|
57 |
+
if mean is not None and std is not None:
|
58 |
+
image = (image - mean) / std
|
59 |
+
return np.moveaxis(image, 2, 0).astype(np.float32)
|
60 |
+
|
61 |
+
def compute_melspec(wav, sample_rate=SAMPLE_RATE):
|
62 |
+
melspec = librosa.feature.melspectrogram(
|
63 |
+
y=wav,
|
64 |
+
sr=sample_rate,
|
65 |
+
n_fft=N_FFT,
|
66 |
+
fmin=F_MIN,
|
67 |
+
fmax=F_MAX,
|
68 |
+
n_mels=N_MELS,
|
69 |
+
hop_length=HOP_LEN
|
70 |
+
)
|
71 |
+
melspec = librosa.power_to_db(melspec).astype(np.float32)
|
72 |
+
return melspec
|
73 |
+
|
74 |
+
def audio_preprocess(wav, sample_rate):
|
75 |
+
wav = wav.numpy()
|
76 |
+
melspec = compute_melspec(wav, sample_rate)
|
77 |
+
image = mono_to_color(melspec)
|
78 |
+
image = normalize(image, mean=None, std=None)
|
79 |
+
image = torch.from_numpy(image)
|
80 |
+
return image
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.12.0
|
2 |
+
torchvision==0.13.0
|
3 |
+
torchaudio==0.12.0
|
4 |
+
gradio==3.1.4
|
5 |
+
librosa==0.9.2
|
6 |
+
gdownhttps://gradio-main-build.s3.amazonaws.com/c3bec6153737855510542e8154391f328ac72606/gradio-3.6-py3-none-any.whl
|
run.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch, torchaudio
|
3 |
+
from timeit import default_timer as timer
|
4 |
+
from data_setups import audio_preprocess, resample
|
5 |
+
import gdown
|
6 |
+
|
7 |
+
url = 'https://drive.google.com/uc?id=1X5CR18u0I-ZOi_8P0cNptCe5JGk9Ro0C'
|
8 |
+
output = 'piano.wav'
|
9 |
+
gdown.download(url, output, quiet=False)
|
10 |
+
url = 'https://drive.google.com/uc?id=1W-8HwmGR5SiyDbUcGAZYYDKdCIst07__'
|
11 |
+
output= 'torch_efficientnet_fold2_CNN.pth'
|
12 |
+
gdown.download(url, output, quiet=False)
|
13 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
SAMPLE_RATE = 44100
|
15 |
+
AUDIO_LEN = 2.90
|
16 |
+
model = torch.load("torch_efficientnet_fold2_CNN.pth", map_location=torch.device('cpu'))
|
17 |
+
LABELS = [
|
18 |
+
"Cello", "Clarinet", "Flute", "Acoustic Guitar", "Electric Guitar", "Organ", "Piano", "Saxophone", "Trumpet", "Violin", "Voice"
|
19 |
+
]
|
20 |
+
example_list = [
|
21 |
+
["piano.wav"]
|
22 |
+
]
|
23 |
+
|
24 |
+
|
25 |
+
def predict(audio_path):
|
26 |
+
start_time = timer()
|
27 |
+
wavform, sample_rate = torchaudio.load(audio_path)
|
28 |
+
wav = resample(wavform, sample_rate, SAMPLE_RATE)
|
29 |
+
if len(wav) > int(AUDIO_LEN * SAMPLE_RATE):
|
30 |
+
wav = wav[:int(AUDIO_LEN * SAMPLE_RATE)]
|
31 |
+
else:
|
32 |
+
print(f"input length {len(wav)} too small!, need over {int(AUDIO_LEN * SAMPLE_RATE)}")
|
33 |
+
return
|
34 |
+
img = audio_preprocess(wav, SAMPLE_RATE).unsqueeze(0)
|
35 |
+
model.eval()
|
36 |
+
with torch.inference_mode():
|
37 |
+
pred_probs = torch.softmax(model(img), dim=1)
|
38 |
+
pred_labels_and_probs = {LABELS[i]: float(pred_probs[0][i]) for i in range(len(LABELS))}
|
39 |
+
pred_time = round(timer() - start_time, 5)
|
40 |
+
return pred_labels_and_probs, pred_time
|
41 |
+
|
42 |
+
demo = gr.Interface(fn=predict,
|
43 |
+
inputs=gr.Audio(type="filepath"),
|
44 |
+
outputs=[gr.Label(num_top_classes=11, label="Predictions"),
|
45 |
+
gr.Number(label="Prediction time (s)")],
|
46 |
+
examples=example_list,
|
47 |
+
cache_examples=False
|
48 |
+
)
|
49 |
+
|
50 |
+
demo.launch(debug=False)
|