Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
from scipy.signal import find_peaks, savgol_filter | |
from collections import Counter | |
from tqdm import tqdm | |
import time | |
import os | |
import torch | |
import torch.nn as nn | |
import torch.fft as fft | |
import xgboost as xgb | |
from torch.utils.data import DataLoader, TensorDataset | |
import time | |
from models.temporal_fusion_transformer import TemporalFusionTransformer | |
from models.tcn import TCN | |
from models.etsformer import ETSformer | |
from models.bilstm import BiLSTM | |
from models.respfusion import RespFusion | |
def process_video(video_path): | |
# Parameters | |
roi_coordinates = None # Manual ROI selection | |
fps = 30 # Frames per second | |
feature_window = 10 # Window for feature aggregation (seconds) | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
return "Error opening video file" | |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# Select ROI | |
ret, first_frame = cap.read() | |
if not ret: | |
return "Failed to read first frame" | |
if roi_coordinates is None: | |
roi_coordinates = cv2.selectROI("Select ROI", first_frame, fromCenter=False, showCrosshair=True) | |
cv2.destroyWindow("Select ROI") | |
x, y, w, h = map(int, roi_coordinates) | |
prev_gray = cv2.cvtColor(first_frame[y:y+h, x:x+w], cv2.COLOR_BGR2GRAY) | |
motion_magnitude, direction_angles = [], [] | |
frame_skip = 2 | |
scale_factor = 0.5 | |
roi = cv2.resize(first_frame[y:y+h, x:x+w], None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_AREA) | |
prev_gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame_index = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) | |
if frame_index % frame_skip != 0: | |
continue | |
roi = cv2.resize(frame[y:y+h, x:x+w], None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_AREA) | |
gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY) | |
flow = cv2.calcOpticalFlowFarneback(prev_gray, gray, None, 0.5, 3, 15, 2, 5, 1.2, 0) | |
magnitude, angle = cv2.cartToPolar(flow[..., 0], flow[..., 1]) | |
motion_magnitude.append(np.mean(magnitude)) | |
direction_angles.extend(angle.flatten()) | |
prev_gray = gray | |
cap.release() | |
window_length = min(len(motion_magnitude) - 1, 31) | |
if window_length % 2 == 0: | |
window_length += 1 | |
smoothed_magnitude = savgol_filter(motion_magnitude, window_length=window_length, polyorder=3) | |
features = [] | |
window_size = feature_window * (fps // frame_skip) | |
num_windows = len(smoothed_magnitude) // window_size | |
for i in range(num_windows): | |
start, end = i * window_size, (i + 1) * window_size | |
window_magnitude = smoothed_magnitude[start:end] | |
peaks, _ = find_peaks(window_magnitude) | |
troughs, _ = find_peaks(-window_magnitude) | |
features.append([ | |
np.mean(window_magnitude), np.std(window_magnitude), np.max(window_magnitude), np.min(window_magnitude), | |
len(peaks), np.mean(window_magnitude[peaks]) if len(peaks) > 0 else 0, | |
np.mean(np.diff(peaks)) / (fps // frame_skip) if len(peaks) > 1 else 0, | |
(np.mean(window_magnitude[peaks]) - np.mean(window_magnitude[troughs])) if peaks.size > 0 and troughs.size > 0 else 0, | |
Counter((np.array(direction_angles[start:end]) / (np.pi / 4)).astype(int) % 8).most_common(1)[0][0] if len(direction_angles) > 0 else -1, | |
np.argmax(np.abs(np.fft.fft(window_magnitude))[1:len(window_magnitude)//2]) / window_size | |
]) | |
features_df = pd.DataFrame(features, columns=[ | |
"mean_magnitude", "std_magnitude", "max_magnitude", "min_magnitude", "peak_count", "avg_peak_height", "peak_to_peak_interval", "amplitude", "dominant_direction", "dominant_frequency" | |
]) | |
print("Features Extracted. Matching Model Keys.") | |
X_inference = np.array([features_df.iloc[i-5:i, :].values for i in range(5, len(features_df))]) | |
X_inference_tensor = torch.tensor(X_inference, dtype=torch.float32) | |
input_size, hidden_size, output_size = X_inference_tensor.shape[2], 64, 1 | |
tft_model = TemporalFusionTransformer(input_size, hidden_size, output_size) | |
tcn_model = TCN(input_size, hidden_size, output_size) | |
ets_model = ETSformer(input_size, hidden_size, output_size) | |
bilstm_model = BiLSTM(input_size, hidden_size, output_size) | |
weights_path = './models/weights/' | |
# Load state dicts for each model | |
tft_model.load_state_dict(torch.load(weights_path + "tft_loss_optimized.pth", map_location=torch.device("cpu"))) | |
print("TFT Model Loaded. All keys matched successfully.") | |
tcn_model.load_state_dict(torch.load(weights_path + "tcn_loss_optimized.pth", map_location=torch.device("cpu"))) | |
print("TCN Model Loaded. All keys matched successfully.") | |
ets_model.load_state_dict(torch.load(weights_path + "etsformer_loss_optimized.pth", map_location=torch.device("cpu"))) | |
print("ETSformer Model Loaded. All keys matched successfully.") | |
bilstm_model.load_state_dict(torch.load(weights_path + "bilstm_loss_optimized.pth", map_location=torch.device("cpu"))) | |
print("BiLSTM Model Loaded. All keys matched successfully.") | |
tft_model.eval() | |
tcn_model.eval() | |
ets_model.eval() | |
bilstm_model.eval() | |
model = RespFusion( | |
tft_model, tcn_model, ets_model, bilstm_model, | |
strategy='stacking', | |
meta_learner_path=weights_path + "respfusion_2_xgboost_meta_learner.json" | |
) | |
print("Respfusion Model Loaded. Detecting Apnea.") | |
with torch.no_grad(): | |
predictions = model(X_inference_tensor).squeeze().tolist() | |
def categorize_breathing_rate(pred): | |
if pred < 0.5: | |
return "Apnea" | |
elif pred > 20: | |
return "Tachypnea" | |
elif pred < 10: | |
return "Bradypnea" | |
else: | |
return "Normal" | |
categories = [categorize_breathing_rate(pred) for pred in predictions] | |
overall_category = categorize_breathing_rate(sum(predictions) / len(predictions)) | |
return {"Breathing State per 10s": categories, "Average Breathing Rate": sum(predictions)/len(predictions), "Overall Breathing State": overall_category,} | |
app = gr.Interface( | |
fn=process_video, | |
inputs=gr.Video(label="Upload Video for Analysis"), | |
outputs=gr.JSON(), | |
title="Apnea Detection System", | |
description="Upload a video to analyze breathing rate and detect conditions such as Apnea, Tachypnea, and Bradypnea." | |
) | |
app.launch() |