Georgii Savin commited on
Commit
60e195a
·
unverified ·
1 Parent(s): 988c006

feat: use more precise prompts for music and image generation

Browse files
src/agent/image_agent.py CHANGED
@@ -2,6 +2,8 @@ from pydantic import BaseModel, Field
2
  from typing import Literal, Optional
3
  from agent.llm import create_light_llm
4
  from langchain_core.messages import SystemMessage, HumanMessage
 
 
5
  import logging
6
 
7
  logger = logging.getLogger(__name__)
@@ -66,12 +68,24 @@ class ChangeScene(BaseModel):
66
 
67
  image_prompt_generator_llm = create_light_llm(0.1).with_structured_output(ChangeScene)
68
 
69
- async def generate_image_prompt(scene_description: str, request_id: str) -> ChangeScene:
70
  """
71
  Generates a detailed image prompt string based on a scene description.
72
  This prompt is intended for use with an AI image generation model.
73
  """
74
  logger.info(f"Generating image prompt for the current scene: {request_id}")
 
 
 
 
 
 
 
 
 
 
 
 
75
  response = await image_prompt_generator_llm.ainvoke(
76
  [
77
  SystemMessage(content=IMAGE_GENERATION_SYSTEM_PROMPT),
 
2
  from typing import Literal, Optional
3
  from agent.llm import create_light_llm
4
  from langchain_core.messages import SystemMessage, HumanMessage
5
+ from agent.state import get_user_state, set_user_state
6
+ from agent.prompts import GAME_STATE_PROMPT
7
  import logging
8
 
9
  logger = logging.getLogger(__name__)
 
68
 
69
  image_prompt_generator_llm = create_light_llm(0.1).with_structured_output(ChangeScene)
70
 
71
+ async def generate_image_prompt(user_hash: str, scene_description: str) -> ChangeScene:
72
  """
73
  Generates a detailed image prompt string based on a scene description.
74
  This prompt is intended for use with an AI image generation model.
75
  """
76
  logger.info(f"Generating image prompt for the current scene: {request_id}")
77
+
78
+ state = get_user_state(user_hash)
79
+ scene = GAME_STATE_PROMPT.format(
80
+ lore=state.story_frame.lore,
81
+ goal=state.story_frame.goal,
82
+ milestones=",".join(m.id for m in state.story_frame.milestones),
83
+ endings=",".join(e.id for e in state.story_frame.endings),
84
+ history="; ".join(f"{c.scene_id}:{c.choice_text}" for c in state.user_choices),
85
+ last_choice=last_choice,
86
+ scene_description=scene_description
87
+ )
88
+
89
  response = await image_prompt_generator_llm.ainvoke(
90
  [
91
  SystemMessage(content=IMAGE_GENERATION_SYSTEM_PROMPT),
src/agent/llm_graph.py CHANGED
@@ -60,7 +60,7 @@ async def node_init_game(state: GraphState) -> GraphState:
60
  first_scene = await generate_scene.ainvoke(
61
  {"user_hash": state.user_hash, "last_choice": "start"}
62
  )
63
- change_scene = await generate_image_prompt(first_scene["description"], state.user_hash)
64
  logger.info(f"Change scene: {change_scene}")
65
  await generate_scene_image.ainvoke(
66
  {
@@ -94,7 +94,7 @@ async def node_player_step(state: GraphState) -> GraphState:
94
  "last_choice": state.choice_text,
95
  }
96
  )
97
- change_scene = await generate_image_prompt(next_scene["description"], state.user_hash)
98
  current_image = None
99
  if scene_id and scene_id in user_state.scenes:
100
  current_image = user_state.scenes[scene_id].image
@@ -107,8 +107,9 @@ async def node_player_step(state: GraphState) -> GraphState:
107
  "change_scene": change_scene,
108
  }
109
  )
110
- music_task = change_music_tone(state.user_hash, next_scene["music"])
111
- await asyncio.gather(image_task, music_task)
 
112
  state.scene = next_scene
113
  return state
114
 
 
60
  first_scene = await generate_scene.ainvoke(
61
  {"user_hash": state.user_hash, "last_choice": "start"}
62
  )
63
+ change_scene = await generate_image_prompt(state.user_hash, first_scene["description"])
64
  logger.info(f"Change scene: {change_scene}")
65
  await generate_scene_image.ainvoke(
66
  {
 
94
  "last_choice": state.choice_text,
95
  }
96
  )
97
+ change_scene = await generate_image_prompt(state.user_hash, next_scene["description"])
98
  current_image = None
99
  if scene_id and scene_id in user_state.scenes:
100
  current_image = user_state.scenes[scene_id].image
 
107
  "change_scene": change_scene,
108
  }
109
  )
110
+ music_task = generate_music_prompt(next_scene["description"])
111
+ _, music_prompt = await asyncio.gather(image_task, music_task)
112
+ asyncio.create_task(change_music_tone(state.user_hash, music_prompt))
113
  state.scene = next_scene
114
  return state
115
 
src/agent/music_agent.py CHANGED
@@ -1,7 +1,9 @@
1
  from pydantic import BaseModel
2
  from agent.llm import create_light_llm
 
3
  from langchain_core.messages import SystemMessage, HumanMessage
4
  import logging
 
5
 
6
  logger = logging.getLogger(__name__)
7
 
@@ -38,10 +40,22 @@ class MusicPrompt(BaseModel):
38
  llm = create_light_llm(0.1).with_structured_output(MusicPrompt)
39
 
40
 
41
- async def generate_music_prompt(scene_description: str, request_id: str) -> str:
42
- logger.info(f"Generating music prompt for the current scene: {request_id}")
 
 
 
 
 
 
 
 
 
 
 
 
43
  response = await llm.ainvoke(
44
- [SystemMessage(content=system_prompt), HumanMessage(content=scene_description)]
45
  )
46
- logger.info(f"Music prompt generated: {request_id}")
47
  return response.prompt
 
1
  from pydantic import BaseModel
2
  from agent.llm import create_light_llm
3
+ from agent.prompts import GAME_STATE_PROMPT
4
  from langchain_core.messages import SystemMessage, HumanMessage
5
  import logging
6
+ from agent.state import get_user_state, set_user_state
7
 
8
  logger = logging.getLogger(__name__)
9
 
 
40
  llm = create_light_llm(0.1).with_structured_output(MusicPrompt)
41
 
42
 
43
+ async def generate_music_prompt(user_hash: str, scene_description: str) -> str:
44
+ logger.info(f"Generating music prompt for the current scene: {scene_description}")
45
+
46
+ state = get_user_state(user_hash)
47
+ scene = GAME_STATE_PROMPT.format(
48
+ lore=state.story_frame.lore,
49
+ goal=state.story_frame.goal,
50
+ milestones=",".join(m.id for m in state.story_frame.milestones),
51
+ endings=",".join(e.id for e in state.story_frame.endings),
52
+ history="; ".join(f"{c.scene_id}:{c.choice_text}" for c in state.user_choices),
53
+ last_choice=last_choice,
54
+ scene_description=scene_description
55
+ )
56
+
57
  response = await llm.ainvoke(
58
+ [SystemMessage(content=system_prompt), HumanMessage(content=scene)]
59
  )
60
+ logger.info(f"Music prompt generated")
61
  return response.prompt
src/agent/prompts.py CHANGED
@@ -13,6 +13,22 @@ Return ONLY a JSON object with:
13
  Translate the lore, goal, milestones and endings to the language which is used in the game and setting description.
14
  """
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  SCENE_PROMPT = """You are an AI agent for a visual novel game.
17
  Your role is to process incoming data and generate the next scene description and choices.
18
  Translate the scene description and choices into a language which is used in the Game Settings.
 
13
  Translate the lore, goal, milestones and endings to the language which is used in the game and setting description.
14
  """
15
 
16
+ GAME_STATE_PROMPT = """
17
+ ---Game Settings START---
18
+ Lore: {lore}
19
+ Goal: {goal}
20
+ Milestones: {milestones}
21
+ Endings: {endings}
22
+ ---Game Settings END---
23
+
24
+ ---User's actions START---
25
+ History: {history}
26
+ Last choice: {last_choice}
27
+ ---User's actions END---
28
+
29
+ Game response to user's action: {scene_description}
30
+ """
31
+
32
  SCENE_PROMPT = """You are an AI agent for a visual novel game.
33
  Your role is to process incoming data and generate the next scene description and choices.
34
  Translate the scene description and choices into a language which is used in the Game Settings.
src/agent/runner.py CHANGED
@@ -56,7 +56,7 @@ async def process_step(
56
  ending_desc = ending_info.get("description") or ending_info.get(
57
  "condition", ""
58
  )
59
- change_scene = await generate_image_prompt(ending_desc, user_hash)
60
  if change_scene.change_scene == "no_change":
61
  change_scene.change_scene = "change_completely"
62
  if not change_scene.scene_description:
 
56
  ending_desc = ending_info.get("description") or ending_info.get(
57
  "condition", ""
58
  )
59
+ change_scene = await generate_image_prompt(user_hash, ending_desc)
60
  if change_scene.change_scene == "no_change":
61
  change_scene.change_scene = "change_completely"
62
  if not change_scene.scene_description: