# tabs/deception_detection.py import gradio as gr import cv2 import numpy as np import matplotlib.pyplot as plt from scipy.signal import butter, filtfilt, find_peaks from typing import Tuple, Optional, Dict import logging from dataclasses import dataclass from enum import Enum import librosa import moviepy.editor as mp import os import tempfile import torch import torch.nn as nn from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer import mediapipe as mp_mediapipe import re # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Define Enums and DataClasses class DeceptionLevel(Enum): LOW = 'Low' MODERATE = 'Moderate' HIGH = 'High' @dataclass class Metric: name: str threshold: float value: float = 0.0 detected: bool = False def analyze(self, new_value: float): self.value = new_value self.detected = self.value > self.threshold class SignalProcessor: def __init__(self, fs: float): self.fs = fs # Sampling frequency def bandpass_filter(self, data: np.ndarray, lowcut: float = 0.75, highcut: float = 3.0) -> np.ndarray: """Apply bandpass filter to signal.""" nyq = 0.5 * self.fs low = lowcut / nyq high = highcut / nyq b, a = butter(2, [low, high], btype='band') filtered = filtfilt(b, a, data) logger.debug("Applied bandpass filter.") return filtered def find_peaks_in_signal(self, signal: np.ndarray) -> np.ndarray: """Find peaks in the signal.""" min_distance = int(60 / 180 * self.fs) # At least 60 BPM (180 BPM max) peaks, _ = find_peaks(signal, distance=min_distance) logger.debug(f"Detected {len(peaks)} peaks in the signal.") return peaks class DeceptionAnalyzer: def __init__(self): self.metrics = { "HRV Suppression": Metric("HRV Suppression", threshold=30.0), "Heart Rate Elevation": Metric("Heart Rate Elevation", threshold=100.0), "Rhythm Irregularity": Metric("Rhythm Irregularity", threshold=0.1), "Blink Rate": Metric("Blink Rate", threshold=25.0), "Head Movements": Metric("Head Movements", threshold=10.0), "Speech Stress": Metric("Speech Stress", threshold=0.5), "Speech Pitch Variation": Metric("Speech Pitch Variation", threshold=50.0), "Pauses and Hesitations": Metric("Pauses and Hesitations", threshold=2.0), "Filler Words": Metric("Filler Words", threshold=5.0), } def analyze_signals(self, heart_rate: np.ndarray, rr_intervals: np.ndarray, hrv_rmssd: float, speech_features: Dict[str, float], facial_features: Dict[str, float]) -> Tuple[Dict[str, Dict], float, DeceptionLevel]: """ Analyze the extracted signals and compute deception probability. """ # Analyze HRV Suppression self.metrics["HRV Suppression"].analyze(hrv_rmssd) # Analyze Heart Rate Elevation avg_heart_rate = np.mean(heart_rate) self.metrics["Heart Rate Elevation"].analyze(avg_heart_rate) # Analyze Rhythm Irregularity rhythm_irregularity = np.std(rr_intervals) / np.mean(rr_intervals) self.metrics["Rhythm Irregularity"].analyze(rhythm_irregularity) # Analyze Speech Features for key in ["Speech Stress", "Speech Pitch Variation", "Pauses and Hesitations", "Filler Words"]: if key in speech_features: self.metrics[key].analyze(speech_features[key]) # Analyze Facial Features # Placeholder values; in actual implementation, replace with real values self.metrics["Blink Rate"].analyze(facial_features.get("Blink Rate", 0)) self.metrics["Head Movements"].analyze(facial_features.get("Head Movements", 0)) # Calculate deception probability detected_indicators = sum(1 for m in self.metrics.values() if m.detected) total_indicators = len(self.metrics) probability = (detected_indicators / total_indicators) * 100 # Determine deception level if probability < 30: level = DeceptionLevel.LOW elif probability < 70: level = DeceptionLevel.MODERATE else: level = DeceptionLevel.HIGH # Prepare metrics for visualization metrics_data = {name: { "value": m.value, "threshold": m.threshold, "detected": m.detected } for name, m in self.metrics.items()} return metrics_data, probability, level def load_transcription_model(model_name: str) -> Optional[torch.nn.Module]: """ Load the speech-to-text transcription model. """ try: model = Wav2Vec2ForCTC.from_pretrained( model_name, ignore_mismatched_sizes=True ) model.eval() logger.info("Transcription model loaded successfully.") return model except Exception as e: logger.error(f"Error loading transcription model: {e}") return None def load_models() -> Dict[str, torch.nn.Module]: """ Load all necessary models for the deception detection system. """ models_dict = {} try: # Load Transcription Model transcription_model_name = 'facebook/wav2vec2-base-960h' transcription_model = load_transcription_model(transcription_model_name) if transcription_model: models_dict['transcription_model'] = transcription_model except Exception as e: logger.error(f"Error loading models: {e}") return models_dict def transcribe_audio(audio_path: str, transcription_model: nn.Module) -> str: """ Transcribe audio to text using Wav2Vec2 model. """ try: tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") y, sr = librosa.load(audio_path, sr=16000) input_values = tokenizer(y, return_tensors="pt", padding="longest").input_values with torch.no_grad(): logits = transcription_model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = tokenizer.decode(predicted_ids[0]) # Clean transcription transcription = transcription.lower() transcription = re.sub(r'[^a-z\s]', '', transcription) return transcription except Exception as e: logger.error(f"Error transcribing audio: {str(e)}") return "" def detect_silence(y: np.ndarray, sr: int, top_db: int = 30) -> float: """ Detect total duration of silence in the audio. """ try: intervals = librosa.effects.split(y, top_db=top_db) silence_duration = 0.0 prev_end = 0 for start, end in intervals: silence = (start - prev_end) / sr silence_duration += silence prev_end = end # Add silence after the last interval silence_duration += (len(y) - prev_end) / sr return silence_duration except Exception as e: logger.error(f"Error detecting silence: {str(e)}") return 0.0 def count_filler_words(transcription: str) -> int: """ Count the number of filler words in the transcription. """ filler_words_list = ['um', 'uh', 'er', 'ah', 'like', 'you know', 'so'] return sum(transcription.split().count(word) for word in filler_words_list) def analyze_speech(audio_path: str, transcription_model: nn.Module) -> Dict[str, float]: """ Analyze speech from the audio file and extract features. """ if not audio_path: logger.warning("No audio path provided.") return {} try: # Load audio file y, sr = librosa.load(audio_path, sr=16000) # Ensure consistent sampling rate logger.info(f"Loaded audio file with sampling rate: {sr} Hz") # Extract prosodic features pitches, magnitudes = librosa.piptrack(y=y, sr=sr) pitch_values = pitches[magnitudes > np.median(magnitudes)] avg_pitch = np.mean(pitch_values) if len(pitch_values) > 0 else 0.0 pitch_variation = np.std(pitch_values) if len(pitch_values) > 0 else 0.0 # Calculate speech stress based on pitch variation speech_stress = pitch_variation / (avg_pitch if avg_pitch != 0 else 1) # Calculate speech rate (words per minute) transcription = transcribe_audio(audio_path, transcription_model) words = transcription.split() duration_minutes = librosa.get_duration(y=y, sr=sr) / 60 speech_rate = len(words) / duration_minutes if duration_minutes > 0 else 0.0 # Detect pauses and hesitations silence_duration = detect_silence(y, sr) filler_words = count_filler_words(transcription) logger.info(f"Speech Analysis - Avg Pitch: {avg_pitch:.2f} Hz, Pitch Variation: {pitch_variation:.2f} Hz") logger.info(f"Speech Stress Level: {speech_stress:.2f}") logger.info(f"Speech Rate: {speech_rate:.2f} WPM") logger.info(f"Silence Duration: {silence_duration:.2f} seconds") logger.info(f"Filler Words Count: {filler_words}") # Return extracted features return { "Speech Stress": speech_stress, "Speech Pitch Variation": pitch_variation, "Pauses and Hesitations": silence_duration, "Filler Words": filler_words } except Exception as e: logger.error(f"Error analyzing speech: {str(e)}") return {} def extract_audio_from_video(video_path: str) -> Optional[str]: """ Extract audio from the video file and save it as a temporary WAV file. """ if not video_path: logger.warning("No video path provided for audio extraction.") return None try: video_clip = mp.VideoFileClip(video_path) if video_clip.audio is None: logger.warning("No audio track found in the video.") video_clip.close() return None temp_audio_fd, temp_audio_path = tempfile.mkstemp(suffix=".wav") os.close(temp_audio_fd) # Close the file descriptor video_clip.audio.write_audiofile(temp_audio_path, logger=None) video_clip.close() logger.info(f"Extracted audio to temporary file: {temp_audio_path}") return temp_audio_path except Exception as e: logger.error(f"Error extracting audio from video: {str(e)}") return None def detect_blink(face_landmarks, frame: np.ndarray) -> float: """ Detect blink rate from facial landmarks. Placeholder implementation. """ # Implement Eye Aspect Ratio (EAR) or other blink detection methods return np.random.uniform(10, 20) # Example blink rate def estimate_head_movement(face_landmarks) -> float: """ Estimate head movements based on facial landmarks. Placeholder implementation. """ # Implement head pose estimation to detect nods/shakes return np.random.uniform(5, 15) # Example head movements def create_visualization(metrics: Dict, probability: float, heart_rate: np.ndarray, duration: float, level: DeceptionLevel, speech_features: Dict[str, float]) -> plt.Figure: """ Create visualization of analysis results. """ # Set figure style parameters plt.style.use('default') plt.rcParams.update({ 'figure.facecolor': 'white', 'axes.facecolor': 'white', 'grid.color': '#E0E0E0', 'grid.linestyle': '-', 'grid.alpha': 0.3, 'font.size': 10, 'axes.labelsize': 10, 'axes.titlesize': 12, 'figure.titlesize': 14, 'font.family': ['DejaVu Sans', 'Arial', 'sans-serif'] }) # Create figure and axes fig = plt.figure(figsize=(12, 20)) # Create polar plot for deception probability gauge ax1 = fig.add_subplot(4, 1, 1, projection='polar') # Create other subplots ax2 = fig.add_subplot(4, 1, 2) ax3 = fig.add_subplot(4, 1, 3) ax4 = fig.add_subplot(4, 1, 4) # Plot 1: Deception Probability Gauge # Create gauge plot theta = np.linspace(0, np.pi, 100) radius = np.ones(100) ax1.plot(theta, radius, color='#E0E0E0', linewidth=30, alpha=0.3) current_angle = (probability / 100) * np.pi ax1.plot([0, current_angle], [0, 0.7], color='red', linewidth=5) ax1.set_xticks([]) ax1.set_yticks([]) ax1.set_title(f'Deception Probability: {probability:.1f}% ({level.value})', pad=20, color='#333333') ax1.set_theta_zero_location('N') ax1.set_facecolor('white') ax1.grid(False) ax1.spines['polar'].set_visible(False) # Plot 2: Metrics Bar Chart names = list(metrics.keys()) values = [m["value"] for m in metrics.values()] thresholds = [m["threshold"] for m in metrics.values()] detected = [m["detected"] for m in metrics.values()] x = np.arange(len(names)) width = 0.35 bar_colors = ['#FF6B6B' if d else '#4BB543' for d in detected] ax2.bar(x - width/2, values, width, label='Current', color=bar_colors) ax2.bar(x + width/2, thresholds, width, label='Threshold', color='#E0E0E0', alpha=0.7) ax2.set_ylabel('Value') ax2.set_title('Physiological, Facial, and Speech Indicators', pad=20) ax2.set_xticks(x) ax2.set_xticklabels(names, rotation=45, ha='right') ax2.grid(True, axis='y', alpha=0.3) ax2.legend(loc='upper right', framealpha=0.9) # Plot 3: Heart Rate Over Time time_axis = np.linspace(0, duration, len(heart_rate)) ax3.plot(time_axis, heart_rate, color='#3498db') ax3.set_xlabel('Time (s)') ax3.set_ylabel('Heart Rate (BPM)') ax3.set_title('Heart Rate Over Time', pad=20) ax3.grid(True, alpha=0.3) # Plot 4: Speech Features pauses = speech_features.get("Pauses and Hesitations", 0) filler_words = speech_features.get("Filler Words", 0) labels = ['Pauses (s)', 'Filler Words (count)'] values = [pauses, filler_words] colors = ['#FFC300', '#FF5733'] ax4.bar(labels, values, color=colors) ax4.set_ylabel('Count / Duration') ax4.set_title('Pauses and Hesitations in Speech', pad=20) ax4.grid(True, axis='y', alpha=0.3) plt.tight_layout() return fig def process_video_and_audio(video_path: str, models: Dict[str, torch.nn.Module]) -> Tuple[Optional[np.ndarray], Optional[plt.Figure]]: """ Process video and audio, perform deception analysis. """ logger.info("Starting video and audio processing.") if not video_path: logger.warning("No video path provided.") return None, None try: # Extract audio from video audio_path = extract_audio_from_video(video_path) if not audio_path: logger.warning("No audio available for speech analysis.") # Initialize video capture cap = cv2.VideoCapture(video_path) if not cap.isOpened(): logger.error("Failed to open video file.") return None, None fps = cap.get(cv2.CAP_PROP_FPS) if fps <= 0 or fps != fps: logger.error("Invalid frame rate detected.") cap.release() return None, None logger.info(f"Video FPS: {fps}") # Initialize processors signal_processor = SignalProcessor(fps) analyzer = DeceptionAnalyzer() ppg_signal = [] last_frame = None # Initialize Mediapipe for real-time facial feature extraction mp_face_mesh = mp_mediapipe.solutions.face_mesh face_mesh = mp_face_mesh.FaceMesh(static_image_mode=False, max_num_faces=1) frame_counter = 0 # Process video frames while True: ret, frame = cap.read() if not ret: break frame_counter += 1 # Extract PPG signal from green channel frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) green_channel = frame_rgb[:, :, 1] ppg_signal.append(np.mean(green_channel)) # Extract facial features results = face_mesh.process(frame_rgb) if results.multi_face_landmarks: face_landmarks = results.multi_face_landmarks[0] # Blink Detection blink = detect_blink(face_landmarks, frame) analyzer.metrics["Blink Rate"].analyze(blink) # Head Movement Detection head_movement = estimate_head_movement(face_landmarks) analyzer.metrics["Head Movements"].analyze(head_movement) else: analyzer.metrics["Blink Rate"].analyze(0.0) analyzer.metrics["Head Movements"].analyze(0.0) # Store last frame last_frame = cv2.resize(frame_rgb, (320, 240)) # Optional: Log progress every 100 frames if frame_counter % 100 == 0: logger.info(f"Processed {frame_counter} frames.") cap.release() face_mesh.close() logger.info(f"Total frames processed: {frame_counter}") if not ppg_signal or last_frame is None: logger.error("No PPG signal extracted or last frame missing.") return last_frame, None # Convert PPG signal to numpy array ppg_signal = np.array(ppg_signal) logger.debug("PPG signal extracted.") # Apply bandpass filter filtered_signal = signal_processor.bandpass_filter(ppg_signal) logger.debug("Filtered PPG signal.") # Find peaks in the filtered signal peaks = signal_processor.find_peaks_in_signal(filtered_signal) if len(peaks) < 2: logger.warning("Insufficient peaks detected. Signal quality may be poor.") return last_frame, None # Return last_frame but no analysis # Calculate RR intervals in milliseconds rr_intervals = np.diff(peaks) / fps * 1000 # ms heart_rate = 60 * fps / np.diff(peaks) # BPM if len(rr_intervals) == 0 or len(heart_rate) == 0: logger.error("Failed to calculate RR intervals or heart rate.") return last_frame, None # Calculate RMSSD (Root Mean Square of Successive Differences) hrv_rmssd = np.sqrt(np.mean(np.diff(rr_intervals) ** 2)) logger.debug(f"Calculated RMSSD: {hrv_rmssd:.2f} ms") # Analyze speech if audio_path and 'transcription_model' in models: speech_features = analyze_speech(audio_path, models['transcription_model']) else: speech_features = {} # Analyze signals metrics, probability, level = analyzer.analyze_signals( heart_rate, rr_intervals, hrv_rmssd, speech_features, {} ) # Create visualization duration = len(ppg_signal) / fps # seconds fig = create_visualization( metrics, probability, heart_rate, duration, level, speech_features ) # Clean up temporary audio file if it was extracted if audio_path and os.path.exists(audio_path): try: os.remove(audio_path) logger.info(f"Deleted temporary audio file: {audio_path}") except Exception as e: logger.error(f"Error deleting temporary audio file: {str(e)}") logger.info("Video and audio processing completed successfully.") return last_frame, fig except Exception as e: logger.error(f"Error processing video and audio: {str(e)}") return None, None def create_deception_detection_tab(models: Dict[str, torch.nn.Module]) -> gr.Blocks: """ Create the deception detection interface tab using Gradio. """ def analyze(video): try: if video is None: return None, None video_path = video logger.info(f"Received video for analysis: {video_path}") if not os.path.exists(video_path): logger.error("Video file does not exist.") return None, None last_frame, fig = process_video_and_audio(video_path, models) if fig: return last_frame, fig else: return last_frame, None except Exception as e: logger.error(f"Error in analyze function: {str(e)}") return None, None with gr.Blocks() as deception_interface: with gr.Row(): with gr.Column(scale=1): input_video = gr.Video() gr.Examples(["./assets/videos/fitness.mp4", "./assets/videos/vladirmir.mp4", "./assets/videos/lula.mp4"], inputs=[input_video]) gr.Markdown(""" ### Deception Level Analysis This analysis evaluates physiological, facial, and speech indicators that may suggest deceptive behavior. **Physiological Indicators:** - ◇ HRV Suppression - ◇ Heart Rate Elevation - ◇ Rhythm Irregularity **Facial Indicators:** - ◇ Blink Rate - ◇ Head Movements **Speech Indicators:** - ◇ Speech Stress - ◇ Speech Pitch Variation - ◇ Pauses and Hesitations - ◇ Filler Words **Interpretation:** - **Low (0-30%):** Minimal indicators - **Moderate (30-70%):** Some indicators - **High (>70%):** Strong indicators **Important Note:** This analysis is for research purposes only. Results should not be used as definitive proof of deception or truthfulness. """) with gr.Column(scale=2): output_frame = gr.Image(label="Last Frame of Video", height=240) analysis_plot = gr.Plot(label="Deception Analysis") # Configure automatic analysis upon video upload input_video.change( fn=analyze, inputs=[input_video], outputs=[output_frame, analysis_plot] ) return deception_interface