Spaces:
Sleeping
Sleeping
import os | |
import asyncio | |
import textwrap | |
import re | |
import httpx | |
import gradio as gr | |
from PIL import Image | |
from io import BytesIO | |
from llama_index.core.workflow import ( | |
Workflow, | |
step, | |
StartEvent, | |
StopEvent, | |
Context, | |
Event, | |
) | |
from llama_index.llms.groq import Groq | |
# --- Secret Management --- | |
GROQ_API_KEY = os.environ.get("Groq_Token") | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
# --- Event and Workflow Definitions (Copied from our notebook) --- | |
class StoryContext(Event): | |
story_part: str | |
inventory: list[str] | |
is_new_scene: bool | |
class SceneReadyEvent(Event): | |
pass | |
class UserChoice(Event): | |
choice: str | |
class StoryEnd(Event): | |
final_message: str | |
# Helper function to generate an image and return its path | |
async def generate_image(prompt: str, hf_token: str) -> str | None: | |
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0" | |
headers = {"Authorization": f"Bearer {hf_token}"} | |
full_prompt = f"epic fantasy art, digital painting, cinematic lighting, masterpiece, {prompt}" | |
payload = {"inputs": full_prompt} | |
try: | |
async with httpx.AsyncClient() as client: | |
response = await client.post(API_URL, headers=headers, json=payload, timeout=180.0) | |
response.raise_for_status() | |
image = Image.open(BytesIO(response.content)) | |
# Save to a temporary file that Gradio can serve | |
image.save("scene_image.png") | |
return "scene_image.png" | |
except (httpx.TimeoutException, httpx.RequestError, IOError) as e: | |
print(f"Image generation failed: {e}") | |
return None | |
# The full workflow class, now with the required end_story step | |
class StorytellerWorkflow(Workflow): | |
def __init__(self, **kwargs): | |
super().__init__(timeout=300, **kwargs) | |
async def generate_story_part(self, ev: StartEvent | UserChoice, ctx: Context) -> StoryContext | StoryEnd: | |
inventory = await ctx.store.get("inventory", []) | |
prompt = "" | |
is_new_scene_flag = False | |
if isinstance(ev, StartEvent): | |
is_new_scene_flag = True | |
prompt = """ | |
You are a creative text adventure game master. Your output is for a console game. | |
Start a new story about a curious explorer entering a recently discovered, glowing cave. | |
Keep the tone mysterious and exciting. After the story part, provide two distinct choices for the player to make. | |
Format your response exactly like this: STORY: [The story text goes here] CHOICES: 1. [First choice] 2. [Second choice] | |
""" | |
elif isinstance(ev, UserChoice): | |
last_story_part = await ctx.store.get("last_story_part") | |
prompt = f""" | |
You are a creative text adventure game master. | |
The story so far: "{last_story_part}" | |
The player chose: "{ev.choice}" | |
The player's inventory: {inventory} | |
Continue the story. IMPORTANT: If the story describes moving to a new location, add the tag `[NEW_SCENE]` to the beginning of the STORY section. Otherwise, do not add the tag. | |
If a choice results in an item, use `[ADD_ITEM: item name]`. If the story should end, write "[END]". | |
Format your response exactly like this: STORY: [The story text goes here] CHOICES: 1. [First choice] 2. [Second choice] | |
""" | |
llm = Groq(model="llama3-8b-8192", api_key=GROQ_API_KEY) | |
response = await llm.acomplete(prompt) | |
response_text = str(response) | |
if "[NEW_SCENE]" in response_text: | |
is_new_scene_flag = True | |
response_text = response_text.replace("[NEW_SCENE]", "").strip() | |
items_found = re.findall(r"\[ADD_ITEM: (.*?)\]", response_text) | |
if items_found: | |
for item in items_found: | |
if item not in inventory: | |
inventory.append(item) | |
response_text = re.sub(r"\[ADD_ITEM: (.*?)\]", "", response_text).strip() | |
if response_text.strip().startswith("[END]"): | |
final_message = response_text.strip().replace("[END]", "") | |
return StoryEnd(final_message=f"\n--- THE END ---\n{final_message}") | |
try: | |
story_section = response_text.split("STORY:")[1].split("CHOICES:")[0].strip() | |
choices_section = response_text.split("CHOICES:")[1].strip() | |
full_story_part = f"{story_section}\n\nChoices:\n{choices_section}" | |
except IndexError: | |
full_story_part = "The story continues... but the path is blurry." | |
await ctx.store.set("last_story_part", full_story_part) | |
await ctx.store.set("inventory", inventory) | |
return StoryContext(story_part=full_story_part, inventory=inventory, is_new_scene=is_new_scene_flag) | |
# --- THIS IS THE STEP THAT WAS MISSING --- | |
def end_story(self, ev: StoryEnd) -> StopEvent: | |
"""This step satisfies the workflow validator by providing a path to a StopEvent.""" | |
return StopEvent(result=ev.final_message) | |
# These two steps are no longer needed, as Gradio's UI will handle the display logic. | |
# @step async def display_scene(...) | |
# @step async def get_user_choice(...) | |
# --- Gradio UI and Application Logic --- | |
async def run_turn(user_input, game_state): | |
# game_state holds our workflow instance | |
workflow = game_state | |
# On the first turn, initialize the workflow | |
if workflow is None: | |
workflow = StorytellerWorkflow() | |
event = StartEvent() | |
else: | |
# For subsequent turns, create a UserChoice event | |
event = UserChoice(choice=user_input) | |
# --- THIS IS THE CORRECTED LOGIC --- | |
# We use stream() to run the workflow until the next input is needed | |
final_event = None | |
async for ev in workflow.stream(event): | |
# The loop will run through generate_story_part and stop when it produces | |
# either a StoryEnd event or a StoryContext event, which triggers the next step. | |
# Since the next step (which we removed) would require user input, the stream pauses. | |
final_event = ev.data | |
# Process the result from the workflow stream | |
if isinstance(final_event, StoryEnd): | |
# We need to manually call the end_story step to get the final message | |
stop_event = workflow.end_story(final_event) | |
return None, stop_event.result, "", None # Reset the state | |
if isinstance(final_event, StoryContext): | |
narrative, choices = final_event.story_part.split("Choices:", 1) | |
story_display = f"{textwrap.fill(narrative, width=80)}\n\nChoices:{choices}" | |
image_path = None | |
if final_event.is_new_scene and HF_TOKEN: | |
image_path = await generate_image(narrative, HF_TOKEN) | |
inventory_text = f"**Inventory:** {', '.join(final_event.inventory) if final_event.inventory else 'Empty'}" | |
return image_path, story_display, inventory_text, workflow # Pass the workflow instance to the next turn | |
# Graceful fallback | |
return None, "An unexpected error occurred. The story cannot continue.", "", None | |
def create_demo(): | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
# State object to hold the workflow instance between turns | |
game_state = gr.State((None, None, "**Inventory:** Empty")) | |
gr.Markdown("# LlamaIndex Workflow: Dynamic Storyteller") | |
gr.Markdown("An AI-powered text adventure game where every scene can be illustrated by AI.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_display = gr.Image(label="Scene", interactive=False) | |
inventory_display = gr.Markdown("**Inventory:** Empty") | |
with gr.Column(scale=2): | |
story_display = gr.Textbox(label="Story", lines=15, interactive=False) | |
user_input = gr.Textbox(label="What do you do?", placeholder="Type your choice and press Enter...") | |
# When the user submits their choice | |
user_input.submit( | |
fn=run_turn, | |
inputs=[user_input, game_state], | |
outputs=[image_display, story_display, inventory_display, game_state] | |
) | |
# When the app loads for the first time | |
demo.load( | |
fn=run_turn, | |
inputs=[gr.State(None), game_state], # Pass empty input and state | |
outputs=[image_display, story_display, inventory_display, game_state] | |
) | |
return demo | |
if __name__ == "__main__": | |
if not GROQ_API_KEY or not HF_TOKEN: | |
print("ERROR: API keys not found. Make sure to set GROQ_API_KEY and HF_TOKEN in your Hugging Face Space Secrets.") | |
else: | |
app = create_demo() | |
app.launch() |