Docfile commited on
Commit
f710c8e
·
verified ·
1 Parent(s): 9edaa6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -36
app.py CHANGED
@@ -6,18 +6,16 @@ import io
6
  import PIL.Image
7
  import asyncio
8
  import os
 
9
  from google import genai
10
  from streamlit_webrtc import webrtc_streamer
11
  import av
12
- import pyaudio
13
  from mediapipe.tasks import python
14
  from mediapipe.tasks.python import vision
15
 
16
  # Configuration
17
- FORMAT = pyaudio.paInt16
18
  CHANNELS = 1
19
- SEND_SAMPLE_RATE = 16000
20
- RECEIVE_SAMPLE_RATE = 24000
21
  CHUNK_SIZE = 1024
22
 
23
  # Initialize Genai client
@@ -28,26 +26,32 @@ CONFIG = {"generation_config": {"response_modalities": ["AUDIO"]}}
28
 
29
  class AudioProcessor:
30
  def __init__(self):
31
- self.audio = pyaudio.PyAudio()
32
  self.stream = None
33
  self.audio_queue = asyncio.Queue()
 
 
 
 
 
 
34
 
35
  def start_stream(self):
36
- mic_info = self.audio.get_default_input_device_info()
37
- self.stream = self.audio.open(
38
- format=FORMAT,
39
- channels=CHANNELS,
40
- rate=SEND_SAMPLE_RATE,
41
- input=True,
42
- input_device_index=mic_info["index"],
43
- frames_per_buffer=CHUNK_SIZE,
44
- )
 
45
 
46
  def stop_stream(self):
47
- if self.stream:
48
- self.stream.stop_stream()
49
  self.stream.close()
50
- self.stream = None
51
 
52
  class VideoProcessor:
53
  def __init__(self):
@@ -58,22 +62,17 @@ class VideoProcessor:
58
  min_detection_confidence=0.5)
59
 
60
  def video_frame_callback(self, frame):
61
- # Convert the frame to RGB
62
  img = frame.to_ndarray(format="rgb24")
63
 
64
- # Process the frame with MediaPipe
65
  results = self.face_detection.process(img)
66
 
67
- # Draw face detection annotations if faces are detected
68
  if results.detections:
69
  for detection in results.detections:
70
  self.mp_draw.draw_detection(img, detection)
71
 
72
- # Convert to PIL Image
73
  pil_img = PIL.Image.fromarray(img)
74
  pil_img.thumbnail([1024, 1024])
75
 
76
- # Prepare frame data for Gemini
77
  image_io = io.BytesIO()
78
  pil_img.save(image_io, format="jpeg")
79
  image_io.seek(0)
@@ -91,7 +90,6 @@ class VideoProcessor:
91
  return av.VideoFrame.from_ndarray(img, format="rgb24")
92
 
93
  def __del__(self):
94
- # Cleanup MediaPipe resources
95
  if hasattr(self, 'face_detection'):
96
  self.face_detection.close()
97
 
@@ -113,17 +111,14 @@ def display_chat_messages():
113
  def main():
114
  st.title("Gemini Interactive Assistant")
115
 
116
- # Initialize session state
117
  initialize_session_state()
118
 
119
- # Sidebar configuration
120
  st.sidebar.title("Settings")
121
  input_mode = st.sidebar.radio(
122
  "Input Mode",
123
  ["Text Only", "Audio + Video", "Audio Only"]
124
  )
125
 
126
- # Enable face detection option
127
  enable_face_detection = st.sidebar.checkbox("Enable Face Detection", value=True)
128
 
129
  if enable_face_detection:
@@ -140,14 +135,11 @@ def main():
140
  )
141
  )
142
 
143
- # Display chat history
144
  display_chat_messages()
145
 
146
- # Main interaction area
147
  if input_mode == "Text Only":
148
  user_input = st.chat_input("Your message")
149
  if user_input:
150
- # Add user message to chat
151
  st.session_state.messages.append({"role": "user", "content": user_input})
152
  with st.chat_message("user"):
153
  st.markdown(user_input)
@@ -158,7 +150,6 @@ def main():
158
  turn = session.receive()
159
  async for response in turn:
160
  if text := response.text:
161
- # Add assistant response to chat
162
  st.session_state.messages.append(
163
  {"role": "assistant", "content": text}
164
  )
@@ -168,7 +159,6 @@ def main():
168
  asyncio.run(send_message())
169
 
170
  else:
171
- # Video stream setup
172
  if input_mode == "Audio + Video":
173
  ctx = webrtc_streamer(
174
  key="gemini-stream",
@@ -177,7 +167,6 @@ def main():
177
  media_stream_constraints={"video": True, "audio": True},
178
  )
179
 
180
- # Audio controls
181
  col1, col2 = st.columns(2)
182
  with col1:
183
  if st.button("Start Recording", type="primary"):
@@ -191,12 +180,15 @@ def main():
191
 
192
  async def process_audio_stream():
193
  while st.session_state.get('recording', False):
194
- if st.session_state.audio_processor.stream:
195
- data = st.session_state.audio_processor.stream.read(CHUNK_SIZE)
196
  await st.session_state.audio_processor.audio_queue.put({
197
- "data": data,
198
- "mime_type": "audio/pcm"
 
199
  })
 
 
200
  await asyncio.sleep(0.1)
201
 
202
  if __name__ == "__main__":
 
6
  import PIL.Image
7
  import asyncio
8
  import os
9
+ import sounddevice as sd
10
  from google import genai
11
  from streamlit_webrtc import webrtc_streamer
12
  import av
 
13
  from mediapipe.tasks import python
14
  from mediapipe.tasks.python import vision
15
 
16
  # Configuration
 
17
  CHANNELS = 1
18
+ SAMPLE_RATE = 16000
 
19
  CHUNK_SIZE = 1024
20
 
21
  # Initialize Genai client
 
26
 
27
  class AudioProcessor:
28
  def __init__(self):
 
29
  self.stream = None
30
  self.audio_queue = asyncio.Queue()
31
+
32
+ def audio_callback(self, indata, frames, time, status):
33
+ """This is called (from a separate thread) for each audio block."""
34
+ if status:
35
+ print(status)
36
+ self.audio_queue.put_nowait(indata.copy())
37
 
38
  def start_stream(self):
39
+ try:
40
+ self.stream = sd.InputStream(
41
+ channels=CHANNELS,
42
+ samplerate=SAMPLE_RATE,
43
+ callback=self.audio_callback,
44
+ blocksize=CHUNK_SIZE
45
+ )
46
+ self.stream.start()
47
+ except Exception as e:
48
+ st.error(f"Error starting audio stream: {str(e)}")
49
 
50
  def stop_stream(self):
51
+ if self.stream is not None:
52
+ self.stream.stop()
53
  self.stream.close()
54
+ self.stream = None
55
 
56
  class VideoProcessor:
57
  def __init__(self):
 
62
  min_detection_confidence=0.5)
63
 
64
  def video_frame_callback(self, frame):
 
65
  img = frame.to_ndarray(format="rgb24")
66
 
 
67
  results = self.face_detection.process(img)
68
 
 
69
  if results.detections:
70
  for detection in results.detections:
71
  self.mp_draw.draw_detection(img, detection)
72
 
 
73
  pil_img = PIL.Image.fromarray(img)
74
  pil_img.thumbnail([1024, 1024])
75
 
 
76
  image_io = io.BytesIO()
77
  pil_img.save(image_io, format="jpeg")
78
  image_io.seek(0)
 
90
  return av.VideoFrame.from_ndarray(img, format="rgb24")
91
 
92
  def __del__(self):
 
93
  if hasattr(self, 'face_detection'):
94
  self.face_detection.close()
95
 
 
111
  def main():
112
  st.title("Gemini Interactive Assistant")
113
 
 
114
  initialize_session_state()
115
 
 
116
  st.sidebar.title("Settings")
117
  input_mode = st.sidebar.radio(
118
  "Input Mode",
119
  ["Text Only", "Audio + Video", "Audio Only"]
120
  )
121
 
 
122
  enable_face_detection = st.sidebar.checkbox("Enable Face Detection", value=True)
123
 
124
  if enable_face_detection:
 
135
  )
136
  )
137
 
 
138
  display_chat_messages()
139
 
 
140
  if input_mode == "Text Only":
141
  user_input = st.chat_input("Your message")
142
  if user_input:
 
143
  st.session_state.messages.append({"role": "user", "content": user_input})
144
  with st.chat_message("user"):
145
  st.markdown(user_input)
 
150
  turn = session.receive()
151
  async for response in turn:
152
  if text := response.text:
 
153
  st.session_state.messages.append(
154
  {"role": "assistant", "content": text}
155
  )
 
159
  asyncio.run(send_message())
160
 
161
  else:
 
162
  if input_mode == "Audio + Video":
163
  ctx = webrtc_streamer(
164
  key="gemini-stream",
 
167
  media_stream_constraints={"video": True, "audio": True},
168
  )
169
 
 
170
  col1, col2 = st.columns(2)
171
  with col1:
172
  if st.button("Start Recording", type="primary"):
 
180
 
181
  async def process_audio_stream():
182
  while st.session_state.get('recording', False):
183
+ try:
184
+ audio_data = await st.session_state.audio_processor.audio_queue.get()
185
  await st.session_state.audio_processor.audio_queue.put({
186
+ "data": audio_data.tobytes(),
187
+ "mime_type": "audio/pcm",
188
+ "sample_rate": SAMPLE_RATE
189
  })
190
+ except asyncio.QueueEmpty:
191
+ pass
192
  await asyncio.sleep(0.1)
193
 
194
  if __name__ == "__main__":