File size: 1,957 Bytes
3ba39f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import streamlit as st
from ced_model.feature_extraction_ced import CedFeatureExtractor
from ced_model.modeling_ced import CedForAudioClassification
import torchaudio
import torch
import os
import soundfile as sf

model_name = "mispeech/ced-tiny"
feature_extractor = CedFeatureExtractor.from_pretrained(model_name)
model = CedForAudioClassification.from_pretrained(model_name)

st.title("Audio Classification App")
st.subheader("Trained on 50 classes of ESC 50 dataset")
st.write("Upload an audio file to predict its class.")

audio_file = st.file_uploader("Upload Audio File", type=["wav","mp3","m4a"])

if audio_file is not None:
    st.write(f"Uploaded file: {audio_file.name}")
    
    try:
        temp_file_path = "temp.wav"
        with open(temp_file_path, "wb") as f:
            f.write(audio_file.read())
        
        try:
            audio, sampling_rate = torchaudio.load(temp_file_path)
        except Exception:
            st.warning("Fallback to soundfile for audio loading.")
            audio_data, sampling_rate = sf.read(temp_file_path)
            audio = torch.tensor(audio_data).unsqueeze(0)  

        if sampling_rate != 16000:
            st.warning("Resampling audio to 16000 Hz...")
            resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
            audio = resampler(audio)
            sampling_rate = 16000
        
        inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
        
        with torch.no_grad():
            logits = model(**inputs).logits
        
        predicted_class_id = torch.argmax(logits, dim=-1).item()
        predicted_label = model.config.id2label[predicted_class_id]
        
        st.success(f"Predicted Class: {predicted_label}")
        
        os.remove(temp_file_path)
    except Exception as e:
        st.error(f"An error occurred: {e}")
else:
    st.info("Please upload a .wav audio file to continue.")