DavidCombei commited on
Commit
e066769
·
verified ·
1 Parent(s): e14b89d

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -97
app.py DELETED
@@ -1,97 +0,0 @@
1
- import joblib
2
- from transformers import AutoFeatureExtractor, Wav2Vec2Model
3
- import torch
4
- import librosa
5
- import numpy as np
6
- from sklearn.linear_model import LogisticRegression
7
- import gradio as gr
8
- import os
9
- import torch.nn.functional as F
10
-
11
-
12
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
-
14
- class CustomWav2Vec2Model(Wav2Vec2Model):
15
- def __init__(self, config):
16
- super().__init__(config)
17
- self.encoder.layers = self.encoder.layers[:9]
18
-
19
- truncated_model = CustomWav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-2b")
20
-
21
- class HuggingFaceFeatureExtractor:
22
- def __init__(self, model, feature_extractor_name):
23
- self.device = device
24
- self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_name)
25
- self.model = model
26
- self.model.eval()
27
- self.model.to(self.device)
28
-
29
- def __call__(self, audio, sr):
30
- inputs = self.feature_extractor(
31
- audio,
32
- sampling_rate=sr,
33
- return_tensors="pt",
34
- padding=True,
35
- )
36
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
37
- with torch.no_grad():
38
- outputs = self.model(**inputs, output_hidden_states=True)
39
- return outputs.hidden_states[9]
40
-
41
- FEATURE_EXTRACTOR = HuggingFaceFeatureExtractor(truncated_model, "facebook/wav2vec2-xls-r-2b")
42
- classifier,scaler, thresh = joblib.load('logreg_margin_pruning_ALL_with_scaler+threshold.joblib')
43
-
44
- def segment_audio(audio, sr, segment_duration):
45
- segment_samples = int(segment_duration * sr)
46
- total_samples = len(audio)
47
- segments = [audio[i:i + segment_samples] for i in range(0, total_samples, segment_samples)]
48
- return segments
49
-
50
- def process_audio(input_data, segment_duration=10):
51
- audio, sr = librosa.load(input_data, sr=16000)
52
- if len(audio.shape) > 1:
53
- audio = audio[0]
54
- segments = segment_audio(audio, sr, segment_duration)
55
- segment_predictions = []
56
- output_lines = []
57
- eer_threshold = thresh - 5e-3 # small margin error due to feature extractor space differences
58
- for idx, segment in enumerate(segments):
59
- features = FEATURE_EXTRACTOR(segment, sr)
60
- features_avg = torch.mean(features, dim=1).cpu().numpy()
61
- features_avg = features_avg.reshape(1, -1)
62
- decision_score = classifier.decision_function(features_avg)
63
- decision_score_scaled = scaler.transform(decision_score.reshape(-1, 1)).flatten()
64
- decision_value = decision_score_scaled[0]
65
- pred = 1 if decision_value >= eer_threshold else 0
66
- p_real = 1.0 / (1.0 + np.exp(-decision_value))
67
- if pred == 1:
68
- confidence_percentage = round(p_real * 100, 2)
69
-
70
- else:
71
- confidence_percentage = round((1 - p_real) * 100, 2)
72
-
73
- segment_predictions.append(pred)
74
- line = f"Segment {idx + 1}: {'Real' if pred == 1 else 'Fake'} (Confidence: {round(confidence_percentage, 2)}%)"
75
- output_lines.append(line)
76
- overall_prediction = 1 if sum(segment_predictions) > (len(segment_predictions) / 2) else 0
77
- overall_line = f"Overall Prediction: {'Real' if overall_prediction == 1 else 'Fake'}"
78
- output_str = overall_line + "\n" + "\n".join(output_lines)
79
- return output_str
80
-
81
- def gradio_interface(audio):
82
- if audio:
83
- return process_audio(audio)
84
- else:
85
- return "please upload an audio file"
86
-
87
- interface = gr.Interface(
88
- fn=gradio_interface,
89
- inputs=[gr.Audio(type="filepath", label="Upload Audio")],
90
- outputs="text",
91
- title="SOL2 Audio Deepfake Detection Demo",
92
- description="Upload an audio file to check if it's AI-generated",
93
- )
94
-
95
- interface.launch(share=True)
96
-
97
- #print(process_audio('SSL_scripts/4.wav'))