kikikita commited on
Commit
21eb680
·
1 Parent(s): 0a18f7d

feat: implement Google API key management and refactor client usage in audio and image generation

Browse files
src/agent/llm.py CHANGED
@@ -4,31 +4,17 @@ import logging
4
  from langchain_google_genai import ChatGoogleGenerativeAI
5
 
6
  from config import settings
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
- _API_KEYS: list[str] = []
11
- _current_key_idx = 0
12
  MODEL_NAME = "gemini-2.5-flash-preview-05-20"
13
 
14
 
15
  def _get_api_key() -> str:
16
- """Return an API key using round-robin selection."""
17
- global _API_KEYS, _current_key_idx
18
-
19
- if not _API_KEYS:
20
- keys_str = settings.gemini_api_key.get_secret_value()
21
- if keys_str:
22
- _API_KEYS = [k.strip() for k in keys_str.split(",") if k.strip()]
23
- if not _API_KEYS:
24
- msg = "Google API keys are not configured or invalid"
25
- logger.error(msg)
26
- raise ValueError(msg)
27
-
28
- key = _API_KEYS[_current_key_idx]
29
- _current_key_idx = (_current_key_idx + 1) % len(_API_KEYS)
30
- logger.debug("Using Google API key index %s", _current_key_idx)
31
- return key
32
 
33
 
34
  def create_llm(
 
4
  from langchain_google_genai import ChatGoogleGenerativeAI
5
 
6
  from config import settings
7
+ from services.google import ApiKeyPool
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
+ _pool = ApiKeyPool()
 
12
  MODEL_NAME = "gemini-2.5-flash-preview-05-20"
13
 
14
 
15
  def _get_api_key() -> str:
16
+ """Return an API key using round-robin selection in a thread-safe way."""
17
+ return _pool.get_key_sync()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def create_llm(
src/agent/redis_state.py CHANGED
@@ -5,9 +5,12 @@ from __future__ import annotations
5
  import json
6
  import msgpack
7
  import redis.asyncio as redis
 
8
 
9
  from agent.models import UserState
10
 
 
 
11
 
12
  class UserRepository:
13
  """Repository for storing UserState objects in Redis."""
@@ -18,6 +21,7 @@ class UserRepository:
18
  async def get(self, user_id: str) -> UserState:
19
  """Return user state for the given id, creating it if absent."""
20
  key = f"llmgamehub:{user_id}"
 
21
  data = await self.redis.hget(key, "data")
22
  if data is None:
23
  return UserState()
@@ -27,12 +31,14 @@ class UserRepository:
27
  async def set(self, user_id: str, state: UserState) -> None:
28
  """Persist updated user state."""
29
  key = f"llmgamehub:{user_id}"
 
30
  packed = msgpack.packb(json.loads(state.json()))
31
  await self.redis.hset(key, mapping={"data": packed})
32
 
33
  async def reset(self, user_id: str) -> None:
34
  """Remove stored state for a user."""
35
  key = f"llmgamehub:{user_id}"
 
36
  await self.redis.delete(key)
37
 
38
 
@@ -40,12 +46,15 @@ _repo = UserRepository()
40
 
41
 
42
  async def get_user_state(user_hash: str) -> UserState:
 
43
  return await _repo.get(user_hash)
44
 
45
 
46
  async def set_user_state(user_hash: str, state: UserState) -> None:
 
47
  await _repo.set(user_hash, state)
48
 
49
 
50
  async def reset_user_state(user_hash: str) -> None:
 
51
  await _repo.reset(user_hash)
 
5
  import json
6
  import msgpack
7
  import redis.asyncio as redis
8
+ import logging
9
 
10
  from agent.models import UserState
11
 
12
+ logger = logging.getLogger(__name__)
13
+
14
 
15
  class UserRepository:
16
  """Repository for storing UserState objects in Redis."""
 
21
  async def get(self, user_id: str) -> UserState:
22
  """Return user state for the given id, creating it if absent."""
23
  key = f"llmgamehub:{user_id}"
24
+ logger.debug("Fetching state for %s", user_id)
25
  data = await self.redis.hget(key, "data")
26
  if data is None:
27
  return UserState()
 
31
  async def set(self, user_id: str, state: UserState) -> None:
32
  """Persist updated user state."""
33
  key = f"llmgamehub:{user_id}"
34
+ logger.debug("Saving state for %s", user_id)
35
  packed = msgpack.packb(json.loads(state.json()))
36
  await self.redis.hset(key, mapping={"data": packed})
37
 
38
  async def reset(self, user_id: str) -> None:
39
  """Remove stored state for a user."""
40
  key = f"llmgamehub:{user_id}"
41
+ logger.debug("Resetting state for %s", user_id)
42
  await self.redis.delete(key)
43
 
44
 
 
46
 
47
 
48
  async def get_user_state(user_hash: str) -> UserState:
49
+ logger.debug("get_user_state for %s", user_hash)
50
  return await _repo.get(user_hash)
51
 
52
 
53
  async def set_user_state(user_hash: str, state: UserState) -> None:
54
+ logger.debug("set_user_state for %s", user_hash)
55
  await _repo.set(user_hash, state)
56
 
57
 
58
  async def reset_user_state(user_hash: str) -> None:
59
+ logger.debug("reset_user_state for %s", user_hash)
60
  await _repo.reset(user_hash)
src/audio/audio_generator.py CHANGED
@@ -1,7 +1,5 @@
1
  import asyncio
2
- from google import genai
3
  from google.genai import types
4
- from config import settings
5
  import wave
6
  import queue
7
  import logging
@@ -10,34 +8,41 @@ import time
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
- client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'})
14
 
15
  async def generate_music(user_hash: str, music_tone: str, receive_audio):
16
  if user_hash in sessions:
17
  logger.info(f"Music generation already started for user hash {user_hash}, skipping new generation")
18
  return
19
- async with (
20
- client.aio.live.music.connect(model='models/lyria-realtime-exp') as session,
21
- asyncio.TaskGroup() as tg,
22
- ):
23
- # Set up task to receive server messages.
24
- tg.create_task(receive_audio(session, user_hash))
 
25
 
26
- # Send initial prompts and config
27
- await session.set_weighted_prompts(
28
- prompts=[
29
- types.WeightedPrompt(text=music_tone, weight=1.0),
30
- ]
31
- )
32
- await session.set_music_generation_config(
33
- config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
34
- )
35
- await session.play()
36
- logger.info(f"Started music generation for user hash {user_hash}, music tone: {music_tone}")
37
- sessions[user_hash] = {
38
- 'session': session,
39
- 'queue': queue.Queue()
40
- }
 
 
 
 
 
 
41
 
42
  async def change_music_tone(user_hash: str, new_tone):
43
  logger.info(f"Changing music tone to {new_tone}")
@@ -45,8 +50,11 @@ async def change_music_tone(user_hash: str, new_tone):
45
  if not session:
46
  logger.error(f"No session found for user hash {user_hash}")
47
  return
48
- await session.set_weighted_prompts(
49
- prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
 
 
 
50
  )
51
 
52
 
@@ -79,8 +87,8 @@ async def cleanup_music_session(user_hash: str):
79
  if user_hash in sessions:
80
  logger.info(f"Cleaning up music session for user hash {user_hash}")
81
  session = sessions[user_hash]['session']
82
- await session.stop()
83
- await session.close()
84
  del sessions[user_hash]
85
 
86
 
@@ -122,4 +130,4 @@ def update_audio(user_hash):
122
  wf.setframerate(SAMPLE_RATE)
123
  wf.writeframes(pcm_data)
124
  wav_bytes = wav_buffer.getvalue()
125
- yield wav_bytes
 
1
  import asyncio
 
2
  from google.genai import types
 
3
  import wave
4
  import queue
5
  import logging
 
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
+ from services.google import GoogleClientFactory
12
 
13
  async def generate_music(user_hash: str, music_tone: str, receive_audio):
14
  if user_hash in sessions:
15
  logger.info(f"Music generation already started for user hash {user_hash}, skipping new generation")
16
  return
17
+ async with GoogleClientFactory.audio() as client:
18
+ async with (
19
+ client.live.music.connect(model='models/lyria-realtime-exp') as session,
20
+ asyncio.TaskGroup() as tg,
21
+ ):
22
+ # Set up task to receive server messages.
23
+ tg.create_task(receive_audio(session, user_hash))
24
 
25
+ # Send initial prompts and config
26
+ await asyncio.wait_for(
27
+ session.set_weighted_prompts(
28
+ prompts=[types.WeightedPrompt(text=music_tone, weight=1.0)]
29
+ ),
30
+ 40,
31
+ )
32
+ await asyncio.wait_for(
33
+ session.set_music_generation_config(
34
+ config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
35
+ ),
36
+ 40,
37
+ )
38
+ await asyncio.wait_for(session.play(), 40)
39
+ logger.info(
40
+ f"Started music generation for user hash {user_hash}, music tone: {music_tone}"
41
+ )
42
+ sessions[user_hash] = {
43
+ 'session': session,
44
+ 'queue': queue.Queue()
45
+ }
46
 
47
  async def change_music_tone(user_hash: str, new_tone):
48
  logger.info(f"Changing music tone to {new_tone}")
 
50
  if not session:
51
  logger.error(f"No session found for user hash {user_hash}")
52
  return
53
+ await asyncio.wait_for(
54
+ session.set_weighted_prompts(
55
+ prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
56
+ ),
57
+ 40,
58
  )
59
 
60
 
 
87
  if user_hash in sessions:
88
  logger.info(f"Cleaning up music session for user hash {user_hash}")
89
  session = sessions[user_hash]['session']
90
+ await asyncio.wait_for(session.stop(), 40)
91
+ await asyncio.wait_for(session.close(), 40)
92
  del sessions[user_hash]
93
 
94
 
 
130
  wf.setframerate(SAMPLE_RATE)
131
  wf.writeframes(pcm_data)
132
  wav_bytes = wav_buffer.getvalue()
133
+ yield wav_bytes
src/images/image_generator.py CHANGED
@@ -1,17 +1,15 @@
1
- from google import genai
2
  from google.genai import types
3
  import os
4
  from PIL import Image
5
  from io import BytesIO
6
  from datetime import datetime
7
- from config import settings
8
  import logging
9
  import asyncio
10
  import gradio as gr
11
 
12
- logger = logging.getLogger(__name__)
13
 
14
- client = genai.Client(api_key=settings.gemini_api_key.get_secret_value()).aio
15
 
16
  safety_settings = [
17
  types.SafetySetting(
@@ -50,14 +48,18 @@ async def generate_image(prompt: str) -> tuple[str, str] | None:
50
  logger.info(f"Generating image with prompt: {prompt}")
51
 
52
  try:
53
- response = await client.models.generate_content(
54
- model="gemini-2.0-flash-preview-image-generation",
55
- contents=prompt,
56
- config=types.GenerateContentConfig(
57
- response_modalities=["TEXT", "IMAGE"],
58
- safety_settings=safety_settings,
59
- ),
60
- )
 
 
 
 
61
 
62
  # Process the response parts
63
  image_saved = False
@@ -108,23 +110,23 @@ async def modify_image(image_path: str, modification_prompt: str) -> str | None:
108
  logger.error(f"Error: Image file not found at {image_path}")
109
  return None
110
 
111
- key = settings.gemini_api_key.get_secret_value()
112
-
113
- client = genai.Client(api_key=key).aio
114
-
115
  try:
116
- # Load the input image
117
- input_image = Image.open(image_path)
118
-
119
- # Make the API call with both text and image
120
- response = await client.models.generate_content(
121
- model="gemini-2.0-flash-preview-image-generation",
122
- contents=[modification_prompt, input_image],
123
- config=types.GenerateContentConfig(
124
- response_modalities=["TEXT", "IMAGE"],
125
- safety_settings=safety_settings,
126
- ),
127
- )
 
 
 
 
128
 
129
  # Process the response parts
130
  image_saved = False
 
 
1
  from google.genai import types
2
  import os
3
  from PIL import Image
4
  from io import BytesIO
5
  from datetime import datetime
 
6
  import logging
7
  import asyncio
8
  import gradio as gr
9
 
10
+ from services.google import GoogleClientFactory
11
 
12
+ logger = logging.getLogger(__name__)
13
 
14
  safety_settings = [
15
  types.SafetySetting(
 
48
  logger.info(f"Generating image with prompt: {prompt}")
49
 
50
  try:
51
+ async with GoogleClientFactory.image() as client:
52
+ response = await asyncio.wait_for(
53
+ client.models.generate_content(
54
+ model="gemini-2.0-flash-preview-image-generation",
55
+ contents=prompt,
56
+ config=types.GenerateContentConfig(
57
+ response_modalities=["TEXT", "IMAGE"],
58
+ safety_settings=safety_settings,
59
+ ),
60
+ ),
61
+ 40,
62
+ )
63
 
64
  # Process the response parts
65
  image_saved = False
 
110
  logger.error(f"Error: Image file not found at {image_path}")
111
  return None
112
 
 
 
 
 
113
  try:
114
+ async with GoogleClientFactory.image() as client:
115
+ # Load the input image
116
+ input_image = Image.open(image_path)
117
+
118
+ # Make the API call with both text and image
119
+ response = await asyncio.wait_for(
120
+ client.models.generate_content(
121
+ model="gemini-2.0-flash-preview-image-generation",
122
+ contents=[modification_prompt, input_image],
123
+ config=types.GenerateContentConfig(
124
+ response_modalities=["TEXT", "IMAGE"],
125
+ safety_settings=safety_settings,
126
+ ),
127
+ ),
128
+ 40,
129
+ )
130
 
131
  # Process the response parts
132
  image_saved = False
src/main.py CHANGED
@@ -345,13 +345,6 @@ with gr.Blocks(
345
  ],
346
  )
347
 
348
- game_choices.change(
349
- fn=update_scene,
350
- inputs=[local_storage, game_choices],
351
- outputs=[game_text, game_image, game_choices, custom_choice],
352
- concurrency_limit=CONCURRENCY_LIMIT,
353
- )
354
-
355
  custom_choice.submit(
356
  fn=update_scene,
357
  inputs=[local_storage, custom_choice],
@@ -367,9 +360,10 @@ with gr.Blocks(
367
  )
368
  local_storage.change(
369
  fn=update_audio,
370
- inputs=[local_storage],
371
  outputs=[audio_out],
372
  concurrency_limit=CONCURRENCY_LIMIT,
373
  )
374
 
 
375
  demo.launch(ssr_mode=False)
 
345
  ],
346
  )
347
 
 
 
 
 
 
 
 
348
  custom_choice.submit(
349
  fn=update_scene,
350
  inputs=[local_storage, custom_choice],
 
360
  )
361
  local_storage.change(
362
  fn=update_audio,
363
+ inputs=[],
364
  outputs=[audio_out],
365
  concurrency_limit=CONCURRENCY_LIMIT,
366
  )
367
 
368
+ demo.queue()
369
  demo.launch(ssr_mode=False)
src/services/google.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ from contextlib import asynccontextmanager
4
+ from google import genai
5
+ import threading
6
+
7
+ from config import settings
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class ApiKeyPool:
13
+ """Manage Google API keys with round-robin selection."""
14
+
15
+ def __init__(self) -> None:
16
+ self._keys: list[str] | None = None
17
+ self._index = 0
18
+ self._lock = asyncio.Lock()
19
+ self._sync_lock = threading.Lock()
20
+
21
+ def _load_keys(self) -> None:
22
+ keys_raw = (
23
+ getattr(settings, "gemini_api_keys", None) or settings.gemini_api_key
24
+ )
25
+ keys_str = keys_raw.get_secret_value()
26
+ keys = [k.strip() for k in keys_str.split(',') if k.strip()] if keys_str else []
27
+ if not keys:
28
+ msg = "Google API keys are not configured or invalid"
29
+ logger.error(msg)
30
+ raise ValueError(msg)
31
+ self._keys = keys
32
+
33
+ async def get_key(self) -> str:
34
+ async with self._lock:
35
+ if self._keys is None:
36
+ self._load_keys()
37
+ key = self._keys[self._index]
38
+ self._index = (self._index + 1) % len(self._keys)
39
+ logger.debug("Using Google API key index %s", self._index)
40
+ return key
41
+
42
+ def get_key_sync(self) -> str:
43
+ """Synchronous helper for environments without an event loop."""
44
+ with self._sync_lock:
45
+ if self._keys is None:
46
+ self._load_keys()
47
+ key = self._keys[self._index]
48
+ self._index = (self._index + 1) % len(self._keys)
49
+ logger.debug("Using Google API key index %s", self._index)
50
+ return key
51
+
52
+
53
+ class GoogleClientFactory:
54
+ """Factory for thread-safe creation of Google GenAI clients."""
55
+
56
+ _pool = ApiKeyPool()
57
+
58
+ @classmethod
59
+ @asynccontextmanager
60
+ async def image(cls):
61
+ key = await cls._pool.get_key()
62
+ client = genai.Client(api_key=key)
63
+ try:
64
+ yield client.aio
65
+ finally:
66
+ pass
67
+
68
+ @classmethod
69
+ @asynccontextmanager
70
+ async def audio(cls):
71
+ key = await cls._pool.get_key()
72
+ client = genai.Client(api_key=key, http_options={"api_version": "v1alpha"})
73
+ try:
74
+ yield client.aio
75
+ finally:
76
+ pass