DavidCombei commited on
Commit
879c4b9
·
verified ·
1 Parent(s): 95e4f8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -95
app.py CHANGED
@@ -1,95 +1,100 @@
1
- import joblib
2
- from transformers import AutoFeatureExtractor, Wav2Vec2Model
3
- import torch
4
- import librosa
5
- import numpy as np
6
- from sklearn.linear_model import LogisticRegression
7
- import gradio as gr
8
- import os
9
- import torch.nn.functional as F
10
- from scipy.special import expit
11
-
12
-
13
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
-
15
- class CustomWav2Vec2Model(Wav2Vec2Model):
16
- def __init__(self, config):
17
- super().__init__(config)
18
- self.encoder.layers = self.encoder.layers[:9]
19
-
20
- truncated_model = CustomWav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-2b")
21
-
22
- class HuggingFaceFeatureExtractor:
23
- def __init__(self, model, feature_extractor_name):
24
- self.device = device
25
- self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_name)
26
- self.model = model
27
- self.model.eval()
28
- self.model.to(self.device)
29
-
30
- def __call__(self, audio, sr):
31
- inputs = self.feature_extractor(
32
- audio,
33
- sampling_rate=sr,
34
- return_tensors="pt",
35
- padding=True,
36
- )
37
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
38
- with torch.no_grad():
39
- outputs = self.model(**inputs, output_hidden_states=True)
40
- return outputs.hidden_states[9]
41
-
42
- FEATURE_EXTRACTOR = HuggingFaceFeatureExtractor(truncated_model, "facebook/wav2vec2-xls-r-2b")
43
- classifier,scaler, thresh = joblib.load('logreg_margin_pruning_ALL_with_scaler+threshold.joblib')
44
-
45
- def segment_audio(audio, sr, segment_duration):
46
- segment_samples = int(segment_duration * sr)
47
- total_samples = len(audio)
48
- segments = [audio[i:i + segment_samples] for i in range(0, total_samples, segment_samples)]
49
- return segments
50
-
51
- def process_audio(input_data, segment_duration=10):
52
- audio, sr = librosa.load(input_data, sr=16000)
53
- if len(audio.shape) > 1:
54
- audio = audio[0]
55
- segments = segment_audio(audio, sr, segment_duration)
56
- segment_predictions = []
57
- output_lines = []
58
- eer_threshold = thresh - 5e-3 # small margin error due to feature extractor space differences
59
- for idx, segment in enumerate(segments):
60
- features = FEATURE_EXTRACTOR(segment, sr)
61
- features_avg = torch.mean(features, dim=1).cpu().numpy()
62
- features_avg = features_avg.reshape(1, -1)
63
- decision_score = classifier.decision_function(features_avg)
64
- decision_score_scaled = scaler.transform(decision_score.reshape(-1, 1)).flatten()
65
- decision_value = decision_score_scaled[0]
66
- pred = 1 if decision_value >= eer_threshold else 0
67
- if pred == 1:
68
- confidence_percentage = expit(decision_score).item()
69
- else:
70
- confidence_percentage = 1 - expit(decision_score).item()
71
- segment_predictions.append(pred)
72
- line = f"Segment {idx + 1}: {'Real' if pred == 1 else 'Fake'} (Confidence: {np.round(confidence_percentage*100, 2)}%)"
73
- output_lines.append(line)
74
- overall_prediction = 1 if sum(segment_predictions) > (len(segment_predictions) / 2) else 0
75
- overall_line = f"Overall Prediction: {'Real' if overall_prediction == 1 else 'Fake'}"
76
- output_str = overall_line + "\n" + "\n".join(output_lines)
77
- return output_str
78
-
79
- def gradio_interface(audio):
80
- if audio:
81
- return process_audio(audio)
82
- else:
83
- return "please upload an audio file"
84
-
85
- interface = gr.Interface(
86
- fn=gradio_interface,
87
- inputs=[gr.Audio(type="filepath", label="Upload Audio")],
88
- outputs="text",
89
- title="SOL2 Audio Deepfake Detection Demo",
90
- description="Upload an audio file to check if it's AI-generated",
91
- )
92
-
93
- interface.launch(share=True)
94
- #
95
- #print(process_audio('SSL_scripts/1.wav'))
 
 
 
 
 
 
1
+ import joblib
2
+ from transformers import AutoFeatureExtractor, Wav2Vec2Model
3
+ import torch
4
+ import librosa
5
+ import numpy as np
6
+ from sklearn.linear_model import LogisticRegression
7
+ import gradio as gr
8
+ import os
9
+ import torch.nn.functional as F
10
+ from scipy.special import expit
11
+
12
+
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+
15
+ class CustomWav2Vec2Model(Wav2Vec2Model):
16
+ def __init__(self, config):
17
+ super().__init__(config)
18
+ self.encoder.layers = self.encoder.layers[:9]
19
+
20
+ truncated_model = CustomWav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-2b")
21
+
22
+ class HuggingFaceFeatureExtractor:
23
+ def __init__(self, model, feature_extractor_name):
24
+ self.device = device
25
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_name)
26
+ self.model = model
27
+ self.model.eval()
28
+ self.model.to(self.device)
29
+
30
+ def __call__(self, audio, sr):
31
+ inputs = self.feature_extractor(
32
+ audio,
33
+ sampling_rate=sr,
34
+ return_tensors="pt",
35
+ padding=True,
36
+ )
37
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
38
+ with torch.no_grad():
39
+ outputs = self.model(**inputs, output_hidden_states=True)
40
+ return outputs.hidden_states[9]
41
+
42
+ FEATURE_EXTRACTOR = HuggingFaceFeatureExtractor(truncated_model, "facebook/wav2vec2-xls-r-2b")
43
+ classifier,scaler, thresh = joblib.load('logreg_margin_pruning_ALL_with_scaler+threshold.joblib')
44
+
45
+ def segment_audio(audio, sr, segment_duration):
46
+ segment_samples = int(segment_duration * sr)
47
+ total_samples = len(audio)
48
+ segments = [audio[i:i + segment_samples] for i in range(0, total_samples, segment_samples)]
49
+ segments_check = []
50
+ for seg in segments:
51
+ # if the segment is shorter than 0.7s, skip it to avoid complications inside wav2vec2
52
+ if len(seg) > 0.7 * sr:
53
+ segments_check.append(seg)
54
+ return segments_check
55
+
56
+ def process_audio(input_data, segment_duration=10):
57
+ audio, sr = librosa.load(input_data, sr=16000)
58
+ if len(audio.shape) > 1:
59
+ audio = audio[0]
60
+ segments = segment_audio(audio, sr, segment_duration)
61
+ segment_predictions = []
62
+ output_lines = []
63
+ eer_threshold = thresh - 5e-3 # small margin error due to feature extractor space differences
64
+ for idx, segment in enumerate(segments):
65
+ features = FEATURE_EXTRACTOR(segment, sr)
66
+ features_avg = torch.mean(features, dim=1).cpu().numpy()
67
+ features_avg = features_avg.reshape(1, -1)
68
+ decision_score = classifier.decision_function(features_avg)
69
+ decision_score_scaled = scaler.transform(decision_score.reshape(-1, 1)).flatten()
70
+ decision_value = decision_score_scaled[0]
71
+ pred = 1 if decision_value >= eer_threshold else 0
72
+ if pred == 1:
73
+ confidence_percentage = expit(decision_score).item()
74
+ else:
75
+ confidence_percentage = 1 - expit(decision_score).item()
76
+ segment_predictions.append(pred)
77
+ line = f"Segment {idx + 1}: {'Real' if pred == 1 else 'Fake'} (Confidence: {np.round(confidence_percentage*100, 2)}%)"
78
+ output_lines.append(line)
79
+ overall_prediction = 1 if sum(segment_predictions) > (len(segment_predictions) / 2) else 0
80
+ overall_line = f"Overall Prediction: {'Real' if overall_prediction == 1 else 'Fake'}"
81
+ output_str = overall_line + "\n" + "\n".join(output_lines)
82
+ return output_str
83
+
84
+ def gradio_interface(audio):
85
+ if audio:
86
+ return process_audio(audio)
87
+ else:
88
+ return "please upload an audio file"
89
+
90
+ interface = gr.Interface(
91
+ fn=gradio_interface,
92
+ inputs=[gr.Audio(type="filepath", label="Upload Audio")],
93
+ outputs="text",
94
+ title="SOL2 Audio Deepfake Detection Demo",
95
+ description="Upload an audio file to check if it's AI-generated",
96
+ )
97
+
98
+ interface.launch(share=True)
99
+ #
100
+ #print(process_audio('SSL_scripts/1.wav'))