jsd219's picture
Update app.py
1f7d3e4 verified
raw
history blame contribute delete
752 Bytes
import gradio as gr
from transformers import pipeline
import torch
model_id = "ntu-spml/distilhubert"
device = 0 if torch.cuda.is_available() else -1
pipe = pipeline("audio-classification", model=model_id, device=device)
def classify_audio(filepath):
import time
start = time.time()
preds = pipe(filepath)
result = {p["label"]: round(p["score"], 3) for p in preds}
return result, round(time.time() - start, 2)
gr.Interface(
fn=classify_audio,
inputs=gr.Audio(type="filepath", label="Upload Audio"),
outputs=[gr.Label(label="Top Genres"), gr.Number(label="Time (s)")],
title="🎵 Music Genre Classifier",
description="Classifies the genre of uploaded audio using DistilHuBERT fine-tuned on GTZAN."
).launch()