DavidCombei commited on
Commit
4613f1e
·
verified ·
1 Parent(s): 15d6143

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +150 -0
  2. logreg_margin_pruning_ALL_best.joblib +3 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+ from transformers import AutoFeatureExtractor, Wav2Vec2Model
3
+ import torch
4
+ import librosa
5
+ import numpy as np
6
+ from sklearn.linear_model import LogisticRegression
7
+ import gradio as gr
8
+ import os
9
+ from scipy.stats import mode
10
+
11
+
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+
14
+
15
+
16
+
17
+ #truncate the SSL from the 10th layer, since we only need the first 9th transformer layers
18
+ class CustomWav2Vec2Model(Wav2Vec2Model):
19
+ def __init__(self, config):
20
+ super().__init__(config)
21
+ self.encoder.layers = self.encoder.layers[:9]
22
+
23
+
24
+ truncated_model = CustomWav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-2b")
25
+
26
+
27
+ # calling the SSL model for feature extraction
28
+ class HuggingFaceFeatureExtractor:
29
+ def __init__(self, model, feature_extractor_name):
30
+ self.device = device
31
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_name)
32
+ self.model = model
33
+ self.model.eval()
34
+ self.model.to(self.device)
35
+
36
+ def __call__(self, audio, sr):
37
+ inputs = self.feature_extractor(
38
+ audio,
39
+ sampling_rate=sr,
40
+ return_tensors="pt",
41
+ padding=True,
42
+ )
43
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
44
+ with torch.no_grad():
45
+ outputs = self.model(**inputs, output_hidden_states=True)
46
+ return outputs.hidden_states[9]
47
+
48
+ FEATURE_EXTRACTOR = HuggingFaceFeatureExtractor(truncated_model, "facebook/wav2vec2-xls-r-2b")
49
+
50
+ #load our best classifier
51
+ classifier = joblib.load('logreg_margin_pruning_ALL_best.joblib')
52
+
53
+ #segment audio and return the segments
54
+ def segment_audio(audio, sr, segment_duration):
55
+ segment_samples = int(segment_duration * sr)
56
+ total_samples = len(audio)
57
+ segments = [audio[i:i + segment_samples] for i in range(0, total_samples, segment_samples)]
58
+ return segments
59
+
60
+
61
+ # classification using the EER threshold
62
+ def classify_with_eer_threshold(probabilities, eer_thresh):
63
+ return (probabilities >= eer_thresh).astype(int)
64
+
65
+
66
+ def process_audio(input_data, segment_duration=30):
67
+ # resample to 16 kHz audio, since xls-r-2b it's trained on 16 KHz audio
68
+
69
+ audio, sr = librosa.load(input_data, sr=16000)
70
+
71
+ # check for single-channel audio (that's what xls-r-2b expects as input)
72
+ if len(audio.shape) > 1:
73
+ audio = audio[0]
74
+
75
+ # segment the audio in 30s chunks to avoid xls-r-2b crashing
76
+ print('loaded file')
77
+ segments = segment_audio(audio, sr, segment_duration)
78
+ final_features = []
79
+ print('segments')
80
+
81
+ # extract the features from each 30s segment
82
+ for segment in segments:
83
+ features = FEATURE_EXTRACTOR(segment, sr)
84
+ features_avg = torch.mean(features, dim=1).cpu().numpy()
85
+ final_features.append(features_avg)
86
+ print('features extracted')
87
+ inference_prob = []
88
+ for feature in final_features:
89
+ #reshape to avoid the batch dimension output from xls
90
+ feature = feature.reshape(1, -1)
91
+ #make the classification
92
+ print(classifier.classes_)
93
+
94
+ probability = classifier.predict_proba(feature)[:, 1]
95
+ inference_prob.append(probability[0])
96
+ print('classifier predicted')
97
+ eer_threshold = 0.9999999996754046
98
+
99
+ #all segment prediction based on probability score and eer threshold
100
+ y_pred_inference = classify_with_eer_threshold(np.array(inference_prob), eer_threshold)
101
+ print('inference done for segments')
102
+ #FINAL PREDICTION based on majority wins
103
+ mode_result = mode(y_pred_inference, keepdims=True)
104
+ final_prediction = mode_result.mode[0] if mode_result.mode.size > 0 else 0
105
+
106
+ print('majority voting done')
107
+ # confidence score (proportion of segments agreeing with majority prediction)
108
+ confidence_score = np.mean(y_pred_inference == final_prediction) if len(y_pred_inference) > 0 else 1.0
109
+ confidence_percentage = confidence_score * 100
110
+
111
+ return {
112
+ "Final classification": "Real" if final_prediction == 1 else "Fake",
113
+ "Confidence ": round(confidence_percentage, 2)
114
+ }
115
+
116
+
117
+
118
+ def gradio_interface(audio):
119
+ if audio:
120
+ return process_audio(audio)
121
+ else:
122
+ return "please upload audio or provide a YouTube link."
123
+
124
+ interface = gr.Interface(
125
+ fn=gradio_interface,
126
+ inputs=[gr.Audio(type="filepath", label="Upload Audio")],
127
+ outputs="text",
128
+ title="SOL2 Audio Deepfake detection Demo",
129
+ description="Upload an audio file to check if it's AI generated",
130
+ )
131
+
132
+ interface.launch(share=True)
133
+
134
+
135
+
136
+
137
+
138
+
139
+
140
+
141
+
142
+
143
+
144
+
145
+
146
+
147
+
148
+
149
+
150
+
logreg_margin_pruning_ALL_best.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:255607c7582302af4e325bec91fb4a7de563880ad3f8f8832d9f65e288818cd6
3
+ size 16223