gsavin commited on
Commit
85d7f84
·
1 Parent(s): f42cab1

feat: add util for LLM retries

Browse files
src/agent/image_agent.py CHANGED
@@ -2,9 +2,10 @@ from pydantic import BaseModel, Field
2
  from typing import Literal, Optional
3
  from agent.llm import create_light_llm
4
  from langchain_core.messages import SystemMessage, HumanMessage
5
- from agent.redis_state import get_user_state, set_user_state
6
  from agent.prompts import GAME_STATE_PROMPT
7
  import logging
 
8
 
9
  logger = logging.getLogger(__name__)
10
 
@@ -86,11 +87,11 @@ async def generate_image_prompt(user_hash: str, scene_description: str, last_cho
86
  scene_description=scene_description
87
  )
88
 
89
- response = await image_prompt_generator_llm.ainvoke(
90
  [
91
  SystemMessage(content=IMAGE_GENERATION_SYSTEM_PROMPT),
92
  HumanMessage(content=scene),
93
  ]
94
- )
95
  logger.info(f"Image prompt generated")
96
  return response
 
2
  from typing import Literal, Optional
3
  from agent.llm import create_light_llm
4
  from langchain_core.messages import SystemMessage, HumanMessage
5
+ from agent.redis_state import get_user_state
6
  from agent.prompts import GAME_STATE_PROMPT
7
  import logging
8
+ from agent.utils import with_retries
9
 
10
  logger = logging.getLogger(__name__)
11
 
 
87
  scene_description=scene_description
88
  )
89
 
90
+ response = await with_retries(lambda: image_prompt_generator_llm.ainvoke(
91
  [
92
  SystemMessage(content=IMAGE_GENERATION_SYSTEM_PROMPT),
93
  HumanMessage(content=scene),
94
  ]
95
+ ))
96
  logger.info(f"Image prompt generated")
97
  return response
src/agent/music_agent.py CHANGED
@@ -3,7 +3,8 @@ from agent.llm import create_light_llm
3
  from agent.prompts import GAME_STATE_PROMPT
4
  from langchain_core.messages import SystemMessage, HumanMessage
5
  import logging
6
- from agent.redis_state import get_user_state, set_user_state
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
@@ -54,8 +55,8 @@ async def generate_music_prompt(user_hash: str, scene_description: str, last_cho
54
  scene_description=scene_description
55
  )
56
 
57
- response = await llm.ainvoke(
58
  [SystemMessage(content=system_prompt), HumanMessage(content=scene)]
59
- )
60
- logger.info(f"Music prompt generated")
61
  return response.prompt
 
3
  from agent.prompts import GAME_STATE_PROMPT
4
  from langchain_core.messages import SystemMessage, HumanMessage
5
  import logging
6
+ from agent.redis_state import get_user_state
7
+ from agent.utils import with_retries
8
 
9
  logger = logging.getLogger(__name__)
10
 
 
55
  scene_description=scene_description
56
  )
57
 
58
+ response = await with_retries(lambda: llm.ainvoke(
59
  [SystemMessage(content=system_prompt), HumanMessage(content=scene)]
60
+ ))
61
+ logger.info("Music prompt generated")
62
  return response.prompt
src/agent/tools.py CHANGED
@@ -18,6 +18,7 @@ from agent.models import (
18
  )
19
  from agent.prompts import ENDING_CHECK_PROMPT, SCENE_PROMPT, STORY_FRAME_PROMPT
20
  from agent.redis_state import get_user_state, set_user_state
 
21
  from images.image_generator import modify_image, generate_image
22
  from agent.image_agent import ChangeScene
23
 
@@ -43,7 +44,7 @@ async def generate_story_frame(
43
  character=character,
44
  genre=genre,
45
  )
46
- resp: StoryFrameLLM = await llm.ainvoke(prompt)
47
  story_frame = StoryFrame(
48
  lore=resp.lore,
49
  goal=resp.goal,
@@ -77,10 +78,10 @@ async def generate_scene(
77
  history="; ".join(f"{c.scene_id}:{c.choice_text}" for c in state.user_choices),
78
  last_choice=last_choice,
79
  )
80
- resp: SceneLLM = await llm.ainvoke(prompt)
81
  if len(resp.choices) < 2:
82
- resp = await llm.ainvoke(
83
- prompt + "\nThe scene must contain exactly two choices."
84
  )
85
  scene_id = str(uuid.uuid4())
86
  choices = [
@@ -163,7 +164,7 @@ async def check_ending(
163
  history=history,
164
  endings=",".join(f"{e.id}:{e.condition}" for e in state.story_frame.endings),
165
  )
166
- resp: EndingCheckResult = await llm.ainvoke(prompt)
167
  if resp.ending_reached and resp.ending:
168
  state.ending = resp.ending
169
  await set_user_state(user_hash, state)
 
18
  )
19
  from agent.prompts import ENDING_CHECK_PROMPT, SCENE_PROMPT, STORY_FRAME_PROMPT
20
  from agent.redis_state import get_user_state, set_user_state
21
+ from agent.utils import with_retries
22
  from images.image_generator import modify_image, generate_image
23
  from agent.image_agent import ChangeScene
24
 
 
44
  character=character,
45
  genre=genre,
46
  )
47
+ resp: StoryFrameLLM = await with_retries(lambda: llm.ainvoke(prompt))
48
  story_frame = StoryFrame(
49
  lore=resp.lore,
50
  goal=resp.goal,
 
78
  history="; ".join(f"{c.scene_id}:{c.choice_text}" for c in state.user_choices),
79
  last_choice=last_choice,
80
  )
81
+ resp: SceneLLM = await with_retries(lambda: llm.ainvoke(prompt))
82
  if len(resp.choices) < 2:
83
+ resp = await with_retries(
84
+ lambda: llm.ainvoke(prompt + "\nThe scene must contain exactly two choices.")
85
  )
86
  scene_id = str(uuid.uuid4())
87
  choices = [
 
164
  history=history,
165
  endings=",".join(f"{e.id}:{e.condition}" for e in state.story_frame.endings),
166
  )
167
+ resp: EndingCheckResult = await with_retries(lambda: llm.ainvoke(prompt))
168
  if resp.ending_reached and resp.ending:
169
  state.ending = resp.ending
170
  await set_user_state(user_hash, state)
src/agent/utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Agent-related utilities."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from typing import Awaitable, Callable, TypeVar
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ T = TypeVar("T")
10
+
11
+
12
+ async def with_retries(
13
+ awaitable_factory: Callable[[], Awaitable[T]],
14
+ retries: int = 3,
15
+ timeout: int = 15,
16
+ ) -> T:
17
+ """Execute an awaitable with retries and timeout.
18
+
19
+ :param awaitable_factory: A function that returns an awaitable.
20
+ :param retries: Maximum number of retries.
21
+ :param timeout: Timeout in seconds for each attempt.
22
+ :return: The result of the awaitable.
23
+ """
24
+ last_exception = None
25
+ for attempt in range(retries):
26
+ try:
27
+ return await asyncio.wait_for(awaitable_factory(), timeout=timeout)
28
+ except Exception as e:
29
+ logger.warning(f"Attempt {attempt + 1}/{retries} failed with error: {e}")
30
+ last_exception = e
31
+ raise last_exception from last_exception
src/images/image_generator.py CHANGED
@@ -8,6 +8,7 @@ import asyncio
8
  import gradio as gr
9
  from config import settings
10
  from services.google import GoogleClientFactory
 
11
 
12
  logger = logging.getLogger(__name__)
13
 
@@ -49,16 +50,15 @@ async def generate_image(prompt: str) -> tuple[str, str] | None:
49
 
50
  try:
51
  async with GoogleClientFactory.image() as client:
52
- response = await asyncio.wait_for(
53
- client.models.generate_content(
54
  model="gemini-2.0-flash-preview-image-generation",
55
  contents=prompt,
56
  config=types.GenerateContentConfig(
57
  response_modalities=["TEXT", "IMAGE"],
58
  safety_settings=safety_settings,
59
  ),
60
- ),
61
- settings.request_timeout,
62
  )
63
 
64
  # Process the response parts
@@ -116,8 +116,8 @@ async def modify_image(image_path: str, modification_prompt: str) -> str | None:
116
  input_image = Image.open(image_path)
117
 
118
  # Make the API call with both text and image
119
- response = await asyncio.wait_for(
120
- client.models.generate_content(
121
  model="gemini-2.0-flash-preview-image-generation",
122
  contents=[modification_prompt, input_image],
123
  config=types.GenerateContentConfig(
@@ -125,7 +125,6 @@ async def modify_image(image_path: str, modification_prompt: str) -> str | None:
125
  safety_settings=safety_settings,
126
  ),
127
  ),
128
- settings.request_timeout,
129
  )
130
 
131
  # Process the response parts
 
8
  import gradio as gr
9
  from config import settings
10
  from services.google import GoogleClientFactory
11
+ from agent.utils import with_retries
12
 
13
  logger = logging.getLogger(__name__)
14
 
 
50
 
51
  try:
52
  async with GoogleClientFactory.image() as client:
53
+ response = await with_retries(
54
+ lambda: client.models.generate_content(
55
  model="gemini-2.0-flash-preview-image-generation",
56
  contents=prompt,
57
  config=types.GenerateContentConfig(
58
  response_modalities=["TEXT", "IMAGE"],
59
  safety_settings=safety_settings,
60
  ),
61
+ )
 
62
  )
63
 
64
  # Process the response parts
 
116
  input_image = Image.open(image_path)
117
 
118
  # Make the API call with both text and image
119
+ response = await with_retries(
120
+ lambda: client.models.generate_content(
121
  model="gemini-2.0-flash-preview-image-generation",
122
  contents=[modification_prompt, input_image],
123
  config=types.GenerateContentConfig(
 
125
  safety_settings=safety_settings,
126
  ),
127
  ),
 
128
  )
129
 
130
  # Process the response parts