gsavin commited on
Commit
a277e33
ยท
1 Parent(s): 939ce2b

fix: partial audio fix

Browse files
src/audio/audio_generator.py CHANGED
@@ -2,26 +2,24 @@ import asyncio
2
  from google import genai
3
  from google.genai import types
4
  from config import settings
5
- import os
6
- import tempfile
7
  import wave
8
- import numpy as np
9
  import queue
10
  import logging
11
  import gradio as gr
 
 
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
  client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'})
16
- audio_queue = queue.Queue(maxsize=1)
17
 
18
- async def generate_music(request: gr.Request, music_tone: str, receive_audio):
19
  async with (
20
  client.aio.live.music.connect(model='models/lyria-realtime-exp') as session,
21
  asyncio.TaskGroup() as tg,
22
  ):
23
  # Set up task to receive server messages.
24
- tg.create_task(receive_audio(session))
25
 
26
  # Send initial prompts and config
27
  await session.set_weighted_prompts(
@@ -33,14 +31,18 @@ async def generate_music(request: gr.Request, music_tone: str, receive_audio):
33
  config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
34
  )
35
  await session.play()
36
- logger.info(f"Started music generation for session {request.session_hash}, music tone: {music_tone}")
37
- sessions[request.session_hash] = session
 
 
 
 
38
 
39
- async def change_music_tone(request: gr.Request, new_tone):
40
  logger.info(f"Changing music tone to {new_tone}")
41
- session = sessions.get(request.session_hash)
42
  if not session:
43
- logger.error(f"No session found for request {request.session_hash}")
44
  return
45
  await session.reset_context()
46
  await session.set_weighted_prompts(
@@ -49,55 +51,71 @@ async def change_music_tone(request: gr.Request, new_tone):
49
 
50
 
51
  SAMPLE_RATE = 48000
 
 
52
 
53
- async def receive_audio(session):
54
  """Process incoming audio from the music generation."""
55
  while True:
56
  try:
57
  async for message in session.receive():
58
  if message.server_content and message.server_content.audio_chunks:
59
  audio_data = message.server_content.audio_chunks[0].data
60
- await asyncio.to_thread(audio_queue.put, audio_data)
 
 
61
  await asyncio.sleep(10**-12)
62
  except Exception as e:
63
  logger.error(f"Error in receive_audio: {e}")
64
- await asyncio.sleep(1)
65
 
66
  sessions = {}
67
 
68
- async def start_music_generation(request: gr.Request, music_tone: str):
69
  """Start the music generation in a separate thread."""
70
- await generate_music(request, music_tone, receive_audio)
71
 
72
- async def cleanup_music_session(request: gr.Request):
73
- if request.session_hash in sessions:
74
- logger.info(f"Cleaning up music session for session {request.session_hash}")
75
- await sessions[request.session_hash].stop()
76
- del sessions[request.session_hash]
 
 
77
 
78
- current_audio_file = None
79
 
80
- def update_audio():
81
- """Continuously stream audio from the queue."""
82
- global current_audio_file
83
  while True:
84
- audio_data = audio_queue.get()
85
- if isinstance(audio_data, bytes):
86
- audio_array = np.frombuffer(audio_data, dtype=np.int16)
87
- else:
88
- audio_array = np.array(audio_data, dtype=np.int16)
89
-
90
- temp_fd, temp_path = tempfile.mkstemp(suffix='.wav')
91
- os.close(temp_fd)
92
- # Write to WAV file
93
- with wave.open(temp_path, 'wb') as wav_file:
94
- wav_file.setnchannels(2) # Stereo
95
- wav_file.setsampwidth(2) # 16-bit
96
- wav_file.setframerate(SAMPLE_RATE)
97
- wav_file.writeframes(audio_array.tobytes())
98
-
99
- if current_audio_file:
100
- os.remove(current_audio_file)
101
-
102
- current_audio_file = temp_path
103
- yield temp_path
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from google import genai
3
  from google.genai import types
4
  from config import settings
 
 
5
  import wave
 
6
  import queue
7
  import logging
8
  import gradio as gr
9
+ import io
10
+ import time
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
  client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'})
 
15
 
16
+ async def generate_music(user_hash: str, music_tone: str, receive_audio):
17
  async with (
18
  client.aio.live.music.connect(model='models/lyria-realtime-exp') as session,
19
  asyncio.TaskGroup() as tg,
20
  ):
21
  # Set up task to receive server messages.
22
+ tg.create_task(receive_audio(session, user_hash))
23
 
24
  # Send initial prompts and config
25
  await session.set_weighted_prompts(
 
31
  config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
32
  )
33
  await session.play()
34
+ logger.info(f"Started music generation for user hash {user_hash}, music tone: {music_tone}")
35
+ await cleanup_music_session(user_hash)
36
+ sessions[user_hash] = {
37
+ 'session': session,
38
+ 'queue': queue.Queue(maxsize=3)
39
+ }
40
 
41
+ async def change_music_tone(user_hash: str, new_tone):
42
  logger.info(f"Changing music tone to {new_tone}")
43
+ session = sessions.get(user_hash, {}).get('session')
44
  if not session:
45
+ logger.error(f"No session found for user hash {user_hash}")
46
  return
47
  await session.reset_context()
48
  await session.set_weighted_prompts(
 
51
 
52
 
53
  SAMPLE_RATE = 48000
54
+ NUM_CHANNELS = 2 # Stereo
55
+ SAMPLE_WIDTH = 2 # 16-bit audio -> 2 bytes per sample
56
 
57
+ async def receive_audio(session, user_hash):
58
  """Process incoming audio from the music generation."""
59
  while True:
60
  try:
61
  async for message in session.receive():
62
  if message.server_content and message.server_content.audio_chunks:
63
  audio_data = message.server_content.audio_chunks[0].data
64
+ queue = sessions[user_hash]['queue']
65
+ # audio_data is already bytes (raw PCM)
66
+ await asyncio.to_thread(queue.put, audio_data)
67
  await asyncio.sleep(10**-12)
68
  except Exception as e:
69
  logger.error(f"Error in receive_audio: {e}")
70
+ break
71
 
72
  sessions = {}
73
 
74
+ async def start_music_generation(user_hash: str, music_tone: str):
75
  """Start the music generation in a separate thread."""
76
+ await generate_music(user_hash, music_tone, receive_audio)
77
 
78
+ async def cleanup_music_session(user_hash: str):
79
+ if user_hash in sessions:
80
+ logger.info(f"Cleaning up music session for user hash {user_hash}")
81
+ session = sessions[user_hash]['session']
82
+ await session.stop()
83
+ await session.close()
84
+ del sessions[user_hash]
85
 
 
86
 
87
+ def update_audio(user_hash):
88
+ """Continuously stream audio from the queue as WAV bytes."""
 
89
  while True:
90
+ if user_hash not in sessions:
91
+ time.sleep(0.5)
92
+ continue
93
+ queue = sessions[user_hash]['queue']
94
+ pcm_data = queue.get() # This is raw PCM audio bytes
95
+
96
+ if not isinstance(pcm_data, bytes):
97
+ logger.warning(f"Expected bytes from audio_queue, got {type(pcm_data)}. Skipping.")
98
+ continue
99
+
100
+ # Lyria provides stereo, 16-bit PCM at 48kHz.
101
+ # Ensure the number of bytes is consistent with stereo 16-bit audio.
102
+ # Each frame = NUM_CHANNELS * SAMPLE_WIDTH bytes.
103
+ # If len(pcm_data) is not a multiple of (NUM_CHANNELS * SAMPLE_WIDTH),
104
+ # it might indicate an incomplete chunk or an issue.
105
+ bytes_per_frame = NUM_CHANNELS * SAMPLE_WIDTH
106
+ if len(pcm_data) % bytes_per_frame != 0:
107
+ logger.warning(
108
+ f"Received PCM data with length {len(pcm_data)}, which is not a multiple of "
109
+ f"bytes_per_frame ({bytes_per_frame}). This might cause issues with WAV formatting."
110
+ )
111
+ # Depending on strictness, you might want to skip this chunk:
112
+ # continue
113
+
114
+ wav_buffer = io.BytesIO()
115
+ with wave.open(wav_buffer, 'wb') as wf:
116
+ wf.setnchannels(NUM_CHANNELS)
117
+ wf.setsampwidth(SAMPLE_WIDTH) # Corresponds to 16-bit audio
118
+ wf.setframerate(SAMPLE_RATE)
119
+ wf.writeframes(pcm_data)
120
+ wav_bytes = wav_buffer.getvalue()
121
+ yield wav_bytes
src/game_constructor.py CHANGED
@@ -108,7 +108,7 @@ def save_game_config(
108
  return f"โŒ Error saving configuration: {str(e)}"
109
 
110
  async def start_game_with_settings(
111
- request: gr.Request,
112
  setting_desc: str,
113
  char_name: str,
114
  char_age: str,
@@ -160,10 +160,9 @@ NOTE FOR THE ASSISTANT: YOU HAVE TO GENERATE THE IMAGE FOR THE START SCENE.
160
 
161
  response = await process_user_input(initial_story)
162
 
163
- music_tone = response.change_music.music_description
164
 
165
- if music_tone:
166
- asyncio.create_task(start_music_generation(request, music_tone))
167
 
168
  img = "forest.jpg"
169
 
 
108
  return f"โŒ Error saving configuration: {str(e)}"
109
 
110
  async def start_game_with_settings(
111
+ user_hash: str,
112
  setting_desc: str,
113
  char_name: str,
114
  char_age: str,
 
160
 
161
  response = await process_user_input(initial_story)
162
 
163
+ music_tone = response.change_music.music_description or "neutral"
164
 
165
+ asyncio.create_task(start_music_generation(user_hash, music_tone))
 
166
 
167
  img = "forest.jpg"
168
 
src/main.py CHANGED
@@ -32,7 +32,7 @@ def return_to_constructor():
32
  )
33
 
34
 
35
- async def update_scene(request: gr.Request, choice):
36
  logger.info(f"Updating scene with choice: {choice}")
37
  if isinstance(choice, str):
38
  old_scene = state["scene"]
@@ -61,7 +61,7 @@ async def update_scene(request: gr.Request, choice):
61
  story[new_scene]["image"] = img_path
62
 
63
  if response.change_music.change_music:
64
- await change_music_tone(request, response.change_music.music_description)
65
 
66
  scene = story[state["scene"]]
67
  return (
@@ -92,7 +92,7 @@ def update_preview(setting, name, age, background, personality, genre):
92
 
93
 
94
  async def start_game_with_music(
95
- request: gr.Request,
96
  setting_desc: str,
97
  char_name: str,
98
  char_age: str,
@@ -113,7 +113,7 @@ async def start_game_with_music(
113
 
114
  # First, get the game interface updates
115
  result = await start_game_with_settings(
116
- request,
117
  setting_desc,
118
  char_name,
119
  char_age,
@@ -132,6 +132,8 @@ with gr.Blocks(
132
  # Fullscreen Loading Indicator (hidden by default)
133
  with gr.Column(visible=False, elem_id="loading-indicator") as loading_indicator:
134
  gr.HTML("<div class='loading-text'>๐Ÿš€ Starting your adventure...</div>")
 
 
135
 
136
  # Constructor Interface (visible by default)
137
  with gr.Column(
@@ -296,6 +298,7 @@ with gr.Blocks(
296
  start_btn.click(
297
  fn=start_game_with_music,
298
  inputs=[
 
299
  setting_description,
300
  char_name,
301
  char_age,
@@ -327,14 +330,14 @@ with gr.Blocks(
327
 
328
  game_choices.change(
329
  fn=update_scene,
330
- inputs=[game_choices],
331
  outputs=[game_text, game_image, game_choices],
332
  )
333
 
334
  demo.unload(cleanup_music_session)
335
  demo.load(
336
  fn=update_audio,
337
- inputs=None,
338
  outputs=[audio_out],
339
  )
340
 
 
32
  )
33
 
34
 
35
+ async def update_scene(user_hash: str, choice):
36
  logger.info(f"Updating scene with choice: {choice}")
37
  if isinstance(choice, str):
38
  old_scene = state["scene"]
 
61
  story[new_scene]["image"] = img_path
62
 
63
  if response.change_music.change_music:
64
+ await change_music_tone(user_hash, response.change_music.music_description)
65
 
66
  scene = story[state["scene"]]
67
  return (
 
92
 
93
 
94
  async def start_game_with_music(
95
+ user_hash: str,
96
  setting_desc: str,
97
  char_name: str,
98
  char_age: str,
 
113
 
114
  # First, get the game interface updates
115
  result = await start_game_with_settings(
116
+ user_hash,
117
  setting_desc,
118
  char_name,
119
  char_age,
 
132
  # Fullscreen Loading Indicator (hidden by default)
133
  with gr.Column(visible=False, elem_id="loading-indicator") as loading_indicator:
134
  gr.HTML("<div class='loading-text'>๐Ÿš€ Starting your adventure...</div>")
135
+
136
+ local_storage = gr.BrowserState(str(uuid.uuid4()), "user_hash")
137
 
138
  # Constructor Interface (visible by default)
139
  with gr.Column(
 
298
  start_btn.click(
299
  fn=start_game_with_music,
300
  inputs=[
301
+ local_storage,
302
  setting_description,
303
  char_name,
304
  char_age,
 
330
 
331
  game_choices.change(
332
  fn=update_scene,
333
+ inputs=[local_storage, game_choices],
334
  outputs=[game_text, game_image, game_choices],
335
  )
336
 
337
  demo.unload(cleanup_music_session)
338
  demo.load(
339
  fn=update_audio,
340
+ inputs=[local_storage],
341
  outputs=[audio_out],
342
  )
343