Spaces:
Build error
Build error
# tabs/deception_detection.py | |
import gradio as gr | |
import cv2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy.signal import butter, filtfilt, find_peaks | |
from typing import Tuple, Optional, Dict | |
import logging | |
from dataclasses import dataclass | |
from enum import Enum | |
import librosa | |
import moviepy.editor as mp | |
import os | |
import tempfile | |
import torch | |
import torch.nn as nn | |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer | |
import mediapipe as mp_mediapipe | |
import re | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Define Enums and DataClasses | |
class DeceptionLevel(Enum): | |
LOW = 'Low' | |
MODERATE = 'Moderate' | |
HIGH = 'High' | |
class Metric: | |
name: str | |
threshold: float | |
value: float = 0.0 | |
detected: bool = False | |
def analyze(self, new_value: float): | |
self.value = new_value | |
self.detected = self.value > self.threshold | |
class SignalProcessor: | |
def __init__(self, fs: float): | |
self.fs = fs # Sampling frequency | |
def bandpass_filter(self, data: np.ndarray, lowcut: float = 0.75, highcut: float = 3.0) -> np.ndarray: | |
"""Apply bandpass filter to signal.""" | |
nyq = 0.5 * self.fs | |
low = lowcut / nyq | |
high = highcut / nyq | |
b, a = butter(2, [low, high], btype='band') | |
filtered = filtfilt(b, a, data) | |
logger.debug("Applied bandpass filter.") | |
return filtered | |
def find_peaks_in_signal(self, signal: np.ndarray) -> np.ndarray: | |
"""Find peaks in the signal.""" | |
min_distance = int(60 / 180 * self.fs) # At least 60 BPM (180 BPM max) | |
peaks, _ = find_peaks(signal, distance=min_distance) | |
logger.debug(f"Detected {len(peaks)} peaks in the signal.") | |
return peaks | |
class DeceptionAnalyzer: | |
def __init__(self): | |
self.metrics = { | |
"HRV Suppression": Metric("HRV Suppression", threshold=30.0), | |
"Heart Rate Elevation": Metric("Heart Rate Elevation", threshold=100.0), | |
"Rhythm Irregularity": Metric("Rhythm Irregularity", threshold=0.1), | |
"Blink Rate": Metric("Blink Rate", threshold=25.0), | |
"Head Movements": Metric("Head Movements", threshold=10.0), | |
"Speech Stress": Metric("Speech Stress", threshold=0.5), | |
"Speech Pitch Variation": Metric("Speech Pitch Variation", threshold=50.0), | |
"Pauses and Hesitations": Metric("Pauses and Hesitations", threshold=2.0), | |
"Filler Words": Metric("Filler Words", threshold=5.0), | |
} | |
def analyze_signals(self, heart_rate: np.ndarray, rr_intervals: np.ndarray, hrv_rmssd: float, | |
speech_features: Dict[str, float], facial_features: Dict[str, float]) -> Tuple[Dict[str, Dict], float, DeceptionLevel]: | |
""" | |
Analyze the extracted signals and compute deception probability. | |
""" | |
# Analyze HRV Suppression | |
self.metrics["HRV Suppression"].analyze(hrv_rmssd) | |
# Analyze Heart Rate Elevation | |
avg_heart_rate = np.mean(heart_rate) | |
self.metrics["Heart Rate Elevation"].analyze(avg_heart_rate) | |
# Analyze Rhythm Irregularity | |
rhythm_irregularity = np.std(rr_intervals) / np.mean(rr_intervals) | |
self.metrics["Rhythm Irregularity"].analyze(rhythm_irregularity) | |
# Analyze Speech Features | |
for key in ["Speech Stress", "Speech Pitch Variation", "Pauses and Hesitations", "Filler Words"]: | |
if key in speech_features: | |
self.metrics[key].analyze(speech_features[key]) | |
# Analyze Facial Features | |
# Placeholder values; in actual implementation, replace with real values | |
self.metrics["Blink Rate"].analyze(facial_features.get("Blink Rate", 0)) | |
self.metrics["Head Movements"].analyze(facial_features.get("Head Movements", 0)) | |
# Calculate deception probability | |
detected_indicators = sum(1 for m in self.metrics.values() if m.detected) | |
total_indicators = len(self.metrics) | |
probability = (detected_indicators / total_indicators) * 100 | |
# Determine deception level | |
if probability < 30: | |
level = DeceptionLevel.LOW | |
elif probability < 70: | |
level = DeceptionLevel.MODERATE | |
else: | |
level = DeceptionLevel.HIGH | |
# Prepare metrics for visualization | |
metrics_data = {name: { | |
"value": m.value, | |
"threshold": m.threshold, | |
"detected": m.detected | |
} for name, m in self.metrics.items()} | |
return metrics_data, probability, level | |
def load_transcription_model(model_name: str) -> Optional[torch.nn.Module]: | |
""" | |
Load the speech-to-text transcription model. | |
""" | |
try: | |
model = Wav2Vec2ForCTC.from_pretrained( | |
model_name, | |
ignore_mismatched_sizes=True | |
) | |
model.eval() | |
logger.info("Transcription model loaded successfully.") | |
return model | |
except Exception as e: | |
logger.error(f"Error loading transcription model: {e}") | |
return None | |
def load_models() -> Dict[str, torch.nn.Module]: | |
""" | |
Load all necessary models for the deception detection system. | |
""" | |
models_dict = {} | |
try: | |
# Load Transcription Model | |
transcription_model_name = 'facebook/wav2vec2-base-960h' | |
transcription_model = load_transcription_model(transcription_model_name) | |
if transcription_model: | |
models_dict['transcription_model'] = transcription_model | |
except Exception as e: | |
logger.error(f"Error loading models: {e}") | |
return models_dict | |
def transcribe_audio(audio_path: str, transcription_model: nn.Module) -> str: | |
""" | |
Transcribe audio to text using Wav2Vec2 model. | |
""" | |
try: | |
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") | |
y, sr = librosa.load(audio_path, sr=16000) | |
input_values = tokenizer(y, return_tensors="pt", padding="longest").input_values | |
with torch.no_grad(): | |
logits = transcription_model(input_values).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = tokenizer.decode(predicted_ids[0]) | |
# Clean transcription | |
transcription = transcription.lower() | |
transcription = re.sub(r'[^a-z\s]', '', transcription) | |
return transcription | |
except Exception as e: | |
logger.error(f"Error transcribing audio: {str(e)}") | |
return "" | |
def detect_silence(y: np.ndarray, sr: int, top_db: int = 30) -> float: | |
""" | |
Detect total duration of silence in the audio. | |
""" | |
try: | |
intervals = librosa.effects.split(y, top_db=top_db) | |
silence_duration = 0.0 | |
prev_end = 0 | |
for start, end in intervals: | |
silence = (start - prev_end) / sr | |
silence_duration += silence | |
prev_end = end | |
# Add silence after the last interval | |
silence_duration += (len(y) - prev_end) / sr | |
return silence_duration | |
except Exception as e: | |
logger.error(f"Error detecting silence: {str(e)}") | |
return 0.0 | |
def count_filler_words(transcription: str) -> int: | |
""" | |
Count the number of filler words in the transcription. | |
""" | |
filler_words_list = ['um', 'uh', 'er', 'ah', 'like', 'you know', 'so'] | |
return sum(transcription.split().count(word) for word in filler_words_list) | |
def analyze_speech(audio_path: str, transcription_model: nn.Module) -> Dict[str, float]: | |
""" | |
Analyze speech from the audio file and extract features. | |
""" | |
if not audio_path: | |
logger.warning("No audio path provided.") | |
return {} | |
try: | |
# Load audio file | |
y, sr = librosa.load(audio_path, sr=16000) # Ensure consistent sampling rate | |
logger.info(f"Loaded audio file with sampling rate: {sr} Hz") | |
# Extract prosodic features | |
pitches, magnitudes = librosa.piptrack(y=y, sr=sr) | |
pitch_values = pitches[magnitudes > np.median(magnitudes)] | |
avg_pitch = np.mean(pitch_values) if len(pitch_values) > 0 else 0.0 | |
pitch_variation = np.std(pitch_values) if len(pitch_values) > 0 else 0.0 | |
# Calculate speech stress based on pitch variation | |
speech_stress = pitch_variation / (avg_pitch if avg_pitch != 0 else 1) | |
# Calculate speech rate (words per minute) | |
transcription = transcribe_audio(audio_path, transcription_model) | |
words = transcription.split() | |
duration_minutes = librosa.get_duration(y=y, sr=sr) / 60 | |
speech_rate = len(words) / duration_minutes if duration_minutes > 0 else 0.0 | |
# Detect pauses and hesitations | |
silence_duration = detect_silence(y, sr) | |
filler_words = count_filler_words(transcription) | |
logger.info(f"Speech Analysis - Avg Pitch: {avg_pitch:.2f} Hz, Pitch Variation: {pitch_variation:.2f} Hz") | |
logger.info(f"Speech Stress Level: {speech_stress:.2f}") | |
logger.info(f"Speech Rate: {speech_rate:.2f} WPM") | |
logger.info(f"Silence Duration: {silence_duration:.2f} seconds") | |
logger.info(f"Filler Words Count: {filler_words}") | |
# Return extracted features | |
return { | |
"Speech Stress": speech_stress, | |
"Speech Pitch Variation": pitch_variation, | |
"Pauses and Hesitations": silence_duration, | |
"Filler Words": filler_words | |
} | |
except Exception as e: | |
logger.error(f"Error analyzing speech: {str(e)}") | |
return {} | |
def extract_audio_from_video(video_path: str) -> Optional[str]: | |
""" | |
Extract audio from the video file and save it as a temporary WAV file. | |
""" | |
if not video_path: | |
logger.warning("No video path provided for audio extraction.") | |
return None | |
try: | |
video_clip = mp.VideoFileClip(video_path) | |
if video_clip.audio is None: | |
logger.warning("No audio track found in the video.") | |
video_clip.close() | |
return None | |
temp_audio_fd, temp_audio_path = tempfile.mkstemp(suffix=".wav") | |
os.close(temp_audio_fd) # Close the file descriptor | |
video_clip.audio.write_audiofile(temp_audio_path, logger=None) | |
video_clip.close() | |
logger.info(f"Extracted audio to temporary file: {temp_audio_path}") | |
return temp_audio_path | |
except Exception as e: | |
logger.error(f"Error extracting audio from video: {str(e)}") | |
return None | |
def detect_blink(face_landmarks, frame: np.ndarray) -> float: | |
""" | |
Detect blink rate from facial landmarks. | |
Placeholder implementation. | |
""" | |
# Implement Eye Aspect Ratio (EAR) or other blink detection methods | |
return np.random.uniform(10, 20) # Example blink rate | |
def estimate_head_movement(face_landmarks) -> float: | |
""" | |
Estimate head movements based on facial landmarks. | |
Placeholder implementation. | |
""" | |
# Implement head pose estimation to detect nods/shakes | |
return np.random.uniform(5, 15) # Example head movements | |
def create_visualization(metrics: Dict, probability: float, heart_rate: np.ndarray, | |
duration: float, level: DeceptionLevel, speech_features: Dict[str, float]) -> plt.Figure: | |
""" | |
Create visualization of analysis results. | |
""" | |
# Set figure style parameters | |
plt.style.use('default') | |
plt.rcParams.update({ | |
'figure.facecolor': 'white', | |
'axes.facecolor': 'white', | |
'grid.color': '#E0E0E0', | |
'grid.linestyle': '-', | |
'grid.alpha': 0.3, | |
'font.size': 10, | |
'axes.labelsize': 10, | |
'axes.titlesize': 12, | |
'figure.titlesize': 14, | |
'font.family': ['DejaVu Sans', 'Arial', 'sans-serif'] | |
}) | |
# Create figure and axes | |
fig = plt.figure(figsize=(12, 20)) | |
# Create polar plot for deception probability gauge | |
ax1 = fig.add_subplot(4, 1, 1, projection='polar') | |
# Create other subplots | |
ax2 = fig.add_subplot(4, 1, 2) | |
ax3 = fig.add_subplot(4, 1, 3) | |
ax4 = fig.add_subplot(4, 1, 4) | |
# Plot 1: Deception Probability Gauge | |
# Create gauge plot | |
theta = np.linspace(0, np.pi, 100) | |
radius = np.ones(100) | |
ax1.plot(theta, radius, color='#E0E0E0', linewidth=30, alpha=0.3) | |
current_angle = (probability / 100) * np.pi | |
ax1.plot([0, current_angle], [0, 0.7], color='red', linewidth=5) | |
ax1.set_xticks([]) | |
ax1.set_yticks([]) | |
ax1.set_title(f'Deception Probability: {probability:.1f}% ({level.value})', pad=20, color='#333333') | |
ax1.set_theta_zero_location('N') | |
ax1.set_facecolor('white') | |
ax1.grid(False) | |
ax1.spines['polar'].set_visible(False) | |
# Plot 2: Metrics Bar Chart | |
names = list(metrics.keys()) | |
values = [m["value"] for m in metrics.values()] | |
thresholds = [m["threshold"] for m in metrics.values()] | |
detected = [m["detected"] for m in metrics.values()] | |
x = np.arange(len(names)) | |
width = 0.35 | |
bar_colors = ['#FF6B6B' if d else '#4BB543' for d in detected] | |
ax2.bar(x - width/2, values, width, label='Current', color=bar_colors) | |
ax2.bar(x + width/2, thresholds, width, label='Threshold', color='#E0E0E0', alpha=0.7) | |
ax2.set_ylabel('Value') | |
ax2.set_title('Physiological, Facial, and Speech Indicators', pad=20) | |
ax2.set_xticks(x) | |
ax2.set_xticklabels(names, rotation=45, ha='right') | |
ax2.grid(True, axis='y', alpha=0.3) | |
ax2.legend(loc='upper right', framealpha=0.9) | |
# Plot 3: Heart Rate Over Time | |
time_axis = np.linspace(0, duration, len(heart_rate)) | |
ax3.plot(time_axis, heart_rate, color='#3498db') | |
ax3.set_xlabel('Time (s)') | |
ax3.set_ylabel('Heart Rate (BPM)') | |
ax3.set_title('Heart Rate Over Time', pad=20) | |
ax3.grid(True, alpha=0.3) | |
# Plot 4: Speech Features | |
pauses = speech_features.get("Pauses and Hesitations", 0) | |
filler_words = speech_features.get("Filler Words", 0) | |
labels = ['Pauses (s)', 'Filler Words (count)'] | |
values = [pauses, filler_words] | |
colors = ['#FFC300', '#FF5733'] | |
ax4.bar(labels, values, color=colors) | |
ax4.set_ylabel('Count / Duration') | |
ax4.set_title('Pauses and Hesitations in Speech', pad=20) | |
ax4.grid(True, axis='y', alpha=0.3) | |
plt.tight_layout() | |
return fig | |
def process_video_and_audio(video_path: str, models: Dict[str, torch.nn.Module]) -> Tuple[Optional[np.ndarray], Optional[plt.Figure]]: | |
""" | |
Process video and audio, perform deception analysis. | |
""" | |
logger.info("Starting video and audio processing.") | |
if not video_path: | |
logger.warning("No video path provided.") | |
return None, None | |
try: | |
# Extract audio from video | |
audio_path = extract_audio_from_video(video_path) | |
if not audio_path: | |
logger.warning("No audio available for speech analysis.") | |
# Initialize video capture | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
logger.error("Failed to open video file.") | |
return None, None | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
if fps <= 0 or fps != fps: | |
logger.error("Invalid frame rate detected.") | |
cap.release() | |
return None, None | |
logger.info(f"Video FPS: {fps}") | |
# Initialize processors | |
signal_processor = SignalProcessor(fps) | |
analyzer = DeceptionAnalyzer() | |
ppg_signal = [] | |
last_frame = None | |
# Initialize Mediapipe for real-time facial feature extraction | |
mp_face_mesh = mp_mediapipe.solutions.face_mesh | |
face_mesh = mp_face_mesh.FaceMesh(static_image_mode=False, max_num_faces=1) | |
frame_counter = 0 | |
# Process video frames | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame_counter += 1 | |
# Extract PPG signal from green channel | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
green_channel = frame_rgb[:, :, 1] | |
ppg_signal.append(np.mean(green_channel)) | |
# Extract facial features | |
results = face_mesh.process(frame_rgb) | |
if results.multi_face_landmarks: | |
face_landmarks = results.multi_face_landmarks[0] | |
# Blink Detection | |
blink = detect_blink(face_landmarks, frame) | |
analyzer.metrics["Blink Rate"].analyze(blink) | |
# Head Movement Detection | |
head_movement = estimate_head_movement(face_landmarks) | |
analyzer.metrics["Head Movements"].analyze(head_movement) | |
else: | |
analyzer.metrics["Blink Rate"].analyze(0.0) | |
analyzer.metrics["Head Movements"].analyze(0.0) | |
# Store last frame | |
last_frame = cv2.resize(frame_rgb, (320, 240)) | |
# Optional: Log progress every 100 frames | |
if frame_counter % 100 == 0: | |
logger.info(f"Processed {frame_counter} frames.") | |
cap.release() | |
face_mesh.close() | |
logger.info(f"Total frames processed: {frame_counter}") | |
if not ppg_signal or last_frame is None: | |
logger.error("No PPG signal extracted or last frame missing.") | |
return last_frame, None | |
# Convert PPG signal to numpy array | |
ppg_signal = np.array(ppg_signal) | |
logger.debug("PPG signal extracted.") | |
# Apply bandpass filter | |
filtered_signal = signal_processor.bandpass_filter(ppg_signal) | |
logger.debug("Filtered PPG signal.") | |
# Find peaks in the filtered signal | |
peaks = signal_processor.find_peaks_in_signal(filtered_signal) | |
if len(peaks) < 2: | |
logger.warning("Insufficient peaks detected. Signal quality may be poor.") | |
return last_frame, None # Return last_frame but no analysis | |
# Calculate RR intervals in milliseconds | |
rr_intervals = np.diff(peaks) / fps * 1000 # ms | |
heart_rate = 60 * fps / np.diff(peaks) # BPM | |
if len(rr_intervals) == 0 or len(heart_rate) == 0: | |
logger.error("Failed to calculate RR intervals or heart rate.") | |
return last_frame, None | |
# Calculate RMSSD (Root Mean Square of Successive Differences) | |
hrv_rmssd = np.sqrt(np.mean(np.diff(rr_intervals) ** 2)) | |
logger.debug(f"Calculated RMSSD: {hrv_rmssd:.2f} ms") | |
# Analyze speech | |
if audio_path and 'transcription_model' in models: | |
speech_features = analyze_speech(audio_path, models['transcription_model']) | |
else: | |
speech_features = {} | |
# Analyze signals | |
metrics, probability, level = analyzer.analyze_signals( | |
heart_rate, rr_intervals, hrv_rmssd, speech_features, | |
{} | |
) | |
# Create visualization | |
duration = len(ppg_signal) / fps # seconds | |
fig = create_visualization( | |
metrics, probability, heart_rate, | |
duration, level, speech_features | |
) | |
# Clean up temporary audio file if it was extracted | |
if audio_path and os.path.exists(audio_path): | |
try: | |
os.remove(audio_path) | |
logger.info(f"Deleted temporary audio file: {audio_path}") | |
except Exception as e: | |
logger.error(f"Error deleting temporary audio file: {str(e)}") | |
logger.info("Video and audio processing completed successfully.") | |
return last_frame, fig | |
except Exception as e: | |
logger.error(f"Error processing video and audio: {str(e)}") | |
return None, None | |
def create_deception_detection_tab(models: Dict[str, torch.nn.Module]) -> gr.Blocks: | |
""" | |
Create the deception detection interface tab using Gradio. | |
""" | |
def analyze(video): | |
try: | |
if video is None: | |
return None, None | |
video_path = video | |
logger.info(f"Received video for analysis: {video_path}") | |
if not os.path.exists(video_path): | |
logger.error("Video file does not exist.") | |
return None, None | |
last_frame, fig = process_video_and_audio(video_path, models) | |
if fig: | |
return last_frame, fig | |
else: | |
return last_frame, None | |
except Exception as e: | |
logger.error(f"Error in analyze function: {str(e)}") | |
return None, None | |
with gr.Blocks() as deception_interface: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_video = gr.Video() | |
gr.Examples(["./assets/videos/fitness.mp4", "./assets/videos/vladirmir.mp4", "./assets/videos/lula.mp4"], inputs=[input_video]) | |
gr.Markdown(""" | |
### Deception Level Analysis | |
This analysis evaluates physiological, facial, and speech indicators | |
that may suggest deceptive behavior. | |
**Physiological Indicators:** | |
- β HRV Suppression | |
- β Heart Rate Elevation | |
- β Rhythm Irregularity | |
**Facial Indicators:** | |
- β Blink Rate | |
- β Head Movements | |
**Speech Indicators:** | |
- β Speech Stress | |
- β Speech Pitch Variation | |
- β Pauses and Hesitations | |
- β Filler Words | |
**Interpretation:** | |
- **Low (0-30%):** Minimal indicators | |
- **Moderate (30-70%):** Some indicators | |
- **High (>70%):** Strong indicators | |
**Important Note:** | |
This analysis is for research purposes only. | |
Results should not be used as definitive proof | |
of deception or truthfulness. | |
""") | |
with gr.Column(scale=2): | |
output_frame = gr.Image(label="Last Frame of Video", height=240) | |
analysis_plot = gr.Plot(label="Deception Analysis") | |
# Configure automatic analysis upon video upload | |
input_video.change( | |
fn=analyze, | |
inputs=[input_video], | |
outputs=[output_frame, analysis_plot] | |
) | |
return deception_interface | |