YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Model Details

Model Name: Fine-Tuned HuBERT for Music Genre Classification

Base Model: facebook/hubert-base-ls960

Dataset: ccmusic-database/music_genre

Quantization: Applied for optimized inference

Training Device: CUDA (GPU)

Dataset Information

Dataset Structure:

DatasetDict({
    train: Dataset({
        features: ['mel', 'cqt', 'chroma', 'fst_level_label', 'sec_level_label', 'thr_level_label'],
        num_rows: 29100
    })
    validation: Dataset({
        features: ['mel', 'cqt', 'chroma', 'fst_level_label', 'sec_level_label', 'thr_level_label'],
        num_rows: 3637
    })
    test: Dataset({
        features: ['mel', 'cqt', 'chroma', 'fst_level_label', 'sec_level_label', 'thr_level_label'],
        num_rows: 3638
    })
})

Available Splits:

Train: 29,100 examples

Validation: 3,637 examples

Test: 3,638 examples

Feature Representation:

mel: Mel spectrogram representation of the audio

cqt: Constant-Q transform (CQT) representation

chroma: Chroma feature representation

fst_level_label: First-level classification (Classic vs. Non-classic)

sec_level_label: Second-level classification (9 genres)

thr_level_label: Third-level classification (15 subgenres)

Training Details

Number of Classes (sec_level_label): 9

Class Names:

Symphony, Opera, Solo, Chamber, Pop, Dance_and_house, Indie, Soul_or_RnB, Rock

Training Process:

Fine-tuned for 4 additional epochs

Loss reduced progressively across epochs

Performance Metrics

Epochs= 4

Training Loss=0.2371

Validation Loss=0.3278

Accuracy=0.8812

F1 Score=0.8809

Classification Report:
                 precision    recall  f1-score   support

       Symphony     0.9372    0.9894    0.9626       377
          Opera     0.9923    0.9149    0.9520       141
           Solo     0.9778    0.9600    0.9688       275
        Chamber     0.9548    0.9651    0.9599       372
            Pop     0.9283    0.8240    0.8730       534
Dance_and_house     0.8742    0.9350    0.9036       446
          Indie     0.9309    0.8115    0.8671       382
    Soul_or_RnB     0.9366    0.8105    0.8690       401
           Rock     0.8067    0.9465    0.8710       710

       accuracy                         0.9041      3638
      macro avg     0.9265    0.9063    0.9141      3638
   weighted avg     0.9090    0.9041    0.9038      3638

Inference Example

import torch
import torch.nn as nn
from torchvision import models, transforms
import librosa
import librosa.display
import numpy as np
import io
import time
from PIL import Image
import matplotlib.pyplot as plt

def load_model(model_path, num_classes=9):
    
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.half()
    model.eval()
    return model

def audio_to_mel_spectrogram(audio_path, sr=22050, n_mels=128, hop_length=512):
    """Convert audio file to a mel spectrogram image."""
    y, sr = librosa.load(audio_path, sr=sr)
    mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels, hop_length=hop_length)
    mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
    plt.figure(figsize=(10, 4))
    librosa.display.specshow(mel_spec_db, y_axis='mel', x_axis='time', sr=sr, hop_length=hop_length)
    plt.axis('off')
    buf = io.BytesIO()
    plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
    plt.close()
    buf.seek(0)
    return Image.open(buf).convert('RGB')

def predict_genre(model_path, audio_path, class_mapping):
    """Predict genre from an audio file using a quantized model."""
    model = load_model(model_path, num_classes=len(class_mapping))
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    
    spectrogram = audio_to_mel_spectrogram(audio_path)
    image_tensor = transform(spectrogram).unsqueeze(0).half()
    
    start_time = time.time()
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.softmax(outputs, dim=1)[0]
        top_probs, top_idx = torch.topk(probabilities, 3)
        predictions = [(class_mapping[idx.item()], prob.item()) for idx, prob in zip(top_idx, top_probs)]
    inference_time = time.time() - start_time
    
    return predictions[0][0], predictions, inference_time

# Example usage
if __name__ == "__main__":
    class_mapping = {0: 'Rock', 1: 'Pop', 2: 'Jazz', 3: 'Classical', 4: 'Hip-Hop', 5: 'Blues', 6: 'Reggae', 7: 'Country', 8: 'Metal'}
    model_path = "path/to/quantized_model.pth"
    audio_path = "path/to/audio.wav"
    genre, predictions, inference_time = predict_genre(model_path, audio_path, class_mapping)
    print(f"Predicted Genre: {genre}")
    print(f"Top Predictions: {predictions}")
    print(f"Inference Time: {inference_time:.4f} seconds")

Quantization & Optimization

Quantization applied for faster inference with minimal loss in accuracy.

Optimized for deployment on edge devices with reduced model size.

Usage

Input: Audio samples converted to mel, cqt, and chroma representations

Output: Predicted music genre label at three hierarchical levels

Limitations

Model performance may vary for underrepresented genres.

Requires sufficient high-quality input features for accurate predictions.

Future Improvements

Further fine-tuning on diverse datasets

Expansion to additional subgenres for better granularity

Implementation of real-time inference support

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support