File size: 3,932 Bytes
19f420a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# drive_paddy/detection/strategies/hybrid.py
from src.detection.base_processor import BaseProcessor
from src.detection.strategies.geometric import GeometricProcessor
from src.detection.strategies.cnn_model import CnnProcessor
import cv2
import concurrent.futures

class HybridProcessor(BaseProcessor):
    """
    Combines outputs from multiple detection strategies (Geometric and CNN)
    concurrently to make a more robust and efficient drowsiness decision.
    This version includes frame skipping for the CNN model to improve performance.
    """
    def __init__(self, config):
        self.geometric_processor = GeometricProcessor(config)
        self.cnn_processor = CnnProcessor(config)
        self.weights = config['hybrid_settings']['weights']
        self.alert_threshold = config['hybrid_settings']['alert_threshold']
        self.active_alerts = {}
        
        # --- Performance Optimization ---
        self.frame_counter = 0
        self.cnn_process_interval = 10  # Run CNN every 10 frames
        self.last_cnn_indicators = {"cnn_prediction": False} # Cache the last CNN result

        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=2)

    def process_frame(self, frame):
        self.frame_counter += 1

        # --- Concurrent Execution ---
        # The geometric processor runs on every frame.
        geo_future = self.executor.submit(self.geometric_processor.process_frame, frame.copy())

        # The CNN processor only runs on specified intervals.
        if self.frame_counter % self.cnn_process_interval == 0:
            cnn_future = self.executor.submit(self.cnn_processor.process_frame, frame.copy())
        
        # Get the result from the geometric processor.
        geo_frame, geo_indicators = geo_future.result()

        # Get the CNN result if it was run, otherwise use the cached result.
        if self.frame_counter % self.cnn_process_interval == 0:
            _, self.last_cnn_indicators = cnn_future.result()
        
        cnn_indicators = self.last_cnn_indicators
        
        # Calculate weighted drowsiness score from the combined results.
        score = 0
        self.active_alerts.clear()

        if geo_indicators.get("eye_closure"):
            score += self.weights['eye_closure']
            self.active_alerts['Eyes Closed'] = geo_indicators['details'].get('EAR', 0)
        if geo_indicators.get("yawning"):
            score += self.weights['yawning']
            self.active_alerts['Yawning'] = geo_indicators['details'].get('MAR', 0)
        if geo_indicators.get("head_nod"):
            score += self.weights['head_nod']
            self.active_alerts['Head Nod'] = geo_indicators['details'].get('Pitch', 0)
        if geo_indicators.get("looking_away"):
            score += self.weights['looking_away']
            self.active_alerts['Looking Away'] = geo_indicators['details'].get('Yaw', 0)
        if cnn_indicators.get("cnn_prediction"):
            score += self.weights['cnn_prediction']
            self.active_alerts['CNN Alert'] = 'Active'

        # --- Visualization ---
        output_frame = geo_frame
        y_pos = 30
        for alert, value in self.active_alerts.items():
            text = f"{alert}: {value:.2f}" if isinstance(value, float) else alert
            cv2.putText(output_frame, text, (10, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
            y_pos += 25
            
        cv2.putText(output_frame, f"Score: {score:.2f}", (output_frame.shape[1] - 150, 30),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

        alert_triggered = score >= self.alert_threshold
        if alert_triggered:
            cv2.rectangle(output_frame, (0, 0), (output_frame.shape[1], output_frame.shape[0]), (0, 0, 255), 5)

        # Return the processed frame, the alert trigger, and the active alert details
        return output_frame, alert_triggered, self.active_alerts