Testys commited on
Commit
b6b2705
·
verified ·
1 Parent(s): 6c72f8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -97
app.py CHANGED
@@ -2,10 +2,8 @@
2
  import gradio as gr
3
  import numpy as np
4
  import torch
5
- import soundfile as sf
6
  import os
7
  import yaml
8
- import spaces
9
  from dotenv import load_dotenv
10
  from threading import Thread
11
  from gradio_webrtc import WebRTC
@@ -13,168 +11,145 @@ from twilio.rest import Client
13
 
14
  # --- TTS & AI Imports ---
15
  from parler_tts import ParlerTTSForConditionalGeneration
16
- from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
17
- from streamer import ParlerTTSStreamer # Make sure streamer.py is available
18
 
 
19
  from src.detection.factory import get_detector
20
  from src.alerting.alert_system import get_alerter
21
 
22
  # --- Load Configuration and Environment Variables ---
23
- # This part is the same as our Streamlit app
24
  load_dotenv()
25
  config_path = 'config.yaml'
26
  with open(config_path, 'r') as f:
27
  config = yaml.safe_load(f)
28
- secrets = {
29
- "gemini_api_key": os.getenv("GEMINI_API_KEY"),
30
- }
31
 
32
  # --- Initialize Backend Components ---
33
  print("Initializing detector and alerter...")
34
  detector = get_detector(config)
35
  alerter = get_alerter(config, secrets["gemini_api_key"])
36
- print("Initialization complete. Launching UI...")
37
 
 
38
  account_sid = os.environ.get("TURN_USERNAME")
39
  auth_token = os.environ.get("TURN_CREDENTIAL")
40
-
41
  if account_sid and auth_token:
42
- client = Client(account_sid, auth_token)
43
-
44
- token = client.tokens.create()
45
-
46
- rtc_configuration = {
47
- "iceServers": token.ice_servers,
48
- "iceTransportPolicy": "relay",
49
- }
50
- else:
51
- rtc_configuration = None
52
-
53
-
54
- # --- Parler-TTS Model Setup (Requires GPU) ---
55
- print("Loading Parler-TTS model. This may take a moment...")
 
56
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
57
  if device == "cpu":
58
- print("\nWARNING: Running Parler-TTS on a CPU will be extremely slow. A GPU is highly recommended.\n")
59
  torch_dtype = torch.float16 if device != "cpu" else torch.float32
60
 
61
-
62
- # Using a smaller, faster model suitable for real-time alerts
63
  repo_id = "parler-tts/parler_tts_mini_v0.1"
64
  model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device)
65
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
66
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
67
  print("Parler-TTS model loaded.")
68
 
69
- # --- Audio Streaming Generator Function ---
70
- @spaces.GPU
71
  def stream_alert_audio(text_prompt):
72
- """
73
- A generator function that yields audio chunks for a given text prompt.
74
- This is the core of the streaming implementation.
75
- """
76
  sampling_rate = model.config.sampling_rate
77
- description = "Jenny is A female speaker with a clear and urgent voice." # Voice prompt for TTS
78
-
79
  prompt_ids = tokenizer(text_prompt, return_tensors="pt").input_ids.to(device)
80
  description_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
81
-
82
- # Setup the streamer
83
  streamer = ParlerTTSStreamer(model, device, play_steps=int(sampling_rate * 2.0))
84
-
85
- generation_kwargs = dict(
86
- input_ids=description_ids,
87
- prompt_input_ids=prompt_ids,
88
- streamer=streamer,
89
- do_sample=True,
90
- temperature=1.0, # Increase for more vocal variety
91
- repetition_penalty=1.2,
92
- )
93
-
94
- # Run generation in a separate thread to not block the UI
95
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
96
-
97
  try:
98
  thread.start()
99
  print(f"Audio stream started for: '{text_prompt}'")
100
- # Yield audio chunks as they become available
101
  for new_audio_chunk in streamer:
102
  yield (sampling_rate, new_audio_chunk)
103
  finally:
104
- # CRITICAL: This block runs after the generator is exhausted (audio finishes)
105
- # We reset the alerter state so that a new alert can be triggered later.
106
  print("Audio stream finished. Resetting alerter state.")
107
  alerter.reset_alert()
108
-
109
- # --- Main Webcam Processing Function ---
110
- @spaces.GPU
111
- def process_live_frame(frame):
112
  """
113
- Processes each webcam frame, performs drowsiness detection, and
114
- returns a generator for audio streaming when an alert is triggered.
115
  """
116
- if frame is None:
117
- return np.zeros((480, 640, 3), dtype=np.uint8), "Status: Inactive", None
118
 
 
119
  processed_frame, indicators, _ = detector.process_frame(frame)
 
 
 
 
 
 
 
 
 
120
  drowsiness_level = indicators.get("drowsiness_level", "Awake")
121
  lighting = indicators.get("lighting", "Good")
122
  score = indicators.get("details", {}).get("Score", 0)
123
 
124
- # Build status text
125
- status_text = f"Lighting: {lighting}\n"
126
  if lighting == "Low":
127
- status_text += "Detection paused due to low light."
128
- else:
129
- status_text += f"Status: {drowsiness_level}\nScore: {score:.2f}"
130
 
131
- # --- Alert Trigger Logic ---
132
  audio_output = None
133
  if drowsiness_level != "Awake":
134
- # alerter.trigger_alert() returns the alert TEXT if not on cooldown, otherwise None.
135
  alert_text = alerter.trigger_alert(level=drowsiness_level)
136
  if alert_text:
137
- # If we got text, it means we can start an alert.
138
- # We return the generator function itself. Gradio will handle it.
139
- audio_output = stream_alert_audio(alert_text)
140
- else:
141
- alert_text = "WAKE UP"
142
  audio_output = stream_alert_audio(alert_text)
143
-
144
- # On subsequent frames where the user is drowsy, trigger_alert() will return None
145
- # due to the cooldown, preventing a new stream from starting, which is what we want.
146
-
147
- return processed_frame, status_text, audio_output
148
-
149
 
150
  # --- Gradio UI Definition ---
151
  with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as app:
152
- gr.Markdown("# 🚗 Drive Paddy - Drowsiness Detection (Streaming)")
153
- gr.Markdown("Live drowsiness detection with real-time, streaming voice alerts.")
 
 
 
154
 
155
  with gr.Row():
156
  with gr.Column(scale=2):
157
- webcam_input = WebRTC(label="Stream", rtc_configuration=rtc_configuration)
 
158
  with gr.Column(scale=1):
159
- processed_output = gr.Image(label="Processed Feed")
160
  status_output = gr.Textbox(label="Live Status", lines=3, interactive=False)
 
161
 
162
- # --- KEY CHANGE: The Audio component now uses streaming=True ---
163
- audio_alert_output = gr.Audio(
164
- label="Alert System",
165
- autoplay=True,
166
- visible=True, # Hide the player controls
167
- streaming=True
168
- )
169
-
170
- webcam_input.stream(
171
- fn=process_live_frame,
172
- inputs=[webcam_input],
173
- outputs=[status_output, audio_alert_output],
174
- time_limit=10
175
  )
176
 
 
 
 
 
 
 
 
 
177
 
178
- # --- Launch the App ---
179
  if __name__ == "__main__":
180
- app.launch(debug=True)
 
 
2
  import gradio as gr
3
  import numpy as np
4
  import torch
 
5
  import os
6
  import yaml
 
7
  from dotenv import load_dotenv
8
  from threading import Thread
9
  from gradio_webrtc import WebRTC
 
11
 
12
  # --- TTS & AI Imports ---
13
  from parler_tts import ParlerTTSForConditionalGeneration
14
+ from transformers import AutoTokenizer, AutoFeatureExtractor
15
+ from streamer import ParlerTTSStreamer
16
 
17
+ # --- Local Project Imports ---
18
  from src.detection.factory import get_detector
19
  from src.alerting.alert_system import get_alerter
20
 
21
  # --- Load Configuration and Environment Variables ---
 
22
  load_dotenv()
23
  config_path = 'config.yaml'
24
  with open(config_path, 'r') as f:
25
  config = yaml.safe_load(f)
26
+ secrets = {"gemini_api_key": os.getenv("GEMINI_API_KEY")}
 
 
27
 
28
  # --- Initialize Backend Components ---
29
  print("Initializing detector and alerter...")
30
  detector = get_detector(config)
31
  alerter = get_alerter(config, secrets["gemini_api_key"])
32
+ print("Initialization complete.")
33
 
34
+ # --- Twilio TURN Server Setup ---
35
  account_sid = os.environ.get("TURN_USERNAME")
36
  auth_token = os.environ.get("TURN_CREDENTIAL")
37
+ rtc_configuration = None
38
  if account_sid and auth_token:
39
+ try:
40
+ client = Client(account_sid, auth_token)
41
+ token = client.tokens.create()
42
+ rtc_configuration = {"iceServers": token.ice_servers}
43
+ print("Twilio TURN server configured successfully.")
44
+ except Exception as e:
45
+ print(f"Warning: Failed to create Twilio token. Using public STUN server. Error: {e}")
46
+ # Fallback to a public STUN server if Twilio fails or is not configured
47
+ if rtc_configuration is None:
48
+ print("Using public STun server.")
49
+ rtc_configuration = {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
50
+
51
+
52
+ # --- Parler-TTS Model Setup ---
53
+ print("Loading Parler-TTS model...")
54
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
55
  if device == "cpu":
56
+ print("\nWARNING: Running Parler-TTS on a CPU is slow. A GPU is highly recommended.\n")
57
  torch_dtype = torch.float16 if device != "cpu" else torch.float32
58
 
 
 
59
  repo_id = "parler-tts/parler_tts_mini_v0.1"
60
  model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device)
61
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
62
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
63
  print("Parler-TTS model loaded.")
64
 
65
+ # --- Audio Streaming Generator ---
 
66
  def stream_alert_audio(text_prompt):
67
+ """A generator that streams audio chunks for a given text prompt."""
 
 
 
68
  sampling_rate = model.config.sampling_rate
69
+ description = "A female speaker with a clear and urgent voice."
 
70
  prompt_ids = tokenizer(text_prompt, return_tensors="pt").input_ids.to(device)
71
  description_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
 
 
72
  streamer = ParlerTTSStreamer(model, device, play_steps=int(sampling_rate * 2.0))
73
+ generation_kwargs = dict(input_ids=description_ids, prompt_input_ids=prompt_ids, streamer=streamer, do_sample=True, temperature=1.0, repetition_penalty=1.2)
 
 
 
 
 
 
 
 
 
 
74
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
75
  try:
76
  thread.start()
77
  print(f"Audio stream started for: '{text_prompt}'")
 
78
  for new_audio_chunk in streamer:
79
  yield (sampling_rate, new_audio_chunk)
80
  finally:
 
 
81
  print("Audio stream finished. Resetting alerter state.")
82
  alerter.reset_alert()
83
+
84
+ # --- Decoupled Processing Functions ---
85
+
86
+ def process_video_and_update_state(frame_dict: dict, state: dict):
87
  """
88
+ HIGH-FREQUENCY LOOP: Processes video, updates shared state, and returns the processed frame.
89
+ This function's speed directly impacts video latency.
90
  """
91
+ if not frame_dict or "video" not in frame_dict or frame_dict["video"] is None:
92
+ return np.zeros((480, 640, 3), dtype=np.uint8), state
93
 
94
+ frame = frame_dict["video"]
95
  processed_frame, indicators, _ = detector.process_frame(frame)
96
+ state['indicators'] = indicators
97
+ return processed_frame, state
98
+
99
+ def update_ui_from_state(state: dict):
100
+ """
101
+ LOW-FREQUENCY LOOP: Reads from state to update status text and trigger audio.
102
+ This runs independently of the video loop.
103
+ """
104
+ indicators = state.get('indicators', {})
105
  drowsiness_level = indicators.get("drowsiness_level", "Awake")
106
  lighting = indicators.get("lighting", "Good")
107
  score = indicators.get("details", {}).get("Score", 0)
108
 
109
+ status_text = f"Lighting: {lighting}\nStatus: {drowsiness_level}\nScore: {score:.2f}"
 
110
  if lighting == "Low":
111
+ status_text = "Lighting: Low\nDetection paused due to low light."
 
 
112
 
 
113
  audio_output = None
114
  if drowsiness_level != "Awake":
 
115
  alert_text = alerter.trigger_alert(level=drowsiness_level)
116
  if alert_text:
 
 
 
 
 
117
  audio_output = stream_alert_audio(alert_text)
118
+ return status_text, audio_output
 
 
 
 
 
119
 
120
  # --- Gradio UI Definition ---
121
  with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as app:
122
+ gr.Markdown("# 🚗 Drive Paddy - Drowsiness Detection (WebRTC)")
123
+ gr.Markdown("Low-latency video processing via WebRTC, with decoupled UI updates for smooth performance.")
124
+
125
+ # Shared state object to pass data between the two processing loops
126
+ shared_state = gr.State(value={'indicators': {}})
127
 
128
  with gr.Row():
129
  with gr.Column(scale=2):
130
+ # This WebRTC component is now correctly used for both input and output of the video stream.
131
+ webcam = WebRTC(label="Live Camera Feed", rtc_configuration=rtc_configuration)
132
  with gr.Column(scale=1):
 
133
  status_output = gr.Textbox(label="Live Status", lines=3, interactive=False)
134
+ audio_alert_output = gr.Audio(label="Alert System", autoplay=True, visible=False, streaming=True)
135
 
136
+ # LOOP 1: High-Frequency Video Stream (as fast as possible)
137
+ # This takes video from the webcam, processes it, and sends it right back.
138
+ webcam.stream(
139
+ fn=process_video_and_update_state,
140
+ inputs=[webcam, shared_state],
141
+ outputs=[webcam, shared_state],
 
 
 
 
 
 
 
142
  )
143
 
144
+ # LOOP 2: Low-Frequency UI Updates (4 times per second)
145
+ # This runs on a timer, reads the shared state, and updates the other UI elements.
146
+ app.load(
147
+ fn=update_ui_from_state,
148
+ inputs=[shared_state],
149
+ outputs=[status_output, audio_alert_output],
150
+ every=0.25 # Run this function every 250 milliseconds
151
+ )
152
 
 
153
  if __name__ == "__main__":
154
+ print("Starting Drive Paddy WebRTC Application...")
155
+ app.launch(debug=True)