gsavin commited on
Commit
4310b90
·
1 Parent(s): d9ec72e

feat: improve image generation

Browse files
src/agent/llm.py CHANGED
@@ -12,7 +12,7 @@ def create_llm(temperature: float = settings.temperature, top_p: float = setting
12
  global _google_api_keys_list, _current_google_key_idx
13
 
14
  if not _google_api_keys_list:
15
- api_keys_str = settings.gemini_api_key.get_secret_value()
16
  if api_keys_str:
17
  _google_api_keys_list = [key.strip() for key in api_keys_str.split(',') if key.strip()]
18
 
@@ -38,6 +38,37 @@ def create_llm(temperature: float = settings.temperature, top_p: float = setting
38
  top_p=top_p,
39
  thinking_budget=1024
40
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  def create_precise_llm():
43
  return create_llm(temperature=0, top_p=1)
 
12
  global _google_api_keys_list, _current_google_key_idx
13
 
14
  if not _google_api_keys_list:
15
+ api_keys_str = settings.gemini_api_keys.get_secret_value()
16
  if api_keys_str:
17
  _google_api_keys_list = [key.strip() for key in api_keys_str.split(',') if key.strip()]
18
 
 
38
  top_p=top_p,
39
  thinking_budget=1024
40
  )
41
+
42
+
43
+ def create_light_llm(temperature: float = settings.temperature, top_p: float = settings.top_p):
44
+ global _google_api_keys_list, _current_google_key_idx
45
+
46
+ if not _google_api_keys_list:
47
+ api_keys_str = settings.gemini_api_keys.get_secret_value()
48
+ if api_keys_str:
49
+ _google_api_keys_list = [key.strip() for key in api_keys_str.split(',') if key.strip()]
50
+
51
+ if not _google_api_keys_list:
52
+ logger.error("Google API keys are not configured or are empty in settings.")
53
+ raise ValueError("Google API keys are not configured or are invalid for round-robin.")
54
+
55
+ if not _google_api_keys_list: # Safeguard, though previous block should handle it.
56
+ logger.error("No Google API keys available for round-robin.")
57
+ raise ValueError("No Google API keys available for round-robin.")
58
+
59
+ key_index_to_use = _current_google_key_idx
60
+ selected_api_key = _google_api_keys_list[key_index_to_use]
61
+
62
+ _current_google_key_idx = (key_index_to_use + 1) % len(_google_api_keys_list)
63
+
64
+ logger.debug(f"Using Google API key at index {key_index_to_use} (ending with ...{selected_api_key[-4:] if len(selected_api_key) > 4 else selected_api_key}) for round-robin.")
65
+
66
+ return ChatGoogleGenerativeAI(
67
+ model="gemini-2.0-flash",
68
+ google_api_key=selected_api_key,
69
+ temperature=temperature,
70
+ top_p=top_p
71
+ )
72
 
73
  def create_precise_llm():
74
  return create_llm(temperature=0, top_p=1)
src/agent/llm_agent.py CHANGED
@@ -1,38 +1,73 @@
1
  from agent.llm import create_llm
2
  from pydantic import BaseModel, Field
3
- from typing import Optional, List
4
  import logging
 
 
 
 
 
5
 
6
  logger = logging.getLogger(__name__)
7
 
8
- class ChangeScene(BaseModel):
9
- change_scene: bool = Field(description="Whether the scene should be changed")
10
- scene_description: Optional[str] = None
11
-
12
- class ChangeMusic(BaseModel):
13
- change_music: bool = Field(description="Whether the music should be changed")
14
- music_description: Optional[str] = None
15
-
16
  class PlayerOption(BaseModel):
17
- option_description: str = Field(description="The description of the option, Examples: [Change location] Go to the forest; [Say] Hello!")
18
-
 
 
 
19
  class LLMOutput(BaseModel):
20
- change_scene: ChangeScene
21
- change_music: ChangeMusic
22
- game_message: str = Field(description="The message to the player, Example: You entered the forest, and you see unknown scary creatures. What do you do?")
23
- player_options: List[PlayerOption] = Field(description="The list of up to 3 options for the player to choose from.")
24
-
25
- llm = create_llm().with_structured_output(LLMOutput)
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- async def process_user_input(input: str) -> LLMOutput:
 
28
  """
29
  Process user input and update the state.
30
  """
31
- logger.info(f"User's choice: {input}")
32
-
 
33
  response: LLMOutput = await llm.ainvoke(input)
 
 
 
 
 
 
 
 
 
34
 
35
- logger.info(f"LLM response: {response}")
36
 
37
- return response
38
-
 
 
 
 
 
 
 
 
 
 
 
1
  from agent.llm import create_llm
2
  from pydantic import BaseModel, Field
3
+ from typing import List
4
  import logging
5
+ from agent.image_agent import ChangeScene
6
+ import asyncio
7
+ from agent.music_agent import generate_music_prompt
8
+ from agent.image_agent import generate_scene_image
9
+ import uuid
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
+
 
 
 
 
 
 
 
14
  class PlayerOption(BaseModel):
15
+ option_description: str = Field(
16
+ description="The description of the option, Examples: [Change location] Go to the forest; [Say] Hello!"
17
+ )
18
+
19
+
20
  class LLMOutput(BaseModel):
21
+ game_message: str = Field(
22
+ description="The message to the player, Example: You entered the forest, and you see unknown scary creatures. What do you do?"
23
+ )
24
+ player_options: List[PlayerOption] = Field(
25
+ description="The list of up to 3 options for the player to choose from."
26
+ )
27
+
28
+
29
+ class MultiAgentResponse(BaseModel):
30
+ game_message: str = Field(
31
+ description="The message to the player, Example: You entered the forest, and you see unknown scary creatures. What do you do?"
32
+ )
33
+ player_options: List[PlayerOption] = Field(
34
+ description="The list of up to 3 options for the player to choose from."
35
+ )
36
+ music_prompt: str = Field(description="The prompt for the music generation model.")
37
+ change_scene: ChangeScene = Field(description="The change to the scene.")
38
+
39
+ llm = create_llm().with_structured_output(MultiAgentResponse)
40
 
41
+
42
+ async def process_user_input(input: str) -> MultiAgentResponse:
43
  """
44
  Process user input and update the state.
45
  """
46
+ request_id = str(uuid.uuid4())
47
+ logger.info(f"LLM input received: {request_id}")
48
+
49
  response: LLMOutput = await llm.ainvoke(input)
50
+
51
+ # return response
52
+ current_state = f"""{input}
53
+
54
+ Game reaction: {response.game_message}
55
+ Player options: {response.player_options}
56
+ """
57
+
58
+ music_prompt_task = generate_music_prompt(current_state, request_id)
59
 
60
+ change_scene_task = generate_scene_image(current_state, request_id)
61
 
62
+ music_prompt, change_scene = await asyncio.gather(music_prompt_task, change_scene_task)
63
+
64
+ multi_agent_response = MultiAgentResponse(
65
+ game_message=response.game_message,
66
+ player_options=response.player_options,
67
+ music_prompt=music_prompt,
68
+ change_scene=change_scene,
69
+ )
70
+
71
+ logger.info(f"LLM responded: {request_id}")
72
+
73
+ return multi_agent_response
src/audio/audio_generator.py CHANGED
@@ -13,10 +13,12 @@ logger = logging.getLogger(__name__)
13
  client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'})
14
 
15
  async def generate_music(user_hash: str, music_tone: str, receive_audio):
16
- async with (
 
 
17
  client.aio.live.music.connect(model='models/lyria-realtime-exp') as session,
18
  asyncio.TaskGroup() as tg,
19
- ):
20
  # Set up task to receive server messages.
21
  tg.create_task(receive_audio(session, user_hash))
22
 
@@ -31,10 +33,9 @@ async def generate_music(user_hash: str, music_tone: str, receive_audio):
31
  )
32
  await session.play()
33
  logger.info(f"Started music generation for user hash {user_hash}, music tone: {music_tone}")
34
- await cleanup_music_session(user_hash)
35
  sessions[user_hash] = {
36
  'session': session,
37
- 'queue': queue.Queue(maxsize=3)
38
  }
39
 
40
  async def change_music_tone(user_hash: str, new_tone):
@@ -43,7 +44,6 @@ async def change_music_tone(user_hash: str, new_tone):
43
  if not session:
44
  logger.error(f"No session found for user hash {user_hash}")
45
  return
46
- await session.reset_context()
47
  await session.set_weighted_prompts(
48
  prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
49
  )
 
13
  client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'})
14
 
15
  async def generate_music(user_hash: str, music_tone: str, receive_audio):
16
+ if user_hash in sessions:
17
+ return
18
+ async with (
19
  client.aio.live.music.connect(model='models/lyria-realtime-exp') as session,
20
  asyncio.TaskGroup() as tg,
21
+ ):
22
  # Set up task to receive server messages.
23
  tg.create_task(receive_audio(session, user_hash))
24
 
 
33
  )
34
  await session.play()
35
  logger.info(f"Started music generation for user hash {user_hash}, music tone: {music_tone}")
 
36
  sessions[user_hash] = {
37
  'session': session,
38
+ 'queue': queue.Queue()
39
  }
40
 
41
  async def change_music_tone(user_hash: str, new_tone):
 
44
  if not session:
45
  logger.error(f"No session found for user hash {user_hash}")
46
  return
 
47
  await session.set_weighted_prompts(
48
  prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
49
  )
src/config.py CHANGED
@@ -21,8 +21,11 @@ class BaseAppSettings(BaseSettings):
21
 
22
  class AppSettings(BaseAppSettings):
23
  gemini_api_key: SecretStr
 
 
24
  top_p: float = 0.95
25
  temperature: float = 0.5
 
26
 
27
 
28
  settings = AppSettings()
 
21
 
22
  class AppSettings(BaseAppSettings):
23
  gemini_api_key: SecretStr
24
+ gemini_api_keys: SecretStr
25
+ # assistant_api_key: SecretStr
26
  top_p: float = 0.95
27
  temperature: float = 0.5
28
+ pregenerate_next_scene: bool = True
29
 
30
 
31
  settings = AppSettings()
src/css.py CHANGED
@@ -33,11 +33,11 @@ custom_css = """
33
  background: rgba(0,0,0,0.7) !important;
34
  border: none !important;
35
  color: white !important;
36
- font-size: 18px !important;
37
  line-height: 1.5 !important;
38
- padding: 20px !important;
39
  border-radius: 10px !important;
40
- margin-bottom: 20px !important;
41
  }
42
 
43
  img {
@@ -49,7 +49,7 @@ img {
49
  border: none !important;
50
  color: white !important;
51
  -webkit-text-fill-color: white !important;
52
- font-size: 18px !important;
53
  resize: none !important;
54
  }
55
 
@@ -57,13 +57,12 @@ img {
57
  .choice-buttons {
58
  background: rgba(0,0,0,0.7) !important;
59
  border-radius: 10px !important;
60
- padding: 15px !important;
61
  }
62
 
63
  .choice-buttons label {
64
  color: white !important;
65
- font-size: 16px !important;
66
- margin-bottom: 10px !important;
67
  }
68
 
69
  /* Fix radio button backgrounds */
 
33
  background: rgba(0,0,0,0.7) !important;
34
  border: none !important;
35
  color: white !important;
36
+ font-size: 15px !important;
37
  line-height: 1.5 !important;
38
+ padding: 10px !important;
39
  border-radius: 10px !important;
40
+ margin-bottom: 10px !important;
41
  }
42
 
43
  img {
 
49
  border: none !important;
50
  color: white !important;
51
  -webkit-text-fill-color: white !important;
52
+ font-size: 15px !important;
53
  resize: none !important;
54
  }
55
 
 
57
  .choice-buttons {
58
  background: rgba(0,0,0,0.7) !important;
59
  border-radius: 10px !important;
60
+ padding: 10px !important;
61
  }
62
 
63
  .choice-buttons label {
64
  color: white !important;
65
+ font-size: 14px !important;
 
66
  }
67
 
68
  /* Fix radio button backgrounds */
src/game_constructor.py CHANGED
@@ -1,12 +1,14 @@
1
  import gradio as gr
2
  import json
3
  import uuid
4
- from game_setting import Character, GameSetting
5
  from game_state import story, state, get_current_scene
6
  from agent.llm_agent import process_user_input
7
  from images.image_generator import generate_image
8
  from audio.audio_generator import start_music_generation
9
  import asyncio
 
 
10
 
11
  # Predefined suggestions for demo
12
  SETTING_SUGGESTIONS = [
@@ -107,6 +109,7 @@ def save_game_config(
107
  except Exception as e:
108
  return f"❌ Error saving configuration: {str(e)}"
109
 
 
110
  async def start_game_with_settings(
111
  user_hash: str,
112
  setting_desc: str,
@@ -155,27 +158,41 @@ Genre: {game_setting.genre}
155
 
156
  You find yourself at the beginning of your adventure. The world around you feels alive with possibilities. What do you choose to do first?
157
 
158
- NOTE FOR THE ASSISTANT: YOU HAVE TO GENERATE THE IMAGE FOR THE START SCENE.
159
  """
160
 
161
  response = await process_user_input(initial_story)
162
-
163
- music_tone = response.change_music.music_description or "neutral"
164
-
165
  asyncio.create_task(start_music_generation(user_hash, music_tone))
166
 
167
  img = "forest.jpg"
168
-
169
- if response.change_scene.change_scene:
170
- img_path, _ = await generate_image(response.change_scene.scene_description)
171
- if img_path:
172
- img = img_path
 
 
173
 
174
  story["start"] = {
175
  "text": response.game_message,
176
  "image": img,
177
- "choices": [option.option_description for option in response.player_options],
178
- "music_tone": response.change_music.music_description,
 
 
 
 
 
 
 
 
 
 
 
 
179
  }
180
  state["scene"] = "start"
181
 
 
1
  import gradio as gr
2
  import json
3
  import uuid
4
+ from game_setting import Character, GameSetting, get_user_story
5
  from game_state import story, state, get_current_scene
6
  from agent.llm_agent import process_user_input
7
  from images.image_generator import generate_image
8
  from audio.audio_generator import start_music_generation
9
  import asyncio
10
+ from config import settings
11
+
12
 
13
  # Predefined suggestions for demo
14
  SETTING_SUGGESTIONS = [
 
109
  except Exception as e:
110
  return f"❌ Error saving configuration: {str(e)}"
111
 
112
+
113
  async def start_game_with_settings(
114
  user_hash: str,
115
  setting_desc: str,
 
158
 
159
  You find yourself at the beginning of your adventure. The world around you feels alive with possibilities. What do you choose to do first?
160
 
161
+ NOTE FOR THE ASSISTANT: YOU HAVE TO GENERATE A NEW IMAGE FOR THE START SCENE.
162
  """
163
 
164
  response = await process_user_input(initial_story)
165
+
166
+ music_tone = response.music_prompt
167
+
168
  asyncio.create_task(start_music_generation(user_hash, music_tone))
169
 
170
  img = "forest.jpg"
171
+ img_description = ""
172
+
173
+ img_path, img_description = await generate_image(
174
+ response.change_scene.scene_description
175
+ )
176
+ if img_path:
177
+ img = img_path
178
 
179
  story["start"] = {
180
  "text": response.game_message,
181
  "image": img,
182
+ "choices": {
183
+ option.option_description: asyncio.create_task(
184
+ process_user_input(
185
+ get_user_story(
186
+ response.game_message,
187
+ response.change_scene.scene_description,
188
+ option.option_description,
189
+ )
190
+ )
191
+ ) if settings.pregenerate_next_scene else None
192
+ for option in response.player_options
193
+ },
194
+ "music_tone": response.music_prompt,
195
+ "img_description": img_description,
196
  }
197
  state["scene"] = "start"
198
 
src/game_setting.py CHANGED
@@ -1,12 +1,25 @@
1
  from pydantic import BaseModel
2
 
 
3
  class Character(BaseModel):
4
  name: str
5
  age: str
6
  background: str
7
  personality: str
8
 
 
9
  class GameSetting(BaseModel):
10
  character: Character
11
  setting: str
12
  genre: str
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pydantic import BaseModel
2
 
3
+
4
  class Character(BaseModel):
5
  name: str
6
  age: str
7
  background: str
8
  personality: str
9
 
10
+
11
  class GameSetting(BaseModel):
12
  character: Character
13
  setting: str
14
  genre: str
15
+
16
+
17
+ def get_user_story(
18
+ scene_description: str, scene_image_description: str, user_choice: str
19
+ ) -> str:
20
+ return f"""Current scene description:
21
+ {scene_description}
22
+ Current scene image description: {scene_image_description}
23
+
24
+ User's choice: {user_choice}
25
+ """
src/game_state.py CHANGED
@@ -1,10 +1,10 @@
1
-
2
  story = {
3
  "start": {
4
  "text": "You wake up in a mysterious forest. What do you do?",
5
  "image": "forest.jpg",
6
- "choices": ["Explore", "Wait"],
7
  "music_tone": "neutral",
 
8
  },
9
  }
10
 
@@ -12,4 +12,4 @@ state = {"scene": "start"}
12
 
13
  def get_current_scene():
14
  scene = story[state["scene"]]
15
- return scene["text"], scene["image"], scene["choices"]
 
 
1
  story = {
2
  "start": {
3
  "text": "You wake up in a mysterious forest. What do you do?",
4
  "image": "forest.jpg",
5
+ "choices": {"Explore": None, "Wait": None},
6
  "music_tone": "neutral",
7
+ "img_description": "forest in the fog",
8
  },
9
  }
10
 
 
12
 
13
  def get_current_scene():
14
  scene = story[state["scene"]]
15
+ return scene["text"], scene["image"], scene["choices"].keys()
src/images/image_generator.py CHANGED
@@ -6,25 +6,47 @@ from io import BytesIO
6
  from datetime import datetime
7
  from config import settings
8
  import logging
 
 
9
 
10
  logger = logging.getLogger(__name__)
11
 
12
  client = genai.Client(api_key=settings.gemini_api_key.get_secret_value()).aio
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  async def generate_image(prompt: str) -> tuple[str, str] | None:
15
  """
16
  Generate an image using Google's Gemini model and save it to generated/images directory.
17
-
18
  Args:
19
  prompt (str): The text prompt to generate the image from
20
-
21
  Returns:
22
  str: Path to the generated image file, or None if generation failed
23
  """
24
  # Ensure the generated/images directory exists
25
  output_dir = "generated/images"
26
  os.makedirs(output_dir, exist_ok=True)
27
-
28
  logger.info(f"Generating image with prompt: {prompt}")
29
 
30
  try:
@@ -32,8 +54,9 @@ async def generate_image(prompt: str) -> tuple[str, str] | None:
32
  model="gemini-2.0-flash-preview-image-generation",
33
  contents=prompt,
34
  config=types.GenerateContentConfig(
35
- response_modalities=['TEXT', 'IMAGE'],
36
- )
 
37
  )
38
 
39
  # Process the response parts
@@ -44,19 +67,20 @@ async def generate_image(prompt: str) -> tuple[str, str] | None:
44
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
45
  filename = f"gemini_{timestamp}.png"
46
  filepath = os.path.join(output_dir, filename)
47
-
48
  # Save the image
49
  image = Image.open(BytesIO(part.inline_data.data))
50
- image.save(filepath, "PNG")
51
  logger.info(f"Image saved to: {filepath}")
52
  image_saved = True
53
-
54
- return filepath, part.text
55
-
56
  if not image_saved:
 
57
  logger.error("No image was generated in the response.")
58
  return None, None
59
-
60
  except Exception as e:
61
  logger.error(f"Error generating image: {e}")
62
  return None, None
@@ -65,38 +89,41 @@ async def generate_image(prompt: str) -> tuple[str, str] | None:
65
  async def modify_image(image_path: str, modification_prompt: str) -> str | None:
66
  """
67
  Modify an existing image using Google's Gemini model based on a text prompt.
68
-
69
  Args:
70
  image_path (str): Path to the existing image file
71
  modification_prompt (str): The text prompt describing how to modify the image
72
-
73
  Returns:
74
  str: Path to the modified image file, or None if modification failed
75
  """
76
  # Ensure the generated/images directory exists
77
  output_dir = "generated/images"
78
  os.makedirs(output_dir, exist_ok=True)
79
-
 
 
80
  # Check if the input image exists
81
  if not os.path.exists(image_path):
82
  logger.error(f"Error: Image file not found at {image_path}")
83
  return None
84
-
85
  key = settings.gemini_api_key.get_secret_value()
86
-
87
  client = genai.Client(api_key=key).aio
88
 
89
  try:
90
  # Load the input image
91
  input_image = Image.open(image_path)
92
-
93
  # Make the API call with both text and image
94
  response = await client.models.generate_content(
95
  model="gemini-2.0-flash-preview-image-generation",
96
  contents=[modification_prompt, input_image],
97
  config=types.GenerateContentConfig(
98
- response_modalities=['TEXT', 'IMAGE']
99
- )
 
100
  )
101
 
102
  # Process the response parts
@@ -107,19 +134,20 @@ async def modify_image(image_path: str, modification_prompt: str) -> str | None:
107
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
108
  filename = f"gemini_modified_{timestamp}.png"
109
  filepath = os.path.join(output_dir, filename)
110
-
111
  # Save the modified image
112
  modified_image = Image.open(BytesIO(part.inline_data.data))
113
- modified_image.save(filepath, "PNG")
114
  logger.info(f"Modified image saved to: {filepath}")
115
  image_saved = True
116
-
117
- return filepath, part.text
118
-
119
  if not image_saved:
 
120
  logger.error("No modified image was generated in the response.")
121
  return None, None
122
-
123
  except Exception as e:
124
  logger.error(f"Error modifying image: {e}")
125
  return None, None
@@ -129,10 +157,10 @@ if __name__ == "__main__":
129
  # Example usage
130
  sample_prompt = "A Luke Skywalker half height sprite with white background for visual novel game"
131
  generated_image_path = generate_image(sample_prompt)
132
-
133
  # if generated_image_path:
134
  # # Example modification
135
  # modification_prompt = "Now the house is destroyed, and the jawas are running away"
136
  # modified_image_path = modify_image(generated_image_path, modification_prompt)
137
  # if modified_image_path:
138
- # print(f"Successfully modified image: {modified_image_path}")
 
6
  from datetime import datetime
7
  from config import settings
8
  import logging
9
+ import asyncio
10
+ import gradio as gr
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
  client = genai.Client(api_key=settings.gemini_api_key.get_secret_value()).aio
15
 
16
+ safety_settings = [
17
+ types.SafetySetting(
18
+ category="HARM_CATEGORY_HARASSMENT",
19
+ threshold="BLOCK_NONE", # Block none
20
+ ),
21
+ types.SafetySetting(
22
+ category="HARM_CATEGORY_HATE_SPEECH",
23
+ threshold="BLOCK_NONE", # Block none
24
+ ),
25
+ types.SafetySetting(
26
+ category="HARM_CATEGORY_SEXUALLY_EXPLICIT",
27
+ threshold="BLOCK_NONE", # Block none
28
+ ),
29
+ types.SafetySetting(
30
+ category="HARM_CATEGORY_DANGEROUS_CONTENT",
31
+ threshold="BLOCK_NONE", # Block none
32
+ ),
33
+ ]
34
+
35
+
36
  async def generate_image(prompt: str) -> tuple[str, str] | None:
37
  """
38
  Generate an image using Google's Gemini model and save it to generated/images directory.
39
+
40
  Args:
41
  prompt (str): The text prompt to generate the image from
42
+
43
  Returns:
44
  str: Path to the generated image file, or None if generation failed
45
  """
46
  # Ensure the generated/images directory exists
47
  output_dir = "generated/images"
48
  os.makedirs(output_dir, exist_ok=True)
49
+
50
  logger.info(f"Generating image with prompt: {prompt}")
51
 
52
  try:
 
54
  model="gemini-2.0-flash-preview-image-generation",
55
  contents=prompt,
56
  config=types.GenerateContentConfig(
57
+ response_modalities=["TEXT", "IMAGE"],
58
+ safety_settings=safety_settings,
59
+ ),
60
  )
61
 
62
  # Process the response parts
 
67
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
68
  filename = f"gemini_{timestamp}.png"
69
  filepath = os.path.join(output_dir, filename)
70
+
71
  # Save the image
72
  image = Image.open(BytesIO(part.inline_data.data))
73
+ await asyncio.to_thread(image.save, filepath, "PNG")
74
  logger.info(f"Image saved to: {filepath}")
75
  image_saved = True
76
+
77
+ return filepath, prompt
78
+
79
  if not image_saved:
80
+ gr.Warning("Image was censored by Google!")
81
  logger.error("No image was generated in the response.")
82
  return None, None
83
+
84
  except Exception as e:
85
  logger.error(f"Error generating image: {e}")
86
  return None, None
 
89
  async def modify_image(image_path: str, modification_prompt: str) -> str | None:
90
  """
91
  Modify an existing image using Google's Gemini model based on a text prompt.
92
+
93
  Args:
94
  image_path (str): Path to the existing image file
95
  modification_prompt (str): The text prompt describing how to modify the image
96
+
97
  Returns:
98
  str: Path to the modified image file, or None if modification failed
99
  """
100
  # Ensure the generated/images directory exists
101
  output_dir = "generated/images"
102
  os.makedirs(output_dir, exist_ok=True)
103
+
104
+ logger.info(f"Modifying current scene image with prompt: {modification_prompt}")
105
+
106
  # Check if the input image exists
107
  if not os.path.exists(image_path):
108
  logger.error(f"Error: Image file not found at {image_path}")
109
  return None
110
+
111
  key = settings.gemini_api_key.get_secret_value()
112
+
113
  client = genai.Client(api_key=key).aio
114
 
115
  try:
116
  # Load the input image
117
  input_image = Image.open(image_path)
118
+
119
  # Make the API call with both text and image
120
  response = await client.models.generate_content(
121
  model="gemini-2.0-flash-preview-image-generation",
122
  contents=[modification_prompt, input_image],
123
  config=types.GenerateContentConfig(
124
+ response_modalities=["TEXT", "IMAGE"],
125
+ safety_settings=safety_settings,
126
+ ),
127
  )
128
 
129
  # Process the response parts
 
134
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
135
  filename = f"gemini_modified_{timestamp}.png"
136
  filepath = os.path.join(output_dir, filename)
137
+
138
  # Save the modified image
139
  modified_image = Image.open(BytesIO(part.inline_data.data))
140
+ await asyncio.to_thread(modified_image.save, filepath, "PNG")
141
  logger.info(f"Modified image saved to: {filepath}")
142
  image_saved = True
143
+
144
+ return filepath, modification_prompt
145
+
146
  if not image_saved:
147
+ gr.Warning("Updated image was censored by Google!")
148
  logger.error("No modified image was generated in the response.")
149
  return None, None
150
+
151
  except Exception as e:
152
  logger.error(f"Error modifying image: {e}")
153
  return None, None
 
157
  # Example usage
158
  sample_prompt = "A Luke Skywalker half height sprite with white background for visual novel game"
159
  generated_image_path = generate_image(sample_prompt)
160
+
161
  # if generated_image_path:
162
  # # Example modification
163
  # modification_prompt = "Now the house is destroyed, and the jawas are running away"
164
  # modified_image_path = modify_image(generated_image_path, modification_prompt)
165
  # if modified_image_path:
166
+ # print(f"Successfully modified image: {modified_image_path}")
src/main.py CHANGED
@@ -7,7 +7,7 @@ from audio.audio_generator import (
7
  )
8
  import logging
9
  from agent.llm_agent import process_user_input
10
- from images.image_generator import generate_image
11
  import uuid
12
  from game_state import story, state
13
  from game_constructor import (
@@ -19,6 +19,8 @@ from game_constructor import (
19
  start_game_with_settings,
20
  )
21
  import asyncio
 
 
22
 
23
  logger = logging.getLogger(__name__)
24
 
@@ -43,29 +45,53 @@ async def update_scene(user_hash: str, choice):
43
  }
44
  state["scene"] = new_scene
45
 
46
- user_story = f"""Current scene description:
47
- {story[old_scene]["text"]}
48
- User's choice: {choice}
49
- """
50
 
51
- response = await process_user_input(user_story)
 
 
52
 
53
  story[new_scene]["text"] = response.game_message
54
 
55
- story[new_scene]["choices"] = [
56
- option.option_description for option in response.player_options
57
- ]
58
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  # run both tasks in parallel
60
  img_res, _ = await asyncio.gather(
61
- generate_image(response.change_scene.scene_description) if response.change_scene.change_scene else asyncio.sleep(0),
62
- change_music_tone(user_hash, response.change_music.music_description) if response.change_music.change_music else asyncio.sleep(0)
63
  )
64
-
65
  if img_res and response.change_scene.change_scene:
66
- img_path, _ = img_res
67
  if img_path:
68
  story[new_scene]["image"] = img_path
 
69
 
70
  scene = story[state["scene"]]
71
  return (
@@ -136,7 +162,7 @@ with gr.Blocks(
136
  # Fullscreen Loading Indicator (hidden by default)
137
  with gr.Column(visible=False, elem_id="loading-indicator") as loading_indicator:
138
  gr.HTML("<div class='loading-text'>🚀 Starting your adventure...</div>")
139
-
140
  local_storage = gr.BrowserState(str(uuid.uuid4()), "user_hash")
141
 
142
  # Constructor Interface (visible by default)
 
7
  )
8
  import logging
9
  from agent.llm_agent import process_user_input
10
+ from images.image_generator import modify_image
11
  import uuid
12
  from game_state import story, state
13
  from game_constructor import (
 
19
  start_game_with_settings,
20
  )
21
  import asyncio
22
+ from game_setting import get_user_story
23
+ from config import settings
24
 
25
  logger = logging.getLogger(__name__)
26
 
 
45
  }
46
  state["scene"] = new_scene
47
 
48
+ user_story = get_user_story(
49
+ story[old_scene]["text"], story[old_scene]["img_description"], choice
50
+ )
 
51
 
52
+ response = await (
53
+ story[old_scene]["choices"][choice] or process_user_input(user_story)
54
+ )
55
 
56
  story[new_scene]["text"] = response.game_message
57
 
58
+ story[new_scene]["choices"] = {
59
+ option.option_description: asyncio.create_task(
60
+ process_user_input(
61
+ get_user_story(
62
+ response.game_message,
63
+ response.change_scene.scene_description,
64
+ option.option_description,
65
+ )
66
+ )
67
+ )
68
+ if settings.pregenerate_next_scene
69
+ else None
70
+ for option in response.player_options
71
+ }
72
+
73
+ img_task = None
74
+ # always modify the image to avoid hallucinations in which image is being generated in entirely different style
75
+ if (
76
+ response.change_scene.change_scene == "change_completely"
77
+ or response.change_scene.change_scene == "modify"
78
+ ):
79
+ img_task = modify_image(
80
+ story[old_scene]["image"], response.change_scene.scene_description
81
+ )
82
+ else:
83
+ img_task = asyncio.sleep(0)
84
+
85
  # run both tasks in parallel
86
  img_res, _ = await asyncio.gather(
87
+ img_task, change_music_tone(user_hash, response.music_prompt)
 
88
  )
89
+
90
  if img_res and response.change_scene.change_scene:
91
+ img_path, img_description = img_res
92
  if img_path:
93
  story[new_scene]["image"] = img_path
94
+ story[new_scene]["img_description"] = img_description
95
 
96
  scene = story[state["scene"]]
97
  return (
 
162
  # Fullscreen Loading Indicator (hidden by default)
163
  with gr.Column(visible=False, elem_id="loading-indicator") as loading_indicator:
164
  gr.HTML("<div class='loading-text'>🚀 Starting your adventure...</div>")
165
+
166
  local_storage = gr.BrowserState(str(uuid.uuid4()), "user_hash")
167
 
168
  # Constructor Interface (visible by default)