Testys commited on
Commit
95b307f
·
verified ·
1 Parent(s): fb982e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -48
app.py CHANGED
@@ -1,14 +1,18 @@
1
  # app_gradio.py
2
  import gradio as gr
3
  import numpy as np
 
 
4
  import os
5
  import yaml
6
  from dotenv import load_dotenv
7
- import io
8
- from scipy.io.wavfile import read as read_wav
9
- import time
 
 
 
10
 
11
- # Correctly import from the drive_paddy package structure
12
  from src.detection.factory import get_detector
13
  from src.alerting.alert_system import get_alerter
14
 
@@ -28,88 +32,124 @@ detector = get_detector(config)
28
  alerter = get_alerter(config, secrets["gemini_api_key"])
29
  print("Initialization complete. Launching UI...")
30
 
 
 
 
 
 
 
 
31
 
32
- STREAM_START_TIME = None
 
 
 
 
 
33
 
34
- # --- Audio Processing for Gradio ---
35
- # Gradio's gr.Audio component needs a specific format: (sample_rate, numpy_array)
36
- def process_audio_for_gradio(audio_bytes):
37
- """Converts in-memory audio bytes to a format Gradio can play."""
38
- # gTTS creates MP3, so we read it as such
39
- byte_io = io.BytesIO(audio_bytes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  try:
41
- from pydub import AudioSegment
42
- audio = AudioSegment.from_mp3(byte_io)
43
- wav_byte_io = io.BytesIO()
44
- audio.export(wav_byte_io, format="wav")
45
- wav_byte_io.seek(0)
 
 
 
 
 
46
 
47
- sample_rate, data = read_wav(wav_byte_io)
48
- return (sample_rate, data)
49
- except Exception as e:
50
- print(f"Could not process audio for Gradio: {e}")
51
- return None
52
-
53
-
54
  def process_live_frame(frame):
55
  """
56
- Takes a single frame from the Gradio webcam input, processes it,
57
- and returns the processed frame, status text, and any audio alerts.
58
  """
59
  if frame is None:
60
- # Return default values if frame is None
61
- blank_image = np.zeros((480, 640, 3), dtype=np.uint8)
62
- return blank_image, "Status: Inactive", None
63
 
64
- # Process the frame using our existing detector
65
  processed_frame, indicators, _ = detector.process_frame(frame)
66
  drowsiness_level = indicators.get("drowsiness_level", "Awake")
67
  lighting = indicators.get("lighting", "Good")
68
  score = indicators.get("details", {}).get("Score", 0)
69
 
70
- # Build the status text
71
  status_text = f"Lighting: {lighting}\n"
72
  if lighting == "Low":
73
  status_text += "Detection paused due to low light."
74
  else:
75
  status_text += f"Status: {drowsiness_level}\nScore: {score:.2f}"
76
 
77
- # Handle alerts
78
  audio_output = None
79
  if drowsiness_level != "Awake":
80
- audio_data = alerter.trigger_alert(level=drowsiness_level)
81
- if audio_data:
82
- audio_output = process_audio_for_gradio(audio_data)
83
- else:
84
- alerter.reset_alert()
 
 
 
 
85
 
86
- # Return all the values needed to update the UI
87
  return processed_frame, status_text, audio_output
88
 
 
89
  # --- Gradio UI Definition ---
90
- with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as app:
91
- gr.Markdown("# 🚗 Drive Paddy - Drowsiness Detection (Gradio)")
92
- gr.Markdown("A live test using Gradio's webcam component. This can be more stable than WebRTC in some environments.")
93
 
94
  with gr.Row():
95
- with gr.Column():
96
- # Input: Live webcam feed
97
  webcam_input = gr.Image(sources=["webcam"], streaming=True, label="Live Camera Feed")
98
- with gr.Column():
99
- # Output 1: Processed video feed
100
  processed_output = gr.Image(label="Processed Feed")
101
- # Output 2: Live status text
102
  status_output = gr.Textbox(label="Live Status", lines=3, interactive=False)
103
- # Output 3: Hidden audio player for alerts
104
- audio_alert_output = gr.Audio(autoplay=True, visible=False)
 
 
 
 
 
 
105
 
106
- # Link the input to the processing function and the function to the outputs
107
  webcam_input.stream(
108
  fn=process_live_frame,
109
  inputs=[webcam_input],
110
  outputs=[processed_output, status_output, audio_alert_output]
111
  )
112
 
 
113
  # --- Launch the App ---
114
  if __name__ == "__main__":
115
  app.launch(debug=True)
 
1
  # app_gradio.py
2
  import gradio as gr
3
  import numpy as np
4
+ import torch
5
+ import soundfile as sf
6
  import os
7
  import yaml
8
  from dotenv import load_dotenv
9
+ from threading import Thread
10
+
11
+ # --- TTS & AI Imports ---
12
+ from parler_tts import ParlerTTSForConditionalGeneration
13
+ from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
14
+ from streamer import ParlerTTSStreamer # Make sure streamer.py is available
15
 
 
16
  from src.detection.factory import get_detector
17
  from src.alerting.alert_system import get_alerter
18
 
 
32
  alerter = get_alerter(config, secrets["gemini_api_key"])
33
  print("Initialization complete. Launching UI...")
34
 
35
+ # --- Parler-TTS Model Setup (Requires GPU) ---
36
+ print("Loading Parler-TTS model. This may take a moment...")
37
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
38
+ if device == "cpu":
39
+ print("\nWARNING: Running Parler-TTS on a CPU will be extremely slow. A GPU is highly recommended.\n")
40
+ torch_dtype = torch.float16 if device != "cpu" else torch.float32
41
+
42
 
43
+ # Using a smaller, faster model suitable for real-time alerts
44
+ repo_id = "parler-tts/parler_tts_mini_v0.1"
45
+ model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device)
46
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
47
+ feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
48
+ print("Parler-TTS model loaded.")
49
 
50
+ # --- Audio Streaming Generator Function ---
51
+ def stream_alert_audio(text_prompt):
52
+ """
53
+ A generator function that yields audio chunks for a given text prompt.
54
+ This is the core of the streaming implementation.
55
+ """
56
+ sampling_rate = model.config.sampling_rate
57
+ description = "Jenny is A female speaker with a clear and urgent voice." # Voice prompt for TTS
58
+
59
+ prompt_ids = tokenizer(text_prompt, return_tensors="pt").input_ids.to(device)
60
+ description_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
61
+
62
+ # Setup the streamer
63
+ streamer = ParlerTTSStreamer(model, device, play_steps=int(sampling_rate * 2.0))
64
+
65
+ generation_kwargs = dict(
66
+ input_ids=description_ids,
67
+ prompt_input_ids=prompt_ids,
68
+ streamer=streamer,
69
+ do_sample=True,
70
+ temperature=1.0, # Increase for more vocal variety
71
+ repetition_penalty=1.2,
72
+ )
73
+
74
+ # Run generation in a separate thread to not block the UI
75
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
76
+
77
  try:
78
+ thread.start()
79
+ print(f"Audio stream started for: '{text_prompt}'")
80
+ # Yield audio chunks as they become available
81
+ for new_audio_chunk in streamer:
82
+ yield (sampling_rate, new_audio_chunk)
83
+ finally:
84
+ # CRITICAL: This block runs after the generator is exhausted (audio finishes)
85
+ # We reset the alerter state so that a new alert can be triggered later.
86
+ print("Audio stream finished. Resetting alerter state.")
87
+ alerter.reset_alert()
88
 
89
+ # --- Main Webcam Processing Function ---
 
 
 
 
 
 
90
  def process_live_frame(frame):
91
  """
92
+ Processes each webcam frame, performs drowsiness detection, and
93
+ returns a generator for audio streaming when an alert is triggered.
94
  """
95
  if frame is None:
96
+ return np.zeros((480, 640, 3), dtype=np.uint8), "Status: Inactive", None
 
 
97
 
 
98
  processed_frame, indicators, _ = detector.process_frame(frame)
99
  drowsiness_level = indicators.get("drowsiness_level", "Awake")
100
  lighting = indicators.get("lighting", "Good")
101
  score = indicators.get("details", {}).get("Score", 0)
102
 
103
+ # Build status text
104
  status_text = f"Lighting: {lighting}\n"
105
  if lighting == "Low":
106
  status_text += "Detection paused due to low light."
107
  else:
108
  status_text += f"Status: {drowsiness_level}\nScore: {score:.2f}"
109
 
110
+ # --- Alert Trigger Logic ---
111
  audio_output = None
112
  if drowsiness_level != "Awake":
113
+ # alerter.trigger_alert() returns the alert TEXT if not on cooldown, otherwise None.
114
+ alert_text = alerter.trigger_alert(level=drowsiness_level)
115
+ if alert_text:
116
+ # If we got text, it means we can start an alert.
117
+ # We return the generator function itself. Gradio will handle it.
118
+ audio_output = stream_alert_audio(alert_text)
119
+
120
+ # On subsequent frames where the user is drowsy, trigger_alert() will return None
121
+ # due to the cooldown, preventing a new stream from starting, which is what we want.
122
 
 
123
  return processed_frame, status_text, audio_output
124
 
125
+
126
  # --- Gradio UI Definition ---
127
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as app:
128
+ gr.Markdown("# 🚗 Drive Paddy - Drowsiness Detection (Streaming)")
129
+ gr.Markdown("Live drowsiness detection with real-time, streaming voice alerts.")
130
 
131
  with gr.Row():
132
+ with gr.Column(scale=2):
 
133
  webcam_input = gr.Image(sources=["webcam"], streaming=True, label="Live Camera Feed")
134
+ with gr.Column(scale=1):
 
135
  processed_output = gr.Image(label="Processed Feed")
 
136
  status_output = gr.Textbox(label="Live Status", lines=3, interactive=False)
137
+
138
+ # --- KEY CHANGE: The Audio component now uses streaming=True ---
139
+ audio_alert_output = gr.Audio(
140
+ label="Alert System",
141
+ autoplay=True,
142
+ visible=False, # Hide the player controls
143
+ streaming=True
144
+ )
145
 
 
146
  webcam_input.stream(
147
  fn=process_live_frame,
148
  inputs=[webcam_input],
149
  outputs=[processed_output, status_output, audio_alert_output]
150
  )
151
 
152
+
153
  # --- Launch the App ---
154
  if __name__ == "__main__":
155
  app.launch(debug=True)