Spaces:
Sleeping
Sleeping
import os | |
import cv2 | |
import io | |
import tempfile | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import torch | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from facenet_pytorch import MTCNN | |
import gradio as gr | |
class EmotionModel(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.net = torch.nn.Sequential( | |
torch.nn.Conv2d(1, 32, 3, padding=1), | |
torch.nn.ReLU(), | |
torch.nn.MaxPool2d(2), | |
torch.nn.Conv2d(32, 64, 3, padding=1), | |
torch.nn.ReLU(), | |
torch.nn.MaxPool2d(2), | |
torch.nn.Flatten(), | |
torch.nn.Linear(64 * 12 * 12, 128), | |
torch.nn.ReLU(), | |
torch.nn.Linear(128, 7) # 7 emotion classes | |
) | |
def forward(self, x): | |
return self.net(x) | |
class EmotionDetector: | |
def __init__(self, device='cpu'): | |
self.device = device | |
self.model = EmotionModel().to(self.device) | |
self.model.eval() | |
self.emotions = ['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral'] | |
self.face_detector = MTCNN(keep_all=False, device=self.device) | |
self.transform = transforms.Compose([ | |
transforms.Resize((48, 48)) | |
]) | |
self.softmax = torch.nn.Softmax(dim=1) | |
# Load pre-trained weights here if available | |
# self.model.load_state_dict(torch.load("emotion_model.pt", map_location=self.device)) | |
def detect_emotions_video(self, video_path, sample_rate=30, max_size_mb=50): | |
try: | |
if video_path is None: | |
return None, "No video provided" | |
if os.path.getsize(video_path) / (1024 * 1024) > max_size_mb: | |
return None, f"File too large (>{max_size_mb} MB)." | |
cap = cv2.VideoCapture(video_path) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
if frame_count == 0: | |
return None, "Invalid video file" | |
frame_indices = range(0, frame_count, sample_rate) | |
emotions_over_time = [] | |
for frame_idx in frame_indices: | |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) | |
ret, frame = cap.read() | |
if not ret: | |
continue | |
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
img_pil = Image.fromarray(img_rgb) | |
face_tensor = self.face_detector(img_pil) | |
if face_tensor is None: | |
continue | |
face_tensor = self.transform(face_tensor) # Resize | |
face_tensor = face_tensor.mean(dim=0, keepdim=True) # grayscale | |
face_tensor = face_tensor.unsqueeze(0).to(self.device) # batch + channel | |
with torch.no_grad(): | |
output = self.model(face_tensor) | |
probs = self.softmax(output).cpu().numpy()[0] | |
emotion_data = {self.emotions[i]: float(probs[i]) * 100 for i in range(len(self.emotions))} | |
emotion_data['timestamp'] = frame_idx / fps | |
emotions_over_time.append(emotion_data) | |
cap.release() | |
if not emotions_over_time: | |
return None, "No emotions detected." | |
df = pd.DataFrame(emotions_over_time) | |
plt.figure(figsize=(12, 8)) | |
for emotion in self.emotions: | |
if emotion in df.columns: | |
plt.plot(df['timestamp'], df[emotion], label=emotion.title(), linewidth=2) | |
plt.xlabel('Time (seconds)') | |
plt.ylabel('Confidence (%)') | |
plt.title('Emotions Over Time') | |
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') | |
plt.grid(True) | |
plt.tight_layout() | |
img_buf = io.BytesIO() | |
plt.savefig(img_buf, format='png', dpi=150, bbox_inches='tight') | |
img_buf.seek(0) | |
plt.close() | |
chart_image = Image.open(img_buf) | |
avg_emotions = df[self.emotions].mean().sort_values(ascending=False) | |
result_text = f"**Video Analysis Complete**\n" | |
result_text += f"**Frames Analyzed:** {len(emotions_over_time)}\n" | |
result_text += f"**Duration:** {df['timestamp'].max():.1f} seconds\n\n" | |
result_text += "**Average Emotions:**\n" | |
for emotion, confidence in avg_emotions.items(): | |
result_text += f"• {emotion.title()}: {confidence:.1f}%\n" | |
return chart_image, result_text | |
except Exception as e: | |
return None, f"Error: {str(e)}" | |
def create_interface(): | |
detector = EmotionDetector() | |
def process(video_path, sample_rate): | |
if video_path is None or not os.path.exists(video_path): | |
return None, "Invalid video path or no video uploaded." | |
return detector.detect_emotions_video(video_path, sample_rate) | |
return gr.Interface( | |
fn=process, | |
inputs=[ | |
gr.Video(label="Upload Video"), # Removed type="file" | |
gr.Slider(minimum=1, maximum=60, step=1, value=30, label="Sample Rate (Frames)") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Emotion Chart"), | |
gr.Textbox(label="Analysis Summary") | |
], | |
title="AI Emotion Detection", | |
description="Upload a video to analyze emotions over time." | |
) | |
if __name__ == "__main__": | |
create_interface().launch() | |