Testys commited on
Commit
3a175cd
·
verified ·
1 Parent(s): d0ff3a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -76
app.py CHANGED
@@ -2,153 +2,158 @@
2
  import gradio as gr
3
  import numpy as np
4
  import torch
 
5
  import os
6
  import yaml
7
  from dotenv import load_dotenv
8
  from threading import Thread
9
- from gradio_webrtc import WebRTC
10
- from twilio.rest import Client
11
 
12
  # --- TTS & AI Imports ---
13
  from parler_tts import ParlerTTSForConditionalGeneration
14
- from transformers import AutoTokenizer, AutoFeatureExtractor
15
- from streamer import ParlerTTSStreamer
16
 
17
- # --- Local Project Imports ---
18
  from src.detection.factory import get_detector
19
  from src.alerting.alert_system import get_alerter
20
 
21
  # --- Load Configuration and Environment Variables ---
 
22
  load_dotenv()
23
  config_path = 'config.yaml'
24
  with open(config_path, 'r') as f:
25
  config = yaml.safe_load(f)
26
- secrets = {"gemini_api_key": os.getenv("GEMINI_API_KEY")}
 
 
27
 
28
  # --- Initialize Backend Components ---
29
  print("Initializing detector and alerter...")
30
  detector = get_detector(config)
31
  alerter = get_alerter(config, secrets["gemini_api_key"])
32
- print("Initialization complete.")
33
 
34
- # --- Twilio TURN Server Setup ---
35
- account_sid = os.environ.get("TURN_USERNAME")
36
- auth_token = os.environ.get("TURN_CREDENTIAL")
37
- rtc_configuration = None
38
- if account_sid and auth_token:
39
- try:
40
- client = Client(account_sid, auth_token)
41
- token = client.tokens.create()
42
- rtc_configuration = {"iceServers": token.ice_servers}
43
- print("Twilio TURN server configured successfully.")
44
- except Exception as e:
45
- print(f"Warning: Failed to create Twilio token. Using public STUN server. Error: {e}")
46
- # Fallback to a public STUN server if Twilio fails or is not configured
47
- if rtc_configuration is None:
48
- print("Using public STun server.")
49
- rtc_configuration = {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
50
-
51
-
52
- # --- Parler-TTS Model Setup ---
53
- print("Loading Parler-TTS model...")
54
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
55
  if device == "cpu":
56
- print("\nWARNING: Running Parler-TTS on a CPU is slow. A GPU is highly recommended.\n")
57
  torch_dtype = torch.float16 if device != "cpu" else torch.float32
58
 
 
 
59
  repo_id = "parler-tts/parler_tts_mini_v0.1"
60
  model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device)
61
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
62
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
63
  print("Parler-TTS model loaded.")
64
 
65
- # --- Audio Streaming Generator ---
66
  def stream_alert_audio(text_prompt):
67
- """A generator that streams audio chunks for a given text prompt."""
 
 
 
68
  sampling_rate = model.config.sampling_rate
69
- description = "A female speaker with a clear and urgent voice."
 
70
  prompt_ids = tokenizer(text_prompt, return_tensors="pt").input_ids.to(device)
71
  description_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
 
 
72
  streamer = ParlerTTSStreamer(model, device, play_steps=int(sampling_rate * 2.0))
73
- generation_kwargs = dict(input_ids=description_ids, prompt_input_ids=prompt_ids, streamer=streamer, do_sample=True, temperature=1.0, repetition_penalty=1.2)
 
 
 
 
 
 
 
 
 
 
74
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
75
  try:
76
  thread.start()
77
  print(f"Audio stream started for: '{text_prompt}'")
 
78
  for new_audio_chunk in streamer:
79
  yield (sampling_rate, new_audio_chunk)
80
  finally:
 
 
81
  print("Audio stream finished. Resetting alerter state.")
82
  alerter.reset_alert()
83
-
84
- # --- Decoupled Processing Functions ---
85
-
86
- def process_video_and_update_state(frame_dict: dict, state: dict):
87
  """
88
- HIGH-FREQUENCY LOOP: Processes video, updates shared state, and returns the processed frame.
89
- This function's speed directly impacts video latency.
90
  """
91
- if not frame_dict or "video" not in frame_dict or frame_dict["video"] is None:
92
- return np.zeros((480, 640, 3), dtype=np.uint8), state
93
 
94
- frame = frame_dict["video"]
95
  processed_frame, indicators, _ = detector.process_frame(frame)
96
- state['indicators'] = indicators
97
- return processed_frame, state
98
-
99
- def update_ui_from_state(state: dict):
100
- """
101
- LOW-FREQUENCY LOOP: Reads from state to update status text and trigger audio.
102
- This runs independently of the video loop.
103
- """
104
- indicators = state.get('indicators', {})
105
  drowsiness_level = indicators.get("drowsiness_level", "Awake")
106
  lighting = indicators.get("lighting", "Good")
107
  score = indicators.get("details", {}).get("Score", 0)
108
 
109
- status_text = f"Lighting: {lighting}\nStatus: {drowsiness_level}\nScore: {score:.2f}"
 
110
  if lighting == "Low":
111
- status_text = "Lighting: Low\nDetection paused due to low light."
 
 
112
 
 
113
  audio_output = None
114
  if drowsiness_level != "Awake":
 
115
  alert_text = alerter.trigger_alert(level=drowsiness_level)
116
  if alert_text:
 
 
117
  audio_output = stream_alert_audio(alert_text)
118
- return status_text, audio_output
 
 
 
 
 
 
 
 
 
119
 
120
  # --- Gradio UI Definition ---
121
  with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as app:
122
- gr.Markdown("# 🚗 Drive Paddy - Drowsiness Detection (WebRTC)")
123
- gr.Markdown("Low-latency video processing via WebRTC, with decoupled UI updates for smooth performance.")
124
-
125
- # Shared state object to pass data between the two processing loops
126
- shared_state = gr.State(value={'indicators': {}})
127
 
128
  with gr.Row():
129
  with gr.Column(scale=2):
130
- # This WebRTC component is now correctly used for both input and output of the video stream.
131
- webcam = WebRTC(label="Live Camera Feed", rtc_configuration=rtc_configuration)
132
  with gr.Column(scale=1):
 
133
  status_output = gr.Textbox(label="Live Status", lines=3, interactive=False)
134
- audio_alert_output = gr.Audio(label="Alert System", autoplay=True, visible=False, streaming=True)
135
 
136
- # LOOP 1: High-Frequency Video Stream (as fast as possible)
137
- # This takes video from the webcam, processes it, and sends it right back.
138
- webcam.stream(
139
- fn=process_video_and_update_state,
140
- inputs=[webcam, shared_state],
141
- outputs=[webcam, shared_state],
 
 
 
 
 
 
142
  )
143
 
144
- # LOOP 2: Low-Frequency UI Updates (4 times per second)
145
- # This runs on a timer, reads the shared state, and updates the other UI elements.
146
- app.load(
147
- fn=update_ui_from_state,
148
- inputs=[shared_state],
149
- outputs=[status_output, audio_alert_output],
150
- )
151
 
 
152
  if __name__ == "__main__":
153
- print("Starting Drive Paddy WebRTC Application...")
154
- app.launch(debug=True)
 
2
  import gradio as gr
3
  import numpy as np
4
  import torch
5
+ import soundfile as sf
6
  import os
7
  import yaml
8
  from dotenv import load_dotenv
9
  from threading import Thread
 
 
10
 
11
  # --- TTS & AI Imports ---
12
  from parler_tts import ParlerTTSForConditionalGeneration
13
+ from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
14
+ from streamer import ParlerTTSStreamer # Make sure streamer.py is available
15
 
 
16
  from src.detection.factory import get_detector
17
  from src.alerting.alert_system import get_alerter
18
 
19
  # --- Load Configuration and Environment Variables ---
20
+ # This part is the same as our Streamlit app
21
  load_dotenv()
22
  config_path = 'config.yaml'
23
  with open(config_path, 'r') as f:
24
  config = yaml.safe_load(f)
25
+ secrets = {
26
+ "gemini_api_key": os.getenv("GEMINI_API_KEY"),
27
+ }
28
 
29
  # --- Initialize Backend Components ---
30
  print("Initializing detector and alerter...")
31
  detector = get_detector(config)
32
  alerter = get_alerter(config, secrets["gemini_api_key"])
33
+ print("Initialization complete. Launching UI...")
34
 
35
+ # --- Parler-TTS Model Setup (Requires GPU) ---
36
+ print("Loading Parler-TTS model. This may take a moment...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
38
  if device == "cpu":
39
+ print("\nWARNING: Running Parler-TTS on a CPU will be extremely slow. A GPU is highly recommended.\n")
40
  torch_dtype = torch.float16 if device != "cpu" else torch.float32
41
 
42
+
43
+ # Using a smaller, faster model suitable for real-time alerts
44
  repo_id = "parler-tts/parler_tts_mini_v0.1"
45
  model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device)
46
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
47
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
48
  print("Parler-TTS model loaded.")
49
 
50
+ # --- Audio Streaming Generator Function ---
51
  def stream_alert_audio(text_prompt):
52
+ """
53
+ A generator function that yields audio chunks for a given text prompt.
54
+ This is the core of the streaming implementation.
55
+ """
56
  sampling_rate = model.config.sampling_rate
57
+ description = "Jenny is A female speaker with a clear and urgent voice." # Voice prompt for TTS
58
+
59
  prompt_ids = tokenizer(text_prompt, return_tensors="pt").input_ids.to(device)
60
  description_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
61
+
62
+ # Setup the streamer
63
  streamer = ParlerTTSStreamer(model, device, play_steps=int(sampling_rate * 2.0))
64
+
65
+ generation_kwargs = dict(
66
+ input_ids=description_ids,
67
+ prompt_input_ids=prompt_ids,
68
+ streamer=streamer,
69
+ do_sample=True,
70
+ temperature=1.0, # Increase for more vocal variety
71
+ repetition_penalty=1.2,
72
+ )
73
+
74
+ # Run generation in a separate thread to not block the UI
75
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
76
+
77
  try:
78
  thread.start()
79
  print(f"Audio stream started for: '{text_prompt}'")
80
+ # Yield audio chunks as they become available
81
  for new_audio_chunk in streamer:
82
  yield (sampling_rate, new_audio_chunk)
83
  finally:
84
+ # CRITICAL: This block runs after the generator is exhausted (audio finishes)
85
+ # We reset the alerter state so that a new alert can be triggered later.
86
  print("Audio stream finished. Resetting alerter state.")
87
  alerter.reset_alert()
88
+
89
+ # --- Main Webcam Processing Function ---
90
+ def process_live_frame(frame):
 
91
  """
92
+ Processes each webcam frame, performs drowsiness detection, and
93
+ returns a generator for audio streaming when an alert is triggered.
94
  """
95
+ if frame is None:
96
+ return np.zeros((480, 640, 3), dtype=np.uint8), "Status: Inactive", None
97
 
 
98
  processed_frame, indicators, _ = detector.process_frame(frame)
 
 
 
 
 
 
 
 
 
99
  drowsiness_level = indicators.get("drowsiness_level", "Awake")
100
  lighting = indicators.get("lighting", "Good")
101
  score = indicators.get("details", {}).get("Score", 0)
102
 
103
+ # Build status text
104
+ status_text = f"Lighting: {lighting}\n"
105
  if lighting == "Low":
106
+ status_text += "Detection paused due to low light."
107
+ else:
108
+ status_text += f"Status: {drowsiness_level}\nScore: {score:.2f}"
109
 
110
+ # --- Alert Trigger Logic ---
111
  audio_output = None
112
  if drowsiness_level != "Awake":
113
+ # alerter.trigger_alert() returns the alert TEXT if not on cooldown, otherwise None.
114
  alert_text = alerter.trigger_alert(level=drowsiness_level)
115
  if alert_text:
116
+ # If we got text, it means we can start an alert.
117
+ # We return the generator function itself. Gradio will handle it.
118
  audio_output = stream_alert_audio(alert_text)
119
+
120
+ else;
121
+ alert_text = "WAKE UP"
122
+ audio_output = stream_alert_audio(alert_text)
123
+
124
+ # On subsequent frames where the user is drowsy, trigger_alert() will return None
125
+ # due to the cooldown, preventing a new stream from starting, which is what we want.
126
+
127
+ return processed_frame, status_text, audio_output
128
+
129
 
130
  # --- Gradio UI Definition ---
131
  with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as app:
132
+ gr.Markdown("# 🚗 Drive Paddy - Drowsiness Detection (Streaming)")
133
+ gr.Markdown("Live drowsiness detection with real-time, streaming voice alerts.")
 
 
 
134
 
135
  with gr.Row():
136
  with gr.Column(scale=2):
137
+ webcam_input = gr.Image(sources=["webcam"], streaming=True, label="Live Camera Feed")
 
138
  with gr.Column(scale=1):
139
+ processed_output = gr.Image(label="Processed Feed")
140
  status_output = gr.Textbox(label="Live Status", lines=3, interactive=False)
 
141
 
142
+ # --- KEY CHANGE: The Audio component now uses streaming=True ---
143
+ audio_alert_output = gr.Audio(
144
+ label="Alert System",
145
+ autoplay=True,
146
+ visible=False, # Hide the player controls
147
+ streaming=True
148
+ )
149
+
150
+ webcam_input.stream(
151
+ fn=process_live_frame,
152
+ inputs=[webcam_input],
153
+ outputs=[processed_output, status_output, audio_alert_output]
154
  )
155
 
 
 
 
 
 
 
 
156
 
157
+ # --- Launch the App ---
158
  if __name__ == "__main__":
159
+ app.launch(debug=True)