Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
import librosa | |
import os | |
from transformers import Wav2Vec2BertModel, AutoFeatureExtractor, HubertModel | |
import torch.nn as nn | |
from typing import Optional, Tuple | |
from transformers.file_utils import ModelOutput | |
from dataclasses import dataclass | |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
class SpeechClassifierOutput(ModelOutput): | |
loss: Optional[torch.FloatTensor] = None | |
logits: torch.FloatTensor = None | |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
attentions: Optional[Tuple[torch.FloatTensor]] = None | |
from transformers.models.wav2vec2.modeling_wav2vec2 import ( | |
Wav2Vec2PreTrainedModel, | |
Wav2Vec2Model | |
) | |
class Wav2Vec2ClassificationHead(nn.Module): | |
"""Head for wav2vec classification task.""" | |
def __init__(self, config): | |
super().__init__() | |
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
self.dropout = nn.Dropout(config.final_dropout) | |
self.out_proj = nn.Linear(config.hidden_size, config.num_labels) | |
def forward(self, features, **kwargs): | |
x = features | |
x = self.dropout(x) | |
x = self.dense(x) | |
x = torch.tanh(x) | |
x = self.dropout(x) | |
x = self.out_proj(x) | |
return x | |
class Wav2Vec2ForSpeechClassification(nn.Module): | |
def __init__(self,model_name): | |
super().__init__() | |
self.num_labels = 2 | |
self.pooling_mode = 'mean' | |
self.wav2vec2bert = Wav2Vec2BertModel.from_pretrained(model_name) | |
self.config = self.wav2vec2bert.config | |
self.classifier = Wav2Vec2ClassificationHead(self.wav2vec2bert.config) | |
def merged_strategy(self,hidden_states,mode="mean"): | |
if mode == "mean": | |
outputs = torch.mean(hidden_states, dim=1) | |
elif mode == "sum": | |
outputs = torch.sum(hidden_states, dim=1) | |
elif mode == "max": | |
outputs = torch.max(hidden_states, dim=1)[0] | |
else: | |
raise Exception( | |
"The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']") | |
return outputs | |
def forward(self,input_features,attention_mask=None,output_attentions=None,output_hidden_states=None,return_dict=None,labels=None,): | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
outputs = self.wav2vec2bert( | |
input_features, | |
attention_mask=attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = outputs.last_hidden_state | |
hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode) | |
logits = self.classifier(hidden_states) | |
loss = None | |
if labels is not None: | |
if self.config.problem_type is None: | |
if self.num_labels == 1: | |
self.config.problem_type = "regression" | |
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | |
self.config.problem_type = "single_label_classification" | |
else: | |
self.config.problem_type = "multi_label_classification" | |
if self.config.problem_type == "regression": | |
loss_fct = MSELoss() | |
loss = loss_fct(logits.view(-1, self.num_labels), labels) | |
elif self.config.problem_type == "single_label_classification": | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
elif self.config.problem_type == "multi_label_classification": | |
loss_fct = BCEWithLogitsLoss() | |
loss = loss_fct(logits, labels) | |
if not return_dict: | |
output = (logits,) + outputs[2:] | |
return ((loss,) + output) if loss is not None else output | |
return SpeechClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=outputs.last_hidden_state, | |
attentions=outputs.attentions, | |
) | |
class HuBERT(nn.Module): | |
def __init__(self, model_name): | |
super().__init__() | |
self.num_labels = 2 | |
self.pooling_mode = 'mean' | |
self.wav2vec2 = HubertModel.from_pretrained(model_name) | |
self.config = self.wav2vec2.config | |
self.classifier = Wav2Vec2ClassificationHead(self.wav2vec2.config) | |
def merged_strategy(self, hidden_states, mode="mean"): | |
if mode == "mean": | |
outputs = torch.mean(hidden_states, dim=1) | |
elif mode == "sum": | |
outputs = torch.sum(hidden_states, dim=1) | |
elif mode == "max": | |
outputs = torch.max(hidden_states, dim=1)[0] | |
else: | |
raise Exception( | |
"The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']") | |
return outputs | |
def forward(self, input_values, attention_mask=None, output_attentions=None, output_hidden_states=None, | |
return_dict=None, labels=None, ): | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
outputs = self.wav2vec2( | |
input_values, | |
attention_mask=attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = outputs.last_hidden_state | |
hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode) | |
logits = self.classifier(hidden_states) | |
loss = None | |
if labels is not None: | |
if self.config.problem_type is None: | |
if self.num_labels == 1: | |
self.config.problem_type = "regression" | |
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | |
self.config.problem_type = "single_label_classification" | |
else: | |
self.config.problem_type = "multi_label_classification" | |
if self.config.problem_type == "regression": | |
loss_fct = MSELoss() | |
loss = loss_fct(logits.view(-1, self.num_labels), labels) | |
elif self.config.problem_type == "single_label_classification": | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
elif self.config.problem_type == "multi_label_classification": | |
loss_fct = BCEWithLogitsLoss() | |
loss = loss_fct(logits, labels) | |
if not return_dict: | |
output = (logits,) + outputs[2:] | |
return ((loss,) + output) if loss is not None else output | |
return SpeechClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=outputs.last_hidden_state, | |
attentions=outputs.attentions, | |
) | |
def pad(x, max_len=64000): | |
x_len = x.shape[0] | |
if x_len > max_len: | |
stt = np.random.randint(x_len - max_len) | |
return x[stt:stt + max_len] | |
# return x[:max_len] | |
# num_repeats = int(max_len / x_len) + 1 | |
# padded_x = np.tile(x, (num_repeats))[:max_len] | |
pad_length = max_len - x_len | |
padded_x = np.concatenate([x, np.zeros(pad_length)], axis=0) | |
return padded_x | |
class AudioDeepfakeDetector: | |
def __init__(self): | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.models = {} | |
self.feature_extractors = {} | |
self.current_model = None | |
# model_name = 'facebook/w2v-bert-2.0' | |
# self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
# self.model = Wav2Vec2ForSpeechClassification(model_name).to(self.device) | |
# ckpt = torch.load("wave2vec2bert_wavefake.pth",map_location=self.device) | |
# self.model.load_state_dict(ckpt) | |
print(f"Using device: {self.device}") | |
print("Audio deepfake detector initilized") | |
def load_model(self, model_type): | |
"""Load the specified model type""" | |
if model_type in self.models: | |
self.current_model = model_type | |
return | |
try: | |
print(f"π Loading {model_type} model...") | |
if model_type == "Wave2Vec2BERT": | |
model_name = 'facebook/w2v-bert-2.0' | |
self.feature_extractors[model_type] = AutoFeatureExtractor.from_pretrained(model_name) | |
self.models[model_type] = Wav2Vec2ForSpeechClassification(model_name).to(self.device) | |
# checkpoint_path = "wave2vec2bert_wavefake.pth" | |
# if os.path.exists(checkpoint_path): | |
# ckpt = torch.load(checkpoint_path, map_location=self.device) | |
# self.models[model_type].load_state_dict(ckpt) | |
# print(f"β Loaded checkpoint for {model_type}") | |
# else: | |
# print(f"β οΈ Checkpoint not found for {model_type}, using pretrained weights only") | |
try: | |
from huggingface_hub import hf_hub_download | |
checkpoint_path = hf_hub_download( | |
repo_id="TrustSafeAI/AudioDeepfakeDetectors", | |
filename="wave2vec2bert_wavefake.pth", | |
cache_dir="./models" | |
) | |
ckpt = torch.load(checkpoint_path, map_location=self.device) | |
self.models[model_type].load_state_dict(ckpt) | |
print(f"β Loaded checkpoint for {model_type}") | |
except Exception as e: | |
print(f"β οΈ Could not load checkpoint for {model_type}: {e}") | |
print("Using pretrained weights only") | |
elif model_type == "HuBERT": | |
model_name = 'facebook/hubert-large-ls960-ft' | |
self.feature_extractors[model_type] = AutoFeatureExtractor.from_pretrained(model_name) | |
self.models[model_type] = HuBERT(model_name).to(self.device) | |
# checkpoint_path = "hubert_large_wavefake.pth" | |
# if os.path.exists(checkpoint_path): | |
# ckpt = torch.load(checkpoint_path, map_location=self.device) | |
# self.models[model_type].load_state_dict(ckpt) | |
# print(f"β Loaded checkpoint for {model_type}") | |
# else: | |
# print(f"β οΈ Checkpoint not found for {model_type}, using pretrained weights only") | |
try: | |
from huggingface_hub import hf_hub_download | |
checkpoint_path = hf_hub_download( | |
repo_id="TrustSafeAI/AudioDeepfakeDetectors", # ζΏζ’δΈΊδ½ η樑εδ»εΊ | |
filename="hubert_large_wavefake.pth", | |
cache_dir="./models" | |
) | |
ckpt = torch.load(checkpoint_path, map_location=self.device) | |
self.models[model_type].load_state_dict(ckpt) | |
print(f"β Loaded checkpoint for {model_type}") | |
except Exception as e: | |
print(f"β οΈ Could not load checkpoint for {model_type}: {e}") | |
print("Using pretrained weights only") | |
self.current_model = model_type | |
print(f"β {model_type} model loaded successfully") | |
except Exception as e: | |
print(f"β Error loading {model_type} model: {str(e)}") | |
raise | |
def preprocess_audio(self, audio_path, target_sr=16000, max_length=4): | |
try: | |
print(f"π Loading audio file: {os.path.basename(audio_path)}") | |
audio, sr = librosa.load(audio_path, sr=target_sr) | |
original_duration = len(audio) / sr | |
audio = pad(audio).reshape(-1) | |
audio = audio[np.newaxis, :] | |
print(f"β Audio loaded successfully: {original_duration:.2f}s, {sr}Hz") | |
return audio, sr | |
except Exception as e: | |
print(f"β Audio processing error: {str(e)}") | |
raise | |
def extract_features(self, audio, sr, model_type): | |
print("π extract audio features...") | |
feature_extractor = self.feature_extractors[model_type] | |
inputs = feature_extractor(audio, sampling_rate=sr, return_attention_mask=True, padding_value=0, return_tensors="pt").to(self.device) | |
print("β Feature extracion completed") | |
return inputs | |
def classifier(self, features, model_type): | |
model = self.models[model_type] | |
with torch.no_grad(): | |
outputs = model(**features) | |
prob = outputs.logits.softmax(dim=-1) | |
fake_prob = prob[0][0].item() | |
return fake_prob | |
def predict(self, audio_path, model_type): | |
try: | |
print("π΅ Start analyzing...") | |
self.load_model(model_type) | |
audio, sr = self.preprocess_audio(audio_path) | |
features= self.extract_features(audio, sr, model_type) | |
fake_probability = self.classifier(features, model_type) | |
real_probability = 1 - fake_probability | |
threshold = 0.5 | |
if fake_probability > threshold: | |
status = "SUSPICIOUS" | |
prediction = "π¨ Likely fake audio" | |
confidence = fake_probability | |
color = "red" | |
else: | |
status = "LIKELY_REAL" | |
prediction = "β Likely real audio" | |
confidence = real_probability | |
color = "green" | |
print(f"\n{'='*50}") | |
print(f"π― Result: {prediction}") | |
print(f"π Confidence: {confidence:.1%}") | |
print(f"π Real Probability: {real_probability:.1%}") | |
print(f"π Fake Probability: {fake_probability:.1%}") | |
print(f"{'='*50}") | |
duration = len(audio) / sr | |
file_size = os.path.getsize(audio_path) / 1024 | |
result_data = { | |
"status": status, | |
"prediction": prediction, | |
"confidence": confidence, | |
"real_probability": real_probability, | |
"fake_probability": fake_probability, | |
"duration": duration, | |
"sample_rate": sr, | |
"file_size_kb": file_size, | |
"model_used": model_type | |
} | |
return result_data | |
except Exception as e: | |
print(f"β Failed: {str(e)}") | |
return {"error": str(e)} | |
detector = AudioDeepfakeDetector() | |
def analyze_uploaded_audio(audio_file, model_choice): | |
if audio_file is None: | |
return "Please upload audio", {} | |
try: | |
result = detector.predict(audio_file, model_choice) | |
if "error" in result: | |
return f"Error: {result['error']}", {} | |
status_color = "#ff4444" if result['status'] == "SUSPICIOUS" else "#44ff44" | |
result_html = f""" | |
<div style="padding: 20px; border-radius: 10px; background-color: {status_color}20; border: 2px solid {status_color};"> | |
<h3 style="color: {status_color}; margin-top: 0;">{result['prediction']}</h3> | |
<p><strong>Status:</strong> {result['status']}</p> | |
<p><strong>Confidence:</strong> {result['confidence']:.1%}</p> | |
</div> | |
""" | |
analysis_data = { | |
"status": result['status'], | |
"real_probability": f"{result['real_probability']:.1%}", | |
"fake_probability": f"{result['fake_probability']:.1%}", | |
} | |
return result_html, analysis_data | |
except Exception as e: | |
error_html = f""" | |
<div style="padding: 20px; border-radius: 10px; background-color: #ff444420; border: 2px solid #ff4444;"> | |
<h3 style="color: #ff4444;">β Processing error</h3> | |
<p>{str(e)}</p> | |
</div> | |
""" | |
return error_html, {"error": str(e)} | |
def create_audio_interface(): | |
with gr.Blocks(title="Audio Deepfake Detection", theme=gr.themes.Soft()) as interface: | |
gr.Markdown(""" | |
<div style="text-align: center; margin-bottom: 30px;"> | |
<h1 style="font-size: 28px; font-weight: bold; margin-bottom: 20px; color: #333;"> | |
Measuring the Robustness of Audio Deepfake Detection under Real-World Corruptions | |
</h1> | |
<p style="font-size: 16px; color: #666; margin-bottom: 15px;"> | |
Audio deepfake detectors based on Wave2Vec2BERT and HuBERT speech foundation models (fine-tuned with Wavefake dataset). | |
</p> | |
<div style="font-size: 14px; color: #555; line-height: 1.8; text-align: left;"> | |
<p><strong>Paper:</strong> <a href="https://arxiv.org/pdf/2503.17577" target="_blank" style="color: #4285f4; text-decoration: none;">https://arxiv.org/pdf/2503.17577</a></p> | |
<p><strong>Project Page:</strong> <a href="https://huggingface.co/spaces/TrustSafeAI/AudioPerturber" target="_blank" style="color: #4285f4; text-decoration: none;">"https://huggingface.co/spaces/TrustSafeAI/AudioPerturber</a></p> | |
<p><strong>Model Checkpoints:</strong> <a href="https://huggingface.co/TrustSafeAI/AudioDeepfakeDetectors" target="_blank" style="color: #4285f4; text-decoration: none;">"https://huggingface.co/TrustSafeAI/AudioDeepfakeDetectors</a></p> | |
<p><strong>Github Codebase:</strong> <a href="https://github.com/Jessegator/Audio_robustness_evaluation" target="_blank" style="color: #4285f4; text-decoration: none;">https://github.com/Jessegator/Audio_robustness_evaluation</a></p> | |
</div> | |
</div> | |
<hr style="margin: 30px 0; border: none; border-top: 1px solid #e0e0e0;"> | |
""") | |
gr.Markdown(""" | |
# Audio Deepfake Detection | |
**Supported Format**: .wav, .mp3, .flac, .m4a, etc. | |
""") | |
with gr.Row(): | |
# model_choice = gr.Dropdown( | |
# choices=["Wave2Vec2BERT", "HuBERT"], | |
# value="Wave2Vec2BERT", | |
# label="π€ Select Model", | |
# info="Choose the foundation model for detection" | |
# ) | |
with gr.Column(scale=1): | |
model_choice = gr.Dropdown( | |
choices=["Wave2Vec2BERT", "HuBERT"], | |
value="Wave2Vec2BERT", | |
label="π€ Select Model", | |
info="Choose the foundation model for detection" | |
) | |
audio_input = gr.Audio( | |
label="π Upload audio file", | |
type="filepath", | |
show_label=True, | |
interactive=True | |
) | |
analyze_btn = gr.Button( | |
"π Start analyzing", | |
variant="primary", | |
size="lg" | |
) | |
gr.Markdown("### π Play uploaded audio") | |
audio_player = gr.Audio( | |
label="Audio Player", | |
interactive=False, | |
show_label=False | |
) | |
with gr.Column(scale=1): | |
result_display = gr.HTML( | |
label="π― Results", | |
value="<p style='text-align: center; color: #666;'>Waiting for uploading...</p>" | |
) | |
analysis_json = gr.JSON( | |
label="π Detailed analysis", | |
value={} | |
) | |
def update_player_and_analyze(audio_file, model_type): | |
if audio_file is not None: | |
result_html, result_data = analyze_uploaded_audio(audio_file, model_type) | |
return audio_file, result_html, result_data | |
else: | |
return None, "<p style='text-align: center; color: #666;'>Waiting for uploading...</p>", {} | |
audio_input.change( | |
fn=update_player_and_analyze, | |
inputs=[audio_input, model_choice], | |
outputs=[audio_player, result_display, analysis_json] | |
) | |
analyze_btn.click( | |
fn=analyze_uploaded_audio, | |
inputs=[audio_input, model_choice], | |
outputs=[result_display, analysis_json] | |
) | |
model_choice.change( | |
fn=lambda audio_file, model_type: analyze_uploaded_audio(audio_file, model_type) if audio_file is not None else ("Please upload audio first", {}), | |
inputs=[audio_input, model_choice], | |
outputs=[result_display, analysis_json] | |
) | |
return interface | |
if __name__ == "__main__": | |
print("π Create interface...") | |
demo = create_audio_interface() | |
print("π± Launching...") | |
demo.launch( | |
share=False, | |
debug=True, | |
show_error=True | |
) |