Spaces:
Running
Running
feat: improve image generation
Browse files- src/agent/llm.py +32 -1
- src/agent/llm_agent.py +58 -23
- src/audio/audio_generator.py +5 -5
- src/config.py +3 -0
- src/css.py +6 -7
- src/game_constructor.py +29 -12
- src/game_setting.py +13 -0
- src/game_state.py +3 -3
- src/images/image_generator.py +55 -27
- src/main.py +41 -15
src/agent/llm.py
CHANGED
@@ -12,7 +12,7 @@ def create_llm(temperature: float = settings.temperature, top_p: float = setting
|
|
12 |
global _google_api_keys_list, _current_google_key_idx
|
13 |
|
14 |
if not _google_api_keys_list:
|
15 |
-
api_keys_str = settings.
|
16 |
if api_keys_str:
|
17 |
_google_api_keys_list = [key.strip() for key in api_keys_str.split(',') if key.strip()]
|
18 |
|
@@ -38,6 +38,37 @@ def create_llm(temperature: float = settings.temperature, top_p: float = setting
|
|
38 |
top_p=top_p,
|
39 |
thinking_budget=1024
|
40 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
def create_precise_llm():
|
43 |
return create_llm(temperature=0, top_p=1)
|
|
|
12 |
global _google_api_keys_list, _current_google_key_idx
|
13 |
|
14 |
if not _google_api_keys_list:
|
15 |
+
api_keys_str = settings.gemini_api_keys.get_secret_value()
|
16 |
if api_keys_str:
|
17 |
_google_api_keys_list = [key.strip() for key in api_keys_str.split(',') if key.strip()]
|
18 |
|
|
|
38 |
top_p=top_p,
|
39 |
thinking_budget=1024
|
40 |
)
|
41 |
+
|
42 |
+
|
43 |
+
def create_light_llm(temperature: float = settings.temperature, top_p: float = settings.top_p):
|
44 |
+
global _google_api_keys_list, _current_google_key_idx
|
45 |
+
|
46 |
+
if not _google_api_keys_list:
|
47 |
+
api_keys_str = settings.gemini_api_keys.get_secret_value()
|
48 |
+
if api_keys_str:
|
49 |
+
_google_api_keys_list = [key.strip() for key in api_keys_str.split(',') if key.strip()]
|
50 |
+
|
51 |
+
if not _google_api_keys_list:
|
52 |
+
logger.error("Google API keys are not configured or are empty in settings.")
|
53 |
+
raise ValueError("Google API keys are not configured or are invalid for round-robin.")
|
54 |
+
|
55 |
+
if not _google_api_keys_list: # Safeguard, though previous block should handle it.
|
56 |
+
logger.error("No Google API keys available for round-robin.")
|
57 |
+
raise ValueError("No Google API keys available for round-robin.")
|
58 |
+
|
59 |
+
key_index_to_use = _current_google_key_idx
|
60 |
+
selected_api_key = _google_api_keys_list[key_index_to_use]
|
61 |
+
|
62 |
+
_current_google_key_idx = (key_index_to_use + 1) % len(_google_api_keys_list)
|
63 |
+
|
64 |
+
logger.debug(f"Using Google API key at index {key_index_to_use} (ending with ...{selected_api_key[-4:] if len(selected_api_key) > 4 else selected_api_key}) for round-robin.")
|
65 |
+
|
66 |
+
return ChatGoogleGenerativeAI(
|
67 |
+
model="gemini-2.0-flash",
|
68 |
+
google_api_key=selected_api_key,
|
69 |
+
temperature=temperature,
|
70 |
+
top_p=top_p
|
71 |
+
)
|
72 |
|
73 |
def create_precise_llm():
|
74 |
return create_llm(temperature=0, top_p=1)
|
src/agent/llm_agent.py
CHANGED
@@ -1,38 +1,73 @@
|
|
1 |
from agent.llm import create_llm
|
2 |
from pydantic import BaseModel, Field
|
3 |
-
from typing import
|
4 |
import logging
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
logger = logging.getLogger(__name__)
|
7 |
|
8 |
-
|
9 |
-
change_scene: bool = Field(description="Whether the scene should be changed")
|
10 |
-
scene_description: Optional[str] = None
|
11 |
-
|
12 |
-
class ChangeMusic(BaseModel):
|
13 |
-
change_music: bool = Field(description="Whether the music should be changed")
|
14 |
-
music_description: Optional[str] = None
|
15 |
-
|
16 |
class PlayerOption(BaseModel):
|
17 |
-
option_description: str = Field(
|
18 |
-
|
|
|
|
|
|
|
19 |
class LLMOutput(BaseModel):
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
player_options: List[PlayerOption] = Field(
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
|
|
|
28 |
"""
|
29 |
Process user input and update the state.
|
30 |
"""
|
31 |
-
|
32 |
-
|
|
|
33 |
response: LLMOutput = await llm.ainvoke(input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
|
36 |
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from agent.llm import create_llm
|
2 |
from pydantic import BaseModel, Field
|
3 |
+
from typing import List
|
4 |
import logging
|
5 |
+
from agent.image_agent import ChangeScene
|
6 |
+
import asyncio
|
7 |
+
from agent.music_agent import generate_music_prompt
|
8 |
+
from agent.image_agent import generate_scene_image
|
9 |
+
import uuid
|
10 |
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
class PlayerOption(BaseModel):
|
15 |
+
option_description: str = Field(
|
16 |
+
description="The description of the option, Examples: [Change location] Go to the forest; [Say] Hello!"
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
class LLMOutput(BaseModel):
|
21 |
+
game_message: str = Field(
|
22 |
+
description="The message to the player, Example: You entered the forest, and you see unknown scary creatures. What do you do?"
|
23 |
+
)
|
24 |
+
player_options: List[PlayerOption] = Field(
|
25 |
+
description="The list of up to 3 options for the player to choose from."
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class MultiAgentResponse(BaseModel):
|
30 |
+
game_message: str = Field(
|
31 |
+
description="The message to the player, Example: You entered the forest, and you see unknown scary creatures. What do you do?"
|
32 |
+
)
|
33 |
+
player_options: List[PlayerOption] = Field(
|
34 |
+
description="The list of up to 3 options for the player to choose from."
|
35 |
+
)
|
36 |
+
music_prompt: str = Field(description="The prompt for the music generation model.")
|
37 |
+
change_scene: ChangeScene = Field(description="The change to the scene.")
|
38 |
+
|
39 |
+
llm = create_llm().with_structured_output(MultiAgentResponse)
|
40 |
|
41 |
+
|
42 |
+
async def process_user_input(input: str) -> MultiAgentResponse:
|
43 |
"""
|
44 |
Process user input and update the state.
|
45 |
"""
|
46 |
+
request_id = str(uuid.uuid4())
|
47 |
+
logger.info(f"LLM input received: {request_id}")
|
48 |
+
|
49 |
response: LLMOutput = await llm.ainvoke(input)
|
50 |
+
|
51 |
+
# return response
|
52 |
+
current_state = f"""{input}
|
53 |
+
|
54 |
+
Game reaction: {response.game_message}
|
55 |
+
Player options: {response.player_options}
|
56 |
+
"""
|
57 |
+
|
58 |
+
music_prompt_task = generate_music_prompt(current_state, request_id)
|
59 |
|
60 |
+
change_scene_task = generate_scene_image(current_state, request_id)
|
61 |
|
62 |
+
music_prompt, change_scene = await asyncio.gather(music_prompt_task, change_scene_task)
|
63 |
+
|
64 |
+
multi_agent_response = MultiAgentResponse(
|
65 |
+
game_message=response.game_message,
|
66 |
+
player_options=response.player_options,
|
67 |
+
music_prompt=music_prompt,
|
68 |
+
change_scene=change_scene,
|
69 |
+
)
|
70 |
+
|
71 |
+
logger.info(f"LLM responded: {request_id}")
|
72 |
+
|
73 |
+
return multi_agent_response
|
src/audio/audio_generator.py
CHANGED
@@ -13,10 +13,12 @@ logger = logging.getLogger(__name__)
|
|
13 |
client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'})
|
14 |
|
15 |
async def generate_music(user_hash: str, music_tone: str, receive_audio):
|
16 |
-
|
|
|
|
|
17 |
client.aio.live.music.connect(model='models/lyria-realtime-exp') as session,
|
18 |
asyncio.TaskGroup() as tg,
|
19 |
-
|
20 |
# Set up task to receive server messages.
|
21 |
tg.create_task(receive_audio(session, user_hash))
|
22 |
|
@@ -31,10 +33,9 @@ async def generate_music(user_hash: str, music_tone: str, receive_audio):
|
|
31 |
)
|
32 |
await session.play()
|
33 |
logger.info(f"Started music generation for user hash {user_hash}, music tone: {music_tone}")
|
34 |
-
await cleanup_music_session(user_hash)
|
35 |
sessions[user_hash] = {
|
36 |
'session': session,
|
37 |
-
'queue': queue.Queue(
|
38 |
}
|
39 |
|
40 |
async def change_music_tone(user_hash: str, new_tone):
|
@@ -43,7 +44,6 @@ async def change_music_tone(user_hash: str, new_tone):
|
|
43 |
if not session:
|
44 |
logger.error(f"No session found for user hash {user_hash}")
|
45 |
return
|
46 |
-
await session.reset_context()
|
47 |
await session.set_weighted_prompts(
|
48 |
prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
|
49 |
)
|
|
|
13 |
client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'})
|
14 |
|
15 |
async def generate_music(user_hash: str, music_tone: str, receive_audio):
|
16 |
+
if user_hash in sessions:
|
17 |
+
return
|
18 |
+
async with (
|
19 |
client.aio.live.music.connect(model='models/lyria-realtime-exp') as session,
|
20 |
asyncio.TaskGroup() as tg,
|
21 |
+
):
|
22 |
# Set up task to receive server messages.
|
23 |
tg.create_task(receive_audio(session, user_hash))
|
24 |
|
|
|
33 |
)
|
34 |
await session.play()
|
35 |
logger.info(f"Started music generation for user hash {user_hash}, music tone: {music_tone}")
|
|
|
36 |
sessions[user_hash] = {
|
37 |
'session': session,
|
38 |
+
'queue': queue.Queue()
|
39 |
}
|
40 |
|
41 |
async def change_music_tone(user_hash: str, new_tone):
|
|
|
44 |
if not session:
|
45 |
logger.error(f"No session found for user hash {user_hash}")
|
46 |
return
|
|
|
47 |
await session.set_weighted_prompts(
|
48 |
prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
|
49 |
)
|
src/config.py
CHANGED
@@ -21,8 +21,11 @@ class BaseAppSettings(BaseSettings):
|
|
21 |
|
22 |
class AppSettings(BaseAppSettings):
|
23 |
gemini_api_key: SecretStr
|
|
|
|
|
24 |
top_p: float = 0.95
|
25 |
temperature: float = 0.5
|
|
|
26 |
|
27 |
|
28 |
settings = AppSettings()
|
|
|
21 |
|
22 |
class AppSettings(BaseAppSettings):
|
23 |
gemini_api_key: SecretStr
|
24 |
+
gemini_api_keys: SecretStr
|
25 |
+
# assistant_api_key: SecretStr
|
26 |
top_p: float = 0.95
|
27 |
temperature: float = 0.5
|
28 |
+
pregenerate_next_scene: bool = True
|
29 |
|
30 |
|
31 |
settings = AppSettings()
|
src/css.py
CHANGED
@@ -33,11 +33,11 @@ custom_css = """
|
|
33 |
background: rgba(0,0,0,0.7) !important;
|
34 |
border: none !important;
|
35 |
color: white !important;
|
36 |
-
font-size:
|
37 |
line-height: 1.5 !important;
|
38 |
-
padding:
|
39 |
border-radius: 10px !important;
|
40 |
-
margin-bottom:
|
41 |
}
|
42 |
|
43 |
img {
|
@@ -49,7 +49,7 @@ img {
|
|
49 |
border: none !important;
|
50 |
color: white !important;
|
51 |
-webkit-text-fill-color: white !important;
|
52 |
-
font-size:
|
53 |
resize: none !important;
|
54 |
}
|
55 |
|
@@ -57,13 +57,12 @@ img {
|
|
57 |
.choice-buttons {
|
58 |
background: rgba(0,0,0,0.7) !important;
|
59 |
border-radius: 10px !important;
|
60 |
-
padding:
|
61 |
}
|
62 |
|
63 |
.choice-buttons label {
|
64 |
color: white !important;
|
65 |
-
font-size:
|
66 |
-
margin-bottom: 10px !important;
|
67 |
}
|
68 |
|
69 |
/* Fix radio button backgrounds */
|
|
|
33 |
background: rgba(0,0,0,0.7) !important;
|
34 |
border: none !important;
|
35 |
color: white !important;
|
36 |
+
font-size: 15px !important;
|
37 |
line-height: 1.5 !important;
|
38 |
+
padding: 10px !important;
|
39 |
border-radius: 10px !important;
|
40 |
+
margin-bottom: 10px !important;
|
41 |
}
|
42 |
|
43 |
img {
|
|
|
49 |
border: none !important;
|
50 |
color: white !important;
|
51 |
-webkit-text-fill-color: white !important;
|
52 |
+
font-size: 15px !important;
|
53 |
resize: none !important;
|
54 |
}
|
55 |
|
|
|
57 |
.choice-buttons {
|
58 |
background: rgba(0,0,0,0.7) !important;
|
59 |
border-radius: 10px !important;
|
60 |
+
padding: 10px !important;
|
61 |
}
|
62 |
|
63 |
.choice-buttons label {
|
64 |
color: white !important;
|
65 |
+
font-size: 14px !important;
|
|
|
66 |
}
|
67 |
|
68 |
/* Fix radio button backgrounds */
|
src/game_constructor.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
import gradio as gr
|
2 |
import json
|
3 |
import uuid
|
4 |
-
from game_setting import Character, GameSetting
|
5 |
from game_state import story, state, get_current_scene
|
6 |
from agent.llm_agent import process_user_input
|
7 |
from images.image_generator import generate_image
|
8 |
from audio.audio_generator import start_music_generation
|
9 |
import asyncio
|
|
|
|
|
10 |
|
11 |
# Predefined suggestions for demo
|
12 |
SETTING_SUGGESTIONS = [
|
@@ -107,6 +109,7 @@ def save_game_config(
|
|
107 |
except Exception as e:
|
108 |
return f"❌ Error saving configuration: {str(e)}"
|
109 |
|
|
|
110 |
async def start_game_with_settings(
|
111 |
user_hash: str,
|
112 |
setting_desc: str,
|
@@ -155,27 +158,41 @@ Genre: {game_setting.genre}
|
|
155 |
|
156 |
You find yourself at the beginning of your adventure. The world around you feels alive with possibilities. What do you choose to do first?
|
157 |
|
158 |
-
NOTE FOR THE ASSISTANT: YOU HAVE TO GENERATE
|
159 |
"""
|
160 |
|
161 |
response = await process_user_input(initial_story)
|
162 |
-
|
163 |
-
music_tone = response.
|
164 |
-
|
165 |
asyncio.create_task(start_music_generation(user_hash, music_tone))
|
166 |
|
167 |
img = "forest.jpg"
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
173 |
|
174 |
story["start"] = {
|
175 |
"text": response.game_message,
|
176 |
"image": img,
|
177 |
-
"choices":
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
}
|
180 |
state["scene"] = "start"
|
181 |
|
|
|
1 |
import gradio as gr
|
2 |
import json
|
3 |
import uuid
|
4 |
+
from game_setting import Character, GameSetting, get_user_story
|
5 |
from game_state import story, state, get_current_scene
|
6 |
from agent.llm_agent import process_user_input
|
7 |
from images.image_generator import generate_image
|
8 |
from audio.audio_generator import start_music_generation
|
9 |
import asyncio
|
10 |
+
from config import settings
|
11 |
+
|
12 |
|
13 |
# Predefined suggestions for demo
|
14 |
SETTING_SUGGESTIONS = [
|
|
|
109 |
except Exception as e:
|
110 |
return f"❌ Error saving configuration: {str(e)}"
|
111 |
|
112 |
+
|
113 |
async def start_game_with_settings(
|
114 |
user_hash: str,
|
115 |
setting_desc: str,
|
|
|
158 |
|
159 |
You find yourself at the beginning of your adventure. The world around you feels alive with possibilities. What do you choose to do first?
|
160 |
|
161 |
+
NOTE FOR THE ASSISTANT: YOU HAVE TO GENERATE A NEW IMAGE FOR THE START SCENE.
|
162 |
"""
|
163 |
|
164 |
response = await process_user_input(initial_story)
|
165 |
+
|
166 |
+
music_tone = response.music_prompt
|
167 |
+
|
168 |
asyncio.create_task(start_music_generation(user_hash, music_tone))
|
169 |
|
170 |
img = "forest.jpg"
|
171 |
+
img_description = ""
|
172 |
+
|
173 |
+
img_path, img_description = await generate_image(
|
174 |
+
response.change_scene.scene_description
|
175 |
+
)
|
176 |
+
if img_path:
|
177 |
+
img = img_path
|
178 |
|
179 |
story["start"] = {
|
180 |
"text": response.game_message,
|
181 |
"image": img,
|
182 |
+
"choices": {
|
183 |
+
option.option_description: asyncio.create_task(
|
184 |
+
process_user_input(
|
185 |
+
get_user_story(
|
186 |
+
response.game_message,
|
187 |
+
response.change_scene.scene_description,
|
188 |
+
option.option_description,
|
189 |
+
)
|
190 |
+
)
|
191 |
+
) if settings.pregenerate_next_scene else None
|
192 |
+
for option in response.player_options
|
193 |
+
},
|
194 |
+
"music_tone": response.music_prompt,
|
195 |
+
"img_description": img_description,
|
196 |
}
|
197 |
state["scene"] = "start"
|
198 |
|
src/game_setting.py
CHANGED
@@ -1,12 +1,25 @@
|
|
1 |
from pydantic import BaseModel
|
2 |
|
|
|
3 |
class Character(BaseModel):
|
4 |
name: str
|
5 |
age: str
|
6 |
background: str
|
7 |
personality: str
|
8 |
|
|
|
9 |
class GameSetting(BaseModel):
|
10 |
character: Character
|
11 |
setting: str
|
12 |
genre: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from pydantic import BaseModel
|
2 |
|
3 |
+
|
4 |
class Character(BaseModel):
|
5 |
name: str
|
6 |
age: str
|
7 |
background: str
|
8 |
personality: str
|
9 |
|
10 |
+
|
11 |
class GameSetting(BaseModel):
|
12 |
character: Character
|
13 |
setting: str
|
14 |
genre: str
|
15 |
+
|
16 |
+
|
17 |
+
def get_user_story(
|
18 |
+
scene_description: str, scene_image_description: str, user_choice: str
|
19 |
+
) -> str:
|
20 |
+
return f"""Current scene description:
|
21 |
+
{scene_description}
|
22 |
+
Current scene image description: {scene_image_description}
|
23 |
+
|
24 |
+
User's choice: {user_choice}
|
25 |
+
"""
|
src/game_state.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
-
|
2 |
story = {
|
3 |
"start": {
|
4 |
"text": "You wake up in a mysterious forest. What do you do?",
|
5 |
"image": "forest.jpg",
|
6 |
-
"choices":
|
7 |
"music_tone": "neutral",
|
|
|
8 |
},
|
9 |
}
|
10 |
|
@@ -12,4 +12,4 @@ state = {"scene": "start"}
|
|
12 |
|
13 |
def get_current_scene():
|
14 |
scene = story[state["scene"]]
|
15 |
-
return scene["text"], scene["image"], scene["choices"]
|
|
|
|
|
1 |
story = {
|
2 |
"start": {
|
3 |
"text": "You wake up in a mysterious forest. What do you do?",
|
4 |
"image": "forest.jpg",
|
5 |
+
"choices": {"Explore": None, "Wait": None},
|
6 |
"music_tone": "neutral",
|
7 |
+
"img_description": "forest in the fog",
|
8 |
},
|
9 |
}
|
10 |
|
|
|
12 |
|
13 |
def get_current_scene():
|
14 |
scene = story[state["scene"]]
|
15 |
+
return scene["text"], scene["image"], scene["choices"].keys()
|
src/images/image_generator.py
CHANGED
@@ -6,25 +6,47 @@ from io import BytesIO
|
|
6 |
from datetime import datetime
|
7 |
from config import settings
|
8 |
import logging
|
|
|
|
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
12 |
client = genai.Client(api_key=settings.gemini_api_key.get_secret_value()).aio
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
async def generate_image(prompt: str) -> tuple[str, str] | None:
|
15 |
"""
|
16 |
Generate an image using Google's Gemini model and save it to generated/images directory.
|
17 |
-
|
18 |
Args:
|
19 |
prompt (str): The text prompt to generate the image from
|
20 |
-
|
21 |
Returns:
|
22 |
str: Path to the generated image file, or None if generation failed
|
23 |
"""
|
24 |
# Ensure the generated/images directory exists
|
25 |
output_dir = "generated/images"
|
26 |
os.makedirs(output_dir, exist_ok=True)
|
27 |
-
|
28 |
logger.info(f"Generating image with prompt: {prompt}")
|
29 |
|
30 |
try:
|
@@ -32,8 +54,9 @@ async def generate_image(prompt: str) -> tuple[str, str] | None:
|
|
32 |
model="gemini-2.0-flash-preview-image-generation",
|
33 |
contents=prompt,
|
34 |
config=types.GenerateContentConfig(
|
35 |
-
response_modalities=[
|
36 |
-
|
|
|
37 |
)
|
38 |
|
39 |
# Process the response parts
|
@@ -44,19 +67,20 @@ async def generate_image(prompt: str) -> tuple[str, str] | None:
|
|
44 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
45 |
filename = f"gemini_{timestamp}.png"
|
46 |
filepath = os.path.join(output_dir, filename)
|
47 |
-
|
48 |
# Save the image
|
49 |
image = Image.open(BytesIO(part.inline_data.data))
|
50 |
-
image.save
|
51 |
logger.info(f"Image saved to: {filepath}")
|
52 |
image_saved = True
|
53 |
-
|
54 |
-
return filepath,
|
55 |
-
|
56 |
if not image_saved:
|
|
|
57 |
logger.error("No image was generated in the response.")
|
58 |
return None, None
|
59 |
-
|
60 |
except Exception as e:
|
61 |
logger.error(f"Error generating image: {e}")
|
62 |
return None, None
|
@@ -65,38 +89,41 @@ async def generate_image(prompt: str) -> tuple[str, str] | None:
|
|
65 |
async def modify_image(image_path: str, modification_prompt: str) -> str | None:
|
66 |
"""
|
67 |
Modify an existing image using Google's Gemini model based on a text prompt.
|
68 |
-
|
69 |
Args:
|
70 |
image_path (str): Path to the existing image file
|
71 |
modification_prompt (str): The text prompt describing how to modify the image
|
72 |
-
|
73 |
Returns:
|
74 |
str: Path to the modified image file, or None if modification failed
|
75 |
"""
|
76 |
# Ensure the generated/images directory exists
|
77 |
output_dir = "generated/images"
|
78 |
os.makedirs(output_dir, exist_ok=True)
|
79 |
-
|
|
|
|
|
80 |
# Check if the input image exists
|
81 |
if not os.path.exists(image_path):
|
82 |
logger.error(f"Error: Image file not found at {image_path}")
|
83 |
return None
|
84 |
-
|
85 |
key = settings.gemini_api_key.get_secret_value()
|
86 |
-
|
87 |
client = genai.Client(api_key=key).aio
|
88 |
|
89 |
try:
|
90 |
# Load the input image
|
91 |
input_image = Image.open(image_path)
|
92 |
-
|
93 |
# Make the API call with both text and image
|
94 |
response = await client.models.generate_content(
|
95 |
model="gemini-2.0-flash-preview-image-generation",
|
96 |
contents=[modification_prompt, input_image],
|
97 |
config=types.GenerateContentConfig(
|
98 |
-
response_modalities=[
|
99 |
-
|
|
|
100 |
)
|
101 |
|
102 |
# Process the response parts
|
@@ -107,19 +134,20 @@ async def modify_image(image_path: str, modification_prompt: str) -> str | None:
|
|
107 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
108 |
filename = f"gemini_modified_{timestamp}.png"
|
109 |
filepath = os.path.join(output_dir, filename)
|
110 |
-
|
111 |
# Save the modified image
|
112 |
modified_image = Image.open(BytesIO(part.inline_data.data))
|
113 |
-
modified_image.save
|
114 |
logger.info(f"Modified image saved to: {filepath}")
|
115 |
image_saved = True
|
116 |
-
|
117 |
-
return filepath,
|
118 |
-
|
119 |
if not image_saved:
|
|
|
120 |
logger.error("No modified image was generated in the response.")
|
121 |
return None, None
|
122 |
-
|
123 |
except Exception as e:
|
124 |
logger.error(f"Error modifying image: {e}")
|
125 |
return None, None
|
@@ -129,10 +157,10 @@ if __name__ == "__main__":
|
|
129 |
# Example usage
|
130 |
sample_prompt = "A Luke Skywalker half height sprite with white background for visual novel game"
|
131 |
generated_image_path = generate_image(sample_prompt)
|
132 |
-
|
133 |
# if generated_image_path:
|
134 |
# # Example modification
|
135 |
# modification_prompt = "Now the house is destroyed, and the jawas are running away"
|
136 |
# modified_image_path = modify_image(generated_image_path, modification_prompt)
|
137 |
# if modified_image_path:
|
138 |
-
# print(f"Successfully modified image: {modified_image_path}")
|
|
|
6 |
from datetime import datetime
|
7 |
from config import settings
|
8 |
import logging
|
9 |
+
import asyncio
|
10 |
+
import gradio as gr
|
11 |
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
14 |
client = genai.Client(api_key=settings.gemini_api_key.get_secret_value()).aio
|
15 |
|
16 |
+
safety_settings = [
|
17 |
+
types.SafetySetting(
|
18 |
+
category="HARM_CATEGORY_HARASSMENT",
|
19 |
+
threshold="BLOCK_NONE", # Block none
|
20 |
+
),
|
21 |
+
types.SafetySetting(
|
22 |
+
category="HARM_CATEGORY_HATE_SPEECH",
|
23 |
+
threshold="BLOCK_NONE", # Block none
|
24 |
+
),
|
25 |
+
types.SafetySetting(
|
26 |
+
category="HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
27 |
+
threshold="BLOCK_NONE", # Block none
|
28 |
+
),
|
29 |
+
types.SafetySetting(
|
30 |
+
category="HARM_CATEGORY_DANGEROUS_CONTENT",
|
31 |
+
threshold="BLOCK_NONE", # Block none
|
32 |
+
),
|
33 |
+
]
|
34 |
+
|
35 |
+
|
36 |
async def generate_image(prompt: str) -> tuple[str, str] | None:
|
37 |
"""
|
38 |
Generate an image using Google's Gemini model and save it to generated/images directory.
|
39 |
+
|
40 |
Args:
|
41 |
prompt (str): The text prompt to generate the image from
|
42 |
+
|
43 |
Returns:
|
44 |
str: Path to the generated image file, or None if generation failed
|
45 |
"""
|
46 |
# Ensure the generated/images directory exists
|
47 |
output_dir = "generated/images"
|
48 |
os.makedirs(output_dir, exist_ok=True)
|
49 |
+
|
50 |
logger.info(f"Generating image with prompt: {prompt}")
|
51 |
|
52 |
try:
|
|
|
54 |
model="gemini-2.0-flash-preview-image-generation",
|
55 |
contents=prompt,
|
56 |
config=types.GenerateContentConfig(
|
57 |
+
response_modalities=["TEXT", "IMAGE"],
|
58 |
+
safety_settings=safety_settings,
|
59 |
+
),
|
60 |
)
|
61 |
|
62 |
# Process the response parts
|
|
|
67 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
68 |
filename = f"gemini_{timestamp}.png"
|
69 |
filepath = os.path.join(output_dir, filename)
|
70 |
+
|
71 |
# Save the image
|
72 |
image = Image.open(BytesIO(part.inline_data.data))
|
73 |
+
await asyncio.to_thread(image.save, filepath, "PNG")
|
74 |
logger.info(f"Image saved to: {filepath}")
|
75 |
image_saved = True
|
76 |
+
|
77 |
+
return filepath, prompt
|
78 |
+
|
79 |
if not image_saved:
|
80 |
+
gr.Warning("Image was censored by Google!")
|
81 |
logger.error("No image was generated in the response.")
|
82 |
return None, None
|
83 |
+
|
84 |
except Exception as e:
|
85 |
logger.error(f"Error generating image: {e}")
|
86 |
return None, None
|
|
|
89 |
async def modify_image(image_path: str, modification_prompt: str) -> str | None:
|
90 |
"""
|
91 |
Modify an existing image using Google's Gemini model based on a text prompt.
|
92 |
+
|
93 |
Args:
|
94 |
image_path (str): Path to the existing image file
|
95 |
modification_prompt (str): The text prompt describing how to modify the image
|
96 |
+
|
97 |
Returns:
|
98 |
str: Path to the modified image file, or None if modification failed
|
99 |
"""
|
100 |
# Ensure the generated/images directory exists
|
101 |
output_dir = "generated/images"
|
102 |
os.makedirs(output_dir, exist_ok=True)
|
103 |
+
|
104 |
+
logger.info(f"Modifying current scene image with prompt: {modification_prompt}")
|
105 |
+
|
106 |
# Check if the input image exists
|
107 |
if not os.path.exists(image_path):
|
108 |
logger.error(f"Error: Image file not found at {image_path}")
|
109 |
return None
|
110 |
+
|
111 |
key = settings.gemini_api_key.get_secret_value()
|
112 |
+
|
113 |
client = genai.Client(api_key=key).aio
|
114 |
|
115 |
try:
|
116 |
# Load the input image
|
117 |
input_image = Image.open(image_path)
|
118 |
+
|
119 |
# Make the API call with both text and image
|
120 |
response = await client.models.generate_content(
|
121 |
model="gemini-2.0-flash-preview-image-generation",
|
122 |
contents=[modification_prompt, input_image],
|
123 |
config=types.GenerateContentConfig(
|
124 |
+
response_modalities=["TEXT", "IMAGE"],
|
125 |
+
safety_settings=safety_settings,
|
126 |
+
),
|
127 |
)
|
128 |
|
129 |
# Process the response parts
|
|
|
134 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
135 |
filename = f"gemini_modified_{timestamp}.png"
|
136 |
filepath = os.path.join(output_dir, filename)
|
137 |
+
|
138 |
# Save the modified image
|
139 |
modified_image = Image.open(BytesIO(part.inline_data.data))
|
140 |
+
await asyncio.to_thread(modified_image.save, filepath, "PNG")
|
141 |
logger.info(f"Modified image saved to: {filepath}")
|
142 |
image_saved = True
|
143 |
+
|
144 |
+
return filepath, modification_prompt
|
145 |
+
|
146 |
if not image_saved:
|
147 |
+
gr.Warning("Updated image was censored by Google!")
|
148 |
logger.error("No modified image was generated in the response.")
|
149 |
return None, None
|
150 |
+
|
151 |
except Exception as e:
|
152 |
logger.error(f"Error modifying image: {e}")
|
153 |
return None, None
|
|
|
157 |
# Example usage
|
158 |
sample_prompt = "A Luke Skywalker half height sprite with white background for visual novel game"
|
159 |
generated_image_path = generate_image(sample_prompt)
|
160 |
+
|
161 |
# if generated_image_path:
|
162 |
# # Example modification
|
163 |
# modification_prompt = "Now the house is destroyed, and the jawas are running away"
|
164 |
# modified_image_path = modify_image(generated_image_path, modification_prompt)
|
165 |
# if modified_image_path:
|
166 |
+
# print(f"Successfully modified image: {modified_image_path}")
|
src/main.py
CHANGED
@@ -7,7 +7,7 @@ from audio.audio_generator import (
|
|
7 |
)
|
8 |
import logging
|
9 |
from agent.llm_agent import process_user_input
|
10 |
-
from images.image_generator import
|
11 |
import uuid
|
12 |
from game_state import story, state
|
13 |
from game_constructor import (
|
@@ -19,6 +19,8 @@ from game_constructor import (
|
|
19 |
start_game_with_settings,
|
20 |
)
|
21 |
import asyncio
|
|
|
|
|
22 |
|
23 |
logger = logging.getLogger(__name__)
|
24 |
|
@@ -43,29 +45,53 @@ async def update_scene(user_hash: str, choice):
|
|
43 |
}
|
44 |
state["scene"] = new_scene
|
45 |
|
46 |
-
user_story =
|
47 |
-
|
48 |
-
|
49 |
-
"""
|
50 |
|
51 |
-
response = await
|
|
|
|
|
52 |
|
53 |
story[new_scene]["text"] = response.game_message
|
54 |
|
55 |
-
story[new_scene]["choices"] =
|
56 |
-
option.option_description
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
# run both tasks in parallel
|
60 |
img_res, _ = await asyncio.gather(
|
61 |
-
|
62 |
-
change_music_tone(user_hash, response.change_music.music_description) if response.change_music.change_music else asyncio.sleep(0)
|
63 |
)
|
64 |
-
|
65 |
if img_res and response.change_scene.change_scene:
|
66 |
-
img_path,
|
67 |
if img_path:
|
68 |
story[new_scene]["image"] = img_path
|
|
|
69 |
|
70 |
scene = story[state["scene"]]
|
71 |
return (
|
@@ -136,7 +162,7 @@ with gr.Blocks(
|
|
136 |
# Fullscreen Loading Indicator (hidden by default)
|
137 |
with gr.Column(visible=False, elem_id="loading-indicator") as loading_indicator:
|
138 |
gr.HTML("<div class='loading-text'>🚀 Starting your adventure...</div>")
|
139 |
-
|
140 |
local_storage = gr.BrowserState(str(uuid.uuid4()), "user_hash")
|
141 |
|
142 |
# Constructor Interface (visible by default)
|
|
|
7 |
)
|
8 |
import logging
|
9 |
from agent.llm_agent import process_user_input
|
10 |
+
from images.image_generator import modify_image
|
11 |
import uuid
|
12 |
from game_state import story, state
|
13 |
from game_constructor import (
|
|
|
19 |
start_game_with_settings,
|
20 |
)
|
21 |
import asyncio
|
22 |
+
from game_setting import get_user_story
|
23 |
+
from config import settings
|
24 |
|
25 |
logger = logging.getLogger(__name__)
|
26 |
|
|
|
45 |
}
|
46 |
state["scene"] = new_scene
|
47 |
|
48 |
+
user_story = get_user_story(
|
49 |
+
story[old_scene]["text"], story[old_scene]["img_description"], choice
|
50 |
+
)
|
|
|
51 |
|
52 |
+
response = await (
|
53 |
+
story[old_scene]["choices"][choice] or process_user_input(user_story)
|
54 |
+
)
|
55 |
|
56 |
story[new_scene]["text"] = response.game_message
|
57 |
|
58 |
+
story[new_scene]["choices"] = {
|
59 |
+
option.option_description: asyncio.create_task(
|
60 |
+
process_user_input(
|
61 |
+
get_user_story(
|
62 |
+
response.game_message,
|
63 |
+
response.change_scene.scene_description,
|
64 |
+
option.option_description,
|
65 |
+
)
|
66 |
+
)
|
67 |
+
)
|
68 |
+
if settings.pregenerate_next_scene
|
69 |
+
else None
|
70 |
+
for option in response.player_options
|
71 |
+
}
|
72 |
+
|
73 |
+
img_task = None
|
74 |
+
# always modify the image to avoid hallucinations in which image is being generated in entirely different style
|
75 |
+
if (
|
76 |
+
response.change_scene.change_scene == "change_completely"
|
77 |
+
or response.change_scene.change_scene == "modify"
|
78 |
+
):
|
79 |
+
img_task = modify_image(
|
80 |
+
story[old_scene]["image"], response.change_scene.scene_description
|
81 |
+
)
|
82 |
+
else:
|
83 |
+
img_task = asyncio.sleep(0)
|
84 |
+
|
85 |
# run both tasks in parallel
|
86 |
img_res, _ = await asyncio.gather(
|
87 |
+
img_task, change_music_tone(user_hash, response.music_prompt)
|
|
|
88 |
)
|
89 |
+
|
90 |
if img_res and response.change_scene.change_scene:
|
91 |
+
img_path, img_description = img_res
|
92 |
if img_path:
|
93 |
story[new_scene]["image"] = img_path
|
94 |
+
story[new_scene]["img_description"] = img_description
|
95 |
|
96 |
scene = story[state["scene"]]
|
97 |
return (
|
|
|
162 |
# Fullscreen Loading Indicator (hidden by default)
|
163 |
with gr.Column(visible=False, elem_id="loading-indicator") as loading_indicator:
|
164 |
gr.HTML("<div class='loading-text'>🚀 Starting your adventure...</div>")
|
165 |
+
|
166 |
local_storage = gr.BrowserState(str(uuid.uuid4()), "user_hash")
|
167 |
|
168 |
# Constructor Interface (visible by default)
|