File size: 6,934 Bytes
d3d626b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d658e55
d3d626b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d698901
d3d626b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import scipy.io.wavfile as wav
from scipy.fftpack import idct
import gradio as gr
import os
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Modele CNN
class modele_CNN(nn.Module):
    def __init__(self, num_classes=8, dropout=0.3):
        super(modele_CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 1 * 62, 128) 
        self.fc2 = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x        

# Audio processor
class AudioProcessor:
    def Mel2Hz(self, mel): return 700 * (np.power(10, mel/2595)-1)
    def Hz2Mel(self, freq): return 2595 * np.log10(1+freq/700)
    def Hz2Ind(self, freq, fs, Tfft): return (freq*Tfft/fs).astype(int)
    
    def hamming(self, T): 
        if T <= 1:
            return np.ones(T)
        return 0.54-0.46*np.cos(2*np.pi*np.arange(T)/(T-1))

    def FiltresMel(self, fs, nf=36, Tfft=512, fmin=100, fmax=8000):
        Indices = self.Hz2Ind(self.Mel2Hz(np.linspace(self.Hz2Mel(fmin), self.Hz2Mel(min(fmax, fs/2)), nf+2)), fs, Tfft)
        filtres = np.zeros((int(Tfft/2), nf))
        for i in range(nf): filtres[Indices[i]:Indices[i+2], i] = self.hamming(Indices[i+2]-Indices[i])
        return filtres

    def spectrogram(self, x, T, p, Tfft):
        S = [] 
        for i in range(0, len(x)-T, p): S.append(x[i:i+T]*self.hamming(T))
        S = np.fft.fft(S, Tfft)
        return np.abs(S), np.angle(S)
    
    def mfcc(self, data, filtres, nc=13, T=256, p=64, Tfft=512):
        data = (data[1]-np.mean(data[1]))/np.std(data[1])
        amp, ph = self.spectrogram(data, T, p, Tfft)
        amp_f = np.log10(np.dot(amp[:, :int(Tfft/2)], filtres)+1)
        return idct(amp_f, n=nc, norm='ortho')

    def process_audio(self, audio_data, sr, audio_length=32000):
        if sr != 16000:
            audio_resampled = np.interp(
                np.linspace(0, len(audio_data), int(16000 * len(audio_data) / sr)),
                np.arange(len(audio_data)),
                audio_data
            )
            sgn = audio_resampled
            fs = 16000
        else:
            sgn = audio_data
            fs = sr
        
        sgn = np.array(sgn, dtype=np.float32)
        
        if len(sgn) > audio_length:
            sgn = sgn[:audio_length]
        else:
            sgn = np.pad(sgn, (0, audio_length - len(sgn)), mode='constant')
        
        filtres = self.FiltresMel(fs)
        sgn_features = self.mfcc([fs, sgn], filtres)
        
        mfcc_tensor = torch.tensor(sgn_features.T, dtype=torch.float32)
        mfcc_tensor = mfcc_tensor.unsqueeze(0).unsqueeze(0)
        
        return mfcc_tensor

# Fonction prédiction
def predict_speaker(audio, model, processor):
    if audio is None:
        return "Aucun audio détecté.", None
    
    try:
        import soundfile as sf
        audio_data, sr = sf.read(audio)  # <- ici tu lis direct l'audio
        input_tensor = processor.process_audio(audio_data, sr)
        
        device = next(model.parameters()).device
        input_tensor = input_tensor.to(device)
        
        with torch.no_grad():
            output = model(input_tensor)
            print(output)
            probabilities = F.softmax(output, dim=1)
            confidence, predicted_class = torch.max(probabilities, 1)
        
        speakers = ["George", "Jackson", "Lucas", "Nicolas", "Theo", "Yweweler", "Narimene"]
        predicted_speaker = speakers[predicted_class.item()]
        
        result = f"Locuteur reconnu : {predicted_speaker} (confiance : {confidence.item()*100:.2f}%)"
        
        probs_dict = {speakers[i]: float(probs) for i, probs in enumerate(probabilities[0].cpu().numpy())}
        
        return result, probs_dict
    
    except Exception as e:
        return f"Erreur : {str(e)}", None

# Charger modèle
def load_model(model_id="nareauow/my_speech_recognition", model_filename="model_3.pth"):
    try:
        model_path = hf_hub_download(repo_id=model_id, filename=model_filename)
        model = modele_CNN(num_classes=7, dropout=0.)
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.to(device)
        model.eval()
        print("Modèle chargé avec succès !")
        return model
    except Exception as e:
        print(f"Erreur de chargement: {e}")
        return None

# Gradio Interface
def create_interface():
    processor = AudioProcessor()
    
    with gr.Blocks(title="Reconnaissance de Locuteur") as interface:
        gr.Markdown("# 🗣️ Reconnaissance de Locuteur")
        gr.Markdown("Enregistrez votre voix pendant 2 secondes pour identifier qui parle.")
        
        with gr.Row():
            with gr.Column():
                model_selector = gr.Dropdown(
                    choices=["model_1.pth", "model_2.pth", "model_3.pth"],
                    value="model_3.pth",
                    label="Choisissez le modèle"
                )
                audio_input = gr.Audio(sources=["microphone"], type="filepath", label="🎙️ Parlez ici")
                record_btn = gr.Button("Reconnaître")
            with gr.Column():
                result_text = gr.Textbox(label="Résultat")
                plot_output = gr.Plot(label="Confiance par locuteur")
        
        def recognize(audio, selected_model):
            model = load_model(model_filename=selected_model)  # Charger le modèle choisi
            res, probs = predict_speaker(audio, model, processor)
            fig = None
            if probs:
                fig, ax = plt.subplots()
                ax.bar(probs.keys(), probs.values(), color='skyblue')
                ax.set_ylim([0, 1])
                ax.set_ylabel("Confiance")
                ax.set_xlabel("Locuteurs")
                plt.xticks(rotation=45)
            return res, fig
        
        record_btn.click(fn=recognize, inputs=[audio_input, model_selector], outputs=[result_text, plot_output])
        
        gr.Markdown("""### Comment utiliser ?
        - Choisissez le modèle.
        - Cliquez sur 🎙️ pour enregistrer votre voix.
        - Cliquez sur **Reconnaître** pour obtenir la prédiction.
        """)
    
    return interface

# Lancer
if __name__ == "__main__":
    app = create_interface()
    app.launch(share=True)