Model Details
Model Name: Fine-Tuned HuBERT for Music Genre Classification
Base Model: facebook/hubert-base-ls960
Dataset: ccmusic-database/music_genre
Quantization: Applied for optimized inference
Training Device: CUDA (GPU)
Dataset Information
Dataset Structure:
DatasetDict({
train: Dataset({
features: ['mel', 'cqt', 'chroma', 'fst_level_label', 'sec_level_label', 'thr_level_label'],
num_rows: 29100
})
validation: Dataset({
features: ['mel', 'cqt', 'chroma', 'fst_level_label', 'sec_level_label', 'thr_level_label'],
num_rows: 3637
})
test: Dataset({
features: ['mel', 'cqt', 'chroma', 'fst_level_label', 'sec_level_label', 'thr_level_label'],
num_rows: 3638
})
})
Available Splits:
Train: 29,100 examples
Validation: 3,637 examples
Test: 3,638 examples
Feature Representation:
mel: Mel spectrogram representation of the audio
cqt: Constant-Q transform (CQT) representation
chroma: Chroma feature representation
fst_level_label: First-level classification (Classic vs. Non-classic)
sec_level_label: Second-level classification (9 genres)
thr_level_label: Third-level classification (15 subgenres)
Training Details
Number of Classes (sec_level_label): 9
Class Names:
Symphony, Opera, Solo, Chamber, Pop, Dance_and_house, Indie, Soul_or_RnB, Rock
Training Process:
Fine-tuned for 4 additional epochs
Loss reduced progressively across epochs
Performance Metrics
Epochs= 4
Training Loss=0.2371
Validation Loss=0.3278
Accuracy=0.8812
F1 Score=0.8809
Classification Report:
precision recall f1-score support
Symphony 0.9372 0.9894 0.9626 377
Opera 0.9923 0.9149 0.9520 141
Solo 0.9778 0.9600 0.9688 275
Chamber 0.9548 0.9651 0.9599 372
Pop 0.9283 0.8240 0.8730 534
Dance_and_house 0.8742 0.9350 0.9036 446
Indie 0.9309 0.8115 0.8671 382
Soul_or_RnB 0.9366 0.8105 0.8690 401
Rock 0.8067 0.9465 0.8710 710
accuracy 0.9041 3638
macro avg 0.9265 0.9063 0.9141 3638
weighted avg 0.9090 0.9041 0.9038 3638
Inference Example
import torch
import torch.nn as nn
from torchvision import models, transforms
import librosa
import librosa.display
import numpy as np
import io
import time
from PIL import Image
import matplotlib.pyplot as plt
def load_model(model_path, num_classes=9):
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, num_classes)
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
model = model.half()
model.eval()
return model
def audio_to_mel_spectrogram(audio_path, sr=22050, n_mels=128, hop_length=512):
"""Convert audio file to a mel spectrogram image."""
y, sr = librosa.load(audio_path, sr=sr)
mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels, hop_length=hop_length)
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
plt.figure(figsize=(10, 4))
librosa.display.specshow(mel_spec_db, y_axis='mel', x_axis='time', sr=sr, hop_length=hop_length)
plt.axis('off')
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close()
buf.seek(0)
return Image.open(buf).convert('RGB')
def predict_genre(model_path, audio_path, class_mapping):
"""Predict genre from an audio file using a quantized model."""
model = load_model(model_path, num_classes=len(class_mapping))
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
spectrogram = audio_to_mel_spectrogram(audio_path)
image_tensor = transform(spectrogram).unsqueeze(0).half()
start_time = time.time()
with torch.no_grad():
outputs = model(image_tensor)
probabilities = torch.softmax(outputs, dim=1)[0]
top_probs, top_idx = torch.topk(probabilities, 3)
predictions = [(class_mapping[idx.item()], prob.item()) for idx, prob in zip(top_idx, top_probs)]
inference_time = time.time() - start_time
return predictions[0][0], predictions, inference_time
# Example usage
if __name__ == "__main__":
class_mapping = {0: 'Rock', 1: 'Pop', 2: 'Jazz', 3: 'Classical', 4: 'Hip-Hop', 5: 'Blues', 6: 'Reggae', 7: 'Country', 8: 'Metal'}
model_path = "path/to/quantized_model.pth"
audio_path = "path/to/audio.wav"
genre, predictions, inference_time = predict_genre(model_path, audio_path, class_mapping)
print(f"Predicted Genre: {genre}")
print(f"Top Predictions: {predictions}")
print(f"Inference Time: {inference_time:.4f} seconds")
Quantization & Optimization
Quantization applied for faster inference with minimal loss in accuracy.
Optimized for deployment on edge devices with reduced model size.
Usage
Input: Audio samples converted to mel, cqt, and chroma representations
Output: Predicted music genre label at three hierarchical levels
Limitations
Model performance may vary for underrepresented genres.
Requires sufficient high-quality input features for accurate predictions.
Future Improvements
Further fine-tuning on diverse datasets
Expansion to additional subgenres for better granularity
Implementation of real-time inference support