DavidCombei commited on
Commit
85a8087
·
verified ·
1 Parent(s): 0f8a520

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+ from transformers import AutoFeatureExtractor, Wav2Vec2Model
3
+ import torch
4
+ import librosa
5
+ import numpy as np
6
+ from sklearn.linear_model import LogisticRegression
7
+ import gradio as gr
8
+ import os
9
+ import torch.nn.functional as F
10
+
11
+
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+
14
+ class CustomWav2Vec2Model(Wav2Vec2Model):
15
+ def __init__(self, config):
16
+ super().__init__(config)
17
+ self.encoder.layers = self.encoder.layers[:9]
18
+
19
+ truncated_model = CustomWav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-2b")
20
+
21
+ class HuggingFaceFeatureExtractor:
22
+ def __init__(self, model, feature_extractor_name):
23
+ self.device = device
24
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_name)
25
+ self.model = model
26
+ self.model.eval()
27
+ self.model.to(self.device)
28
+
29
+ def __call__(self, audio, sr):
30
+ inputs = self.feature_extractor(
31
+ audio,
32
+ sampling_rate=sr,
33
+ return_tensors="pt",
34
+ padding=True,
35
+ )
36
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
37
+ with torch.no_grad():
38
+ outputs = self.model(**inputs, output_hidden_states=True)
39
+ return outputs.hidden_states[9]
40
+
41
+ FEATURE_EXTRACTOR = HuggingFaceFeatureExtractor(truncated_model, "facebook/wav2vec2-xls-r-2b")
42
+ classifier,scaler, thresh = joblib.load('logreg_margin_pruning_ALL_with_scaler+threshold.joblib')
43
+
44
+ def segment_audio(audio, sr, segment_duration):
45
+ segment_samples = int(segment_duration * sr)
46
+ total_samples = len(audio)
47
+ segments = [audio[i:i + segment_samples] for i in range(0, total_samples, segment_samples)]
48
+ return segments
49
+
50
+ def process_audio(input_data, segment_duration=10):
51
+ audio, sr = librosa.load(input_data, sr=16000)
52
+ if len(audio.shape) > 1:
53
+ audio = audio[0]
54
+ segments = segment_audio(audio, sr, segment_duration)
55
+ segment_predictions = []
56
+ output_lines = []
57
+ eer_threshold = thresh - 5e-3 # small margin error due to feature extractor space differences
58
+ for idx, segment in enumerate(segments):
59
+ features = FEATURE_EXTRACTOR(segment, sr)
60
+ features_avg = torch.mean(features, dim=1).cpu().numpy()
61
+ features_avg = features_avg.reshape(1, -1)
62
+ decision_score = classifier.decision_function(features_avg)
63
+ decision_score_scaled = scaler.transform(decision_score.reshape(-1, 1)).flatten()
64
+ decision_value = decision_score_scaled[0]
65
+ pred = 1 if decision_value >= eer_threshold else 0
66
+ if pred == 1:
67
+ confidence_percentage = decision_score.item()
68
+ else:
69
+ confidence_percentage = 1 - decision_score.item()
70
+
71
+ segment_predictions.append(pred)
72
+ line = f"Segment {idx + 1}: {'Real' if pred == 1 else 'Fake'} (Confidence: {round(confidence_percentage, 2)}%)"
73
+ output_lines.append(line)
74
+ overall_prediction = 1 if sum(segment_predictions) > (len(segment_predictions) / 2) else 0
75
+ overall_line = f"Overall Prediction: {'Real' if overall_prediction == 1 else 'Fake'}"
76
+ output_str = overall_line + "\n" + "\n".join(output_lines)
77
+ return output_str
78
+
79
+ def gradio_interface(audio):
80
+ if audio:
81
+ return process_audio(audio)
82
+ else:
83
+ return "please upload an audio file"
84
+
85
+ interface = gr.Interface(
86
+ fn=gradio_interface,
87
+ inputs=[gr.Audio(type="filepath", label="Upload Audio")],
88
+ outputs="text",
89
+ title="SOL2 Audio Deepfake Detection Demo",
90
+ description="Upload an audio file to check if it's AI-generated",
91
+ )
92
+
93
+ interface.launch(share=True)
94
+
95
+ #print(process_audio('SSL_scripts/4.wav'))