gsavin commited on
Commit
45fabe9
·
1 Parent(s): 21eb680

Merge branch 'feature/unique-session-ids' of https://github.com/DeltaZN/gradio-mcp-hackaton into feature/unique-session-ids

Browse files
src/audio/audio_generator.py CHANGED
@@ -5,18 +5,23 @@ import queue
5
  import logging
6
  import io
7
  import time
 
 
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
- from services.google import GoogleClientFactory
 
12
 
13
  async def generate_music(user_hash: str, music_tone: str, receive_audio):
14
  if user_hash in sessions:
15
- logger.info(f"Music generation already started for user hash {user_hash}, skipping new generation")
 
 
16
  return
17
  async with GoogleClientFactory.audio() as client:
18
  async with (
19
- client.live.music.connect(model='models/lyria-realtime-exp') as session,
20
  asyncio.TaskGroup() as tg,
21
  ):
22
  # Set up task to receive server messages.
@@ -27,26 +32,24 @@ async def generate_music(user_hash: str, music_tone: str, receive_audio):
27
  session.set_weighted_prompts(
28
  prompts=[types.WeightedPrompt(text=music_tone, weight=1.0)]
29
  ),
30
- 40,
31
  )
32
  await asyncio.wait_for(
33
  session.set_music_generation_config(
34
  config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
35
- ),
36
- 40,
37
  )
38
- await asyncio.wait_for(session.play(), 40)
39
  logger.info(
40
  f"Started music generation for user hash {user_hash}, music tone: {music_tone}"
41
  )
42
- sessions[user_hash] = {
43
- 'session': session,
44
- 'queue': queue.Queue()
45
- }
46
-
47
  async def change_music_tone(user_hash: str, new_tone):
48
  logger.info(f"Changing music tone to {new_tone}")
49
- session = sessions.get(user_hash, {}).get('session')
50
  if not session:
51
  logger.error(f"No session found for user hash {user_hash}")
52
  return
@@ -54,14 +57,15 @@ async def change_music_tone(user_hash: str, new_tone):
54
  session.set_weighted_prompts(
55
  prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
56
  ),
57
- 40,
58
  )
59
-
60
 
61
  SAMPLE_RATE = 48000
62
  NUM_CHANNELS = 2 # Stereo
63
  SAMPLE_WIDTH = 2 # 16-bit audio -> 2 bytes per sample
64
 
 
65
  async def receive_audio(session, user_hash):
66
  """Process incoming audio from the music generation."""
67
  while True:
@@ -69,7 +73,7 @@ async def receive_audio(session, user_hash):
69
  async for message in session.receive():
70
  if message.server_content and message.server_content.audio_chunks:
71
  audio_data = message.server_content.audio_chunks[0].data
72
- queue = sessions[user_hash]['queue']
73
  # audio_data is already bytes (raw PCM)
74
  await asyncio.to_thread(queue.put, audio_data)
75
  await asyncio.sleep(10**-12)
@@ -77,42 +81,47 @@ async def receive_audio(session, user_hash):
77
  logger.error(f"Error in receive_audio: {e}")
78
  break
79
 
 
80
  sessions = {}
81
 
 
82
  async def start_music_generation(user_hash: str, music_tone: str):
83
  """Start the music generation in a separate thread."""
84
  await generate_music(user_hash, music_tone, receive_audio)
85
-
 
86
  async def cleanup_music_session(user_hash: str):
87
  if user_hash in sessions:
88
  logger.info(f"Cleaning up music session for user hash {user_hash}")
89
- session = sessions[user_hash]['session']
90
- await asyncio.wait_for(session.stop(), 40)
91
- await asyncio.wait_for(session.close(), 40)
92
  del sessions[user_hash]
93
-
94
 
95
  def update_audio(user_hash):
96
  """Continuously stream audio from the queue as WAV bytes."""
97
  if user_hash == "":
98
  return
99
-
100
  logger.info(f"Starting audio update loop for user hash: {user_hash}")
101
  while True:
102
  if user_hash not in sessions:
103
  time.sleep(0.5)
104
  continue
105
- queue = sessions[user_hash]['queue']
106
- pcm_data = queue.get() # This is raw PCM audio bytes
107
-
108
  if not isinstance(pcm_data, bytes):
109
- logger.warning(f"Expected bytes from audio_queue, got {type(pcm_data)}. Skipping.")
 
 
110
  continue
111
 
112
  # Lyria provides stereo, 16-bit PCM at 48kHz.
113
  # Ensure the number of bytes is consistent with stereo 16-bit audio.
114
  # Each frame = NUM_CHANNELS * SAMPLE_WIDTH bytes.
115
- # If len(pcm_data) is not a multiple of (NUM_CHANNELS * SAMPLE_WIDTH),
116
  # it might indicate an incomplete chunk or an issue.
117
  bytes_per_frame = NUM_CHANNELS * SAMPLE_WIDTH
118
  if len(pcm_data) % bytes_per_frame != 0:
@@ -121,12 +130,12 @@ def update_audio(user_hash):
121
  f"bytes_per_frame ({bytes_per_frame}). This might cause issues with WAV formatting."
122
  )
123
  # Depending on strictness, you might want to skip this chunk:
124
- # continue
125
 
126
  wav_buffer = io.BytesIO()
127
- with wave.open(wav_buffer, 'wb') as wf:
128
  wf.setnchannels(NUM_CHANNELS)
129
- wf.setsampwidth(SAMPLE_WIDTH) # Corresponds to 16-bit audio
130
  wf.setframerate(SAMPLE_RATE)
131
  wf.writeframes(pcm_data)
132
  wav_bytes = wav_buffer.getvalue()
 
5
  import logging
6
  import io
7
  import time
8
+ from config import settings
9
+ from services.google import GoogleClientFactory
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
+
14
+
15
 
16
  async def generate_music(user_hash: str, music_tone: str, receive_audio):
17
  if user_hash in sessions:
18
+ logger.info(
19
+ f"Music generation already started for user hash {user_hash}, skipping new generation"
20
+ )
21
  return
22
  async with GoogleClientFactory.audio() as client:
23
  async with (
24
+ client.live.music.connect(model="models/lyria-realtime-exp") as session,
25
  asyncio.TaskGroup() as tg,
26
  ):
27
  # Set up task to receive server messages.
 
32
  session.set_weighted_prompts(
33
  prompts=[types.WeightedPrompt(text=music_tone, weight=1.0)]
34
  ),
35
+ settings.request_timeout,
36
  )
37
  await asyncio.wait_for(
38
  session.set_music_generation_config(
39
  config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
40
+ ),
41
+ settings.request_timeout,
42
  )
43
+ await asyncio.wait_for(session.play(), settings.request_timeout)
44
  logger.info(
45
  f"Started music generation for user hash {user_hash}, music tone: {music_tone}"
46
  )
47
+ sessions[user_hash] = {"session": session, "queue": queue.Queue()}
48
+
49
+
 
 
50
  async def change_music_tone(user_hash: str, new_tone):
51
  logger.info(f"Changing music tone to {new_tone}")
52
+ session = sessions.get(user_hash, {}).get("session")
53
  if not session:
54
  logger.error(f"No session found for user hash {user_hash}")
55
  return
 
57
  session.set_weighted_prompts(
58
  prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
59
  ),
60
+ settings.request_timeout,
61
  )
62
+
63
 
64
  SAMPLE_RATE = 48000
65
  NUM_CHANNELS = 2 # Stereo
66
  SAMPLE_WIDTH = 2 # 16-bit audio -> 2 bytes per sample
67
 
68
+
69
  async def receive_audio(session, user_hash):
70
  """Process incoming audio from the music generation."""
71
  while True:
 
73
  async for message in session.receive():
74
  if message.server_content and message.server_content.audio_chunks:
75
  audio_data = message.server_content.audio_chunks[0].data
76
+ queue = sessions[user_hash]["queue"]
77
  # audio_data is already bytes (raw PCM)
78
  await asyncio.to_thread(queue.put, audio_data)
79
  await asyncio.sleep(10**-12)
 
81
  logger.error(f"Error in receive_audio: {e}")
82
  break
83
 
84
+
85
  sessions = {}
86
 
87
+
88
  async def start_music_generation(user_hash: str, music_tone: str):
89
  """Start the music generation in a separate thread."""
90
  await generate_music(user_hash, music_tone, receive_audio)
91
+
92
+
93
  async def cleanup_music_session(user_hash: str):
94
  if user_hash in sessions:
95
  logger.info(f"Cleaning up music session for user hash {user_hash}")
96
+ session = sessions[user_hash]["session"]
97
+ await asyncio.wait_for(session.stop(), settings.request_timeout)
98
+ await asyncio.wait_for(session.close(), settings.request_timeout)
99
  del sessions[user_hash]
100
+
101
 
102
  def update_audio(user_hash):
103
  """Continuously stream audio from the queue as WAV bytes."""
104
  if user_hash == "":
105
  return
106
+
107
  logger.info(f"Starting audio update loop for user hash: {user_hash}")
108
  while True:
109
  if user_hash not in sessions:
110
  time.sleep(0.5)
111
  continue
112
+ queue = sessions[user_hash]["queue"]
113
+ pcm_data = queue.get() # This is raw PCM audio bytes
114
+
115
  if not isinstance(pcm_data, bytes):
116
+ logger.warning(
117
+ f"Expected bytes from audio_queue, got {type(pcm_data)}. Skipping."
118
+ )
119
  continue
120
 
121
  # Lyria provides stereo, 16-bit PCM at 48kHz.
122
  # Ensure the number of bytes is consistent with stereo 16-bit audio.
123
  # Each frame = NUM_CHANNELS * SAMPLE_WIDTH bytes.
124
+ # If len(pcm_data) is not a multiple of (NUM_CHANNELS * SAMPLE_WIDTH),
125
  # it might indicate an incomplete chunk or an issue.
126
  bytes_per_frame = NUM_CHANNELS * SAMPLE_WIDTH
127
  if len(pcm_data) % bytes_per_frame != 0:
 
130
  f"bytes_per_frame ({bytes_per_frame}). This might cause issues with WAV formatting."
131
  )
132
  # Depending on strictness, you might want to skip this chunk:
133
+ # continue
134
 
135
  wav_buffer = io.BytesIO()
136
+ with wave.open(wav_buffer, "wb") as wf:
137
  wf.setnchannels(NUM_CHANNELS)
138
+ wf.setsampwidth(SAMPLE_WIDTH) # Corresponds to 16-bit audio
139
  wf.setframerate(SAMPLE_RATE)
140
  wf.writeframes(pcm_data)
141
  wav_bytes = wav_buffer.getvalue()
src/config.py CHANGED
@@ -29,6 +29,6 @@ class AppSettings(BaseAppSettings):
29
  top_p: float = 0.95
30
  temperature: float = 0.5
31
  pregenerate_next_scene: bool = True
32
-
33
 
34
  settings = AppSettings()
 
29
  top_p: float = 0.95
30
  temperature: float = 0.5
31
  pregenerate_next_scene: bool = True
32
+ request_timeout: int = 20
33
 
34
  settings = AppSettings()
src/images/image_generator.py CHANGED
@@ -6,7 +6,7 @@ from datetime import datetime
6
  import logging
7
  import asyncio
8
  import gradio as gr
9
-
10
  from services.google import GoogleClientFactory
11
 
12
  logger = logging.getLogger(__name__)
@@ -58,7 +58,7 @@ async def generate_image(prompt: str) -> tuple[str, str] | None:
58
  safety_settings=safety_settings,
59
  ),
60
  ),
61
- 40,
62
  )
63
 
64
  # Process the response parts
@@ -125,7 +125,7 @@ async def modify_image(image_path: str, modification_prompt: str) -> str | None:
125
  safety_settings=safety_settings,
126
  ),
127
  ),
128
- 40,
129
  )
130
 
131
  # Process the response parts
 
6
  import logging
7
  import asyncio
8
  import gradio as gr
9
+ from config import settings
10
  from services.google import GoogleClientFactory
11
 
12
  logger = logging.getLogger(__name__)
 
58
  safety_settings=safety_settings,
59
  ),
60
  ),
61
+ settings.request_timeout,
62
  )
63
 
64
  # Process the response parts
 
125
  safety_settings=safety_settings,
126
  ),
127
  ),
128
+ settings.request_timeout,
129
  )
130
 
131
  # Process the response parts
src/main.py CHANGED
@@ -136,7 +136,7 @@ with gr.Blocks(
136
  with gr.Column(visible=False, elem_id="loading-indicator") as loading_indicator:
137
  gr.HTML("<div class='loading-text'>🚀 Starting your adventure...</div>")
138
 
139
- local_storage = gr.BrowserState("", "user_hash")
140
 
141
  # Constructor Interface (visible by default)
142
  with gr.Column(
@@ -313,7 +313,7 @@ with gr.Blocks(
313
  start_btn.click(
314
  fn=start_game_with_music,
315
  inputs=[
316
- local_storage,
317
  setting_description,
318
  char_name,
319
  char_age,
@@ -330,13 +330,14 @@ with gr.Blocks(
330
  game_image,
331
  game_choices,
332
  custom_choice,
 
333
  ],
334
  concurrency_limit=CONCURRENCY_LIMIT,
335
  )
336
 
337
  back_btn.click(
338
  fn=return_to_constructor,
339
- inputs=[local_storage],
340
  outputs=[
341
  loading_indicator,
342
  constructor_interface,
@@ -347,7 +348,7 @@ with gr.Blocks(
347
 
348
  custom_choice.submit(
349
  fn=update_scene,
350
- inputs=[local_storage, custom_choice],
351
  outputs=[game_text, game_image, game_choices, custom_choice],
352
  concurrency_limit=CONCURRENCY_LIMIT,
353
  )
@@ -356,14 +357,15 @@ with gr.Blocks(
356
  demo.load(
357
  fn=generate_user_hash,
358
  inputs=[],
359
- outputs=[local_storage],
360
  )
361
- local_storage.change(
362
  fn=update_audio,
363
- inputs=[],
364
  outputs=[audio_out],
365
  concurrency_limit=CONCURRENCY_LIMIT,
366
  )
367
 
 
368
  demo.queue()
369
  demo.launch(ssr_mode=False)
 
136
  with gr.Column(visible=False, elem_id="loading-indicator") as loading_indicator:
137
  gr.HTML("<div class='loading-text'>🚀 Starting your adventure...</div>")
138
 
139
+ ls_user_hash = gr.BrowserState("", "user_hash")
140
 
141
  # Constructor Interface (visible by default)
142
  with gr.Column(
 
313
  start_btn.click(
314
  fn=start_game_with_music,
315
  inputs=[
316
+ ls_user_hash,
317
  setting_description,
318
  char_name,
319
  char_age,
 
330
  game_image,
331
  game_choices,
332
  custom_choice,
333
+ ls_user_hash,
334
  ],
335
  concurrency_limit=CONCURRENCY_LIMIT,
336
  )
337
 
338
  back_btn.click(
339
  fn=return_to_constructor,
340
+ inputs=[ls_user_hash],
341
  outputs=[
342
  loading_indicator,
343
  constructor_interface,
 
348
 
349
  custom_choice.submit(
350
  fn=update_scene,
351
+ inputs=[ls_user_hash, custom_choice],
352
  outputs=[game_text, game_image, game_choices, custom_choice],
353
  concurrency_limit=CONCURRENCY_LIMIT,
354
  )
 
357
  demo.load(
358
  fn=generate_user_hash,
359
  inputs=[],
360
+ outputs=[ls_user_hash],
361
  )
362
+ ls_user_hash.change(
363
  fn=update_audio,
364
+ inputs=[ls_user_hash],
365
  outputs=[audio_out],
366
  concurrency_limit=CONCURRENCY_LIMIT,
367
  )
368
 
369
+
370
  demo.queue()
371
  demo.launch(ssr_mode=False)