Spaces:
Sleeping
Sleeping
# Copyright 2025 Google LLC. | |
# Based on work by Yousif Ahmed. | |
# Concept: ChronoWeave β Branching Narrative Generation | |
# Licensed under the Apache License, Version 2.0 (the "License"). | |
# You may not use this file except in compliance with the License. | |
# Obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0 | |
import os | |
import json | |
import time | |
import uuid | |
import asyncio | |
import logging | |
import shutil | |
import contextlib | |
import wave | |
from io import BytesIO | |
from typing import List, Optional, Tuple, Dict, Any | |
import streamlit as st | |
import numpy as np | |
from PIL import Image | |
# Pydantic for data validation | |
from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator | |
# Video and audio processing | |
from moviepy.editor import ImageClip, AudioFileClip, concatenate_videoclips | |
# Google Generative AI library and async patch | |
import google.generativeai as genai | |
import nest_asyncio | |
nest_asyncio.apply() # Ensure asyncio works correctly in Streamlit/Jupyter | |
# --- Logging Setup --- | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger(__name__) | |
# --- Constants & Configurations --- | |
TEXT_MODEL_ID = "models/gemini-1.5-flash" # Alternatively "gemini-1.5-pro" | |
AUDIO_MODEL_ID = "models/gemini-1.5-flash" # Synchronous generation for audio now | |
AUDIO_SAMPLING_RATE = 24000 | |
IMAGE_MODEL_ID = "imagen-3" # NOTE: Requires Vertex AI SDK integration in the future | |
DEFAULT_ASPECT_RATIO = "1:1" | |
VIDEO_FPS = 24 | |
VIDEO_CODEC = "libx264" | |
AUDIO_CODEC = "aac" | |
TEMP_DIR_BASE = ".chrono_temp" | |
# --- Pydantic Schemas --- | |
class StorySegment(BaseModel): | |
scene_id: int = Field(..., ge=0) | |
image_prompt: str = Field(..., min_length=10, max_length=250) | |
audio_text: str = Field(..., min_length=5, max_length=150) | |
character_description: str = Field(..., max_length=250) | |
timeline_visual_modifier: Optional[str] = Field(None, max_length=50) | |
def image_prompt_no_humans(cls, v: str) -> str: | |
if any(word in v.lower() for word in ["person", "people", "human", "man", "woman", "boy", "girl", "child"]): | |
logger.warning(f"Image prompt '{v[:50]}...' may include human-related descriptions.") | |
return v | |
class Timeline(BaseModel): | |
timeline_id: int = Field(..., ge=0) | |
divergence_reason: str = Field(..., min_length=5) | |
segments: List[StorySegment] = Field(..., min_items=1) | |
class ChronoWeaveResponse(BaseModel): | |
core_theme: str = Field(..., min_length=5) | |
timelines: List[Timeline] = Field(..., min_items=1) | |
total_scenes_per_timeline: int = Field(..., gt=0) | |
def check_timeline_segment_count(self) -> "ChronoWeaveResponse": | |
expected = self.total_scenes_per_timeline | |
for i, timeline in enumerate(self.timelines): | |
if len(timeline.segments) != expected: | |
raise ValueError(f"Timeline {i} (ID: {timeline.timeline_id}): Expected {expected} segments, got {len(timeline.segments)}.") | |
return self | |
# --- Helper Functions --- | |
def wave_file_writer(filename: str, channels: int = 1, rate: int = AUDIO_SAMPLING_RATE, sample_width: int = 2): | |
""" | |
Safely writes a WAV file using a context manager. | |
""" | |
wf = None | |
try: | |
wf = wave.open(filename, "wb") | |
wf.setnchannels(channels) | |
wf.setsampwidth(sample_width) # 16-bit audio (2 bytes) | |
wf.setframerate(rate) | |
yield wf | |
except Exception as exc: | |
logger.error(f"Error writing wave file {filename}: {exc}") | |
raise | |
finally: | |
if wf: | |
try: | |
wf.close() | |
except Exception as e_close: | |
logger.error(f"Error closing wave file {filename}: {e_close}") | |
# --- ChronoWeave Generator Class --- | |
class ChronoWeaveGenerator: | |
""" | |
Encapsulates the logic for generating branching narratives, | |
processing audio, images, and assembling video outputs. | |
""" | |
def __init__(self, api_key: str): | |
self.api_key = api_key | |
genai.configure(api_key=self.api_key) | |
try: | |
self.client_text = genai.GenerativeModel(TEXT_MODEL_ID) | |
logger.info(f"Initialized text model: {TEXT_MODEL_ID}") | |
self.client_audio = genai.GenerativeModel(AUDIO_MODEL_ID) | |
logger.info(f"Initialized audio model: {AUDIO_MODEL_ID}") | |
self.client_image = genai.GenerativeModel(IMAGE_MODEL_ID) | |
logger.info(f"Initialized image model: {IMAGE_MODEL_ID} (Placeholder: Update to Vertex AI SDK)") | |
except Exception as exc: | |
logger.exception("Failed to initialize Google Clients/Models.") | |
raise exc | |
def generate_story_structure( | |
self, theme: str, num_scenes: int, num_timelines: int, divergence_prompt: str = "" | |
) -> Optional[ChronoWeaveResponse]: | |
""" | |
Generates a story structure as JSON using the text model and validates it via Pydantic. | |
""" | |
st.info(f"Generating {num_timelines} timeline(s) with {num_scenes} scene(s) for theme: '{theme}'") | |
logger.info(f"Story generation request: Theme='{theme}', Timelines={num_timelines}, Scenes={num_scenes}") | |
divergence_instruction = ( | |
f"Introduce clear divergence after the first scene. Hint: '{divergence_prompt}'. " | |
f"For timeline_id 0, use 'Initial path' or 'Baseline scenario'." | |
) | |
prompt = f"""Act as a narrative designer. Create a story for the theme: "{theme}". | |
Instructions: | |
1. Exactly **{num_timelines}** timelines. | |
2. Each timeline must consist of exactly **{num_scenes}** scenes. | |
3. **NO humans/humanoids**; focus on animals, fantasy creatures, animated objects, and nature. | |
4. {divergence_instruction} | |
5. Style: **'Simple, friendly kids animation, bright colors, rounded shapes'** unless modified by `timeline_visual_modifier`. | |
6. `audio_text`: One concise sentence (max 30 words). | |
7. `image_prompt`: Descriptive prompt (15β35 words) emphasizing scene elements. **Avoid repeating general style.** | |
8. `character_description`: Very brief (name and features; < 20 words). | |
Output only a valid JSON object conforming exactly to this schema: | |
JSON Schema: ```json | |
{json.dumps(ChronoWeaveResponse.model_json_schema(), indent=2)} | |
```""" | |
try: | |
response = self.client_text.generate_content( | |
contents=prompt, | |
generation_config=genai.types.GenerationConfig( | |
response_mime_type="application/json", temperature=0.7 | |
), | |
) | |
raw_data = json.loads(response.text) | |
validated_data = ChronoWeaveResponse.model_validate(raw_data) | |
st.success("Story structure validated successfully!") | |
return validated_data | |
except json.JSONDecodeError as json_err: | |
logger.error(f"JSON decode failed: {json_err}\nResponse: {response.text}") | |
st.error(f"π¨ JSON Parsing Error: {json_err}", icon="π") | |
st.text_area("Response", response.text, height=150) | |
except ValidationError as val_err: | |
logger.error(f"Pydantic validation error: {val_err}\nData: {json.dumps(raw_data, indent=2)}") | |
st.error(f"π¨ Invalid story structure: {val_err}", icon="π§¬") | |
st.json(raw_data) | |
except Exception as e: | |
logger.exception("Story generation error:") | |
st.error(f"π¨ Error generating story: {e}", icon="π₯") | |
return None | |
async def generate_audio(self, text: str, output_filename: str, voice: Optional[str] = None) -> Optional[str]: | |
""" | |
Asynchronously generates audio by wrapping the synchronous generate_content call. | |
The call is executed using asyncio.to_thread to avoid blocking. | |
""" | |
task_id = os.path.basename(output_filename).split(".")[0] | |
logger.info(f"ποΈ [{task_id}] Generating audio for text: '{text[:60]}...'") | |
try: | |
# Define a synchronous function for audio generation. | |
def sync_generate_audio(): | |
prompt = f"Narrate directly: \"{text}\"" | |
response = self.client_audio.generate_content( | |
contents=prompt, | |
generation_config=genai.types.GenerationConfig( | |
response_mime_type="application/octet-stream", | |
temperature=0.7, | |
audio_config={"audio_encoding": "LINEAR16", "sample_rate_hertz": AUDIO_SAMPLING_RATE} | |
) | |
) | |
return response | |
# Execute the synchronous call in a separate thread. | |
response = await asyncio.to_thread(sync_generate_audio) | |
# Process the response. Adjust as necessary based on the APIβs actual response structure. | |
if not response or not hasattr(response, "audio_chunk") or not response.audio_chunk.data: | |
logger.error(f"β [{task_id}] No audio data returned.") | |
st.error(f"Audio generation failed for {task_id}: No audio data.", icon="π") | |
return None | |
audio_data = response.audio_chunk.data | |
with wave_file_writer(output_filename) as wf: | |
wf.writeframes(audio_data) | |
logger.info(f"β [{task_id}] Audio saved: {os.path.basename(output_filename)} ({len(audio_data)} bytes)") | |
return output_filename | |
except Exception as e: | |
logger.exception(f"β [{task_id}] Audio generation error: {e}") | |
st.error(f"Audio generation failed for {task_id}: {e}", icon="π") | |
return None | |
async def generate_image_async(self, prompt: str, aspect_ratio: str, task_id: str) -> Optional[Image.Image]: | |
""" | |
Placeholder for image generation. | |
Currently logs an error and returns None. Update this function once integrating Vertex AI SDK. | |
""" | |
logger.info(f"πΌοΈ [{task_id}] Requesting image for prompt: '{prompt[:70]}...' (Aspect Ratio: {aspect_ratio})") | |
logger.error(f"β [{task_id}] Image generation not implemented. Update required for Vertex AI.") | |
st.error(f"Image generation for {task_id} skipped: Requires Vertex AI SDK implementation.", icon="πΌοΈ") | |
return None | |
async def process_scene( | |
self, | |
timeline_id: int, | |
segment: StorySegment, | |
temp_dir: str, | |
aspect_ratio: str, | |
audio_voice: Optional[str] = None, | |
) -> Tuple[Optional[str], Optional[str], Optional[Any], List[str]]: | |
""" | |
Processes a single scene: concurrently generates image and audio, | |
and then creates a video clip if both outputs are available. | |
Returns a tuple of (image_path, audio_path, video_clip, [error messages]). | |
""" | |
errors: List[str] = [] | |
task_id = f"T{timeline_id}_S{segment.scene_id}" | |
image_path = os.path.join(temp_dir, f"{task_id}_image.png") | |
audio_path = os.path.join(temp_dir, f"{task_id}_audio.wav") | |
video_clip = None | |
# Launch image and audio generation concurrently. | |
image_future = asyncio.create_task( | |
self.generate_image_async( | |
prompt=f"{segment.image_prompt} Featuring: {segment.character_description} " + | |
(f"Style hint: {segment.timeline_visual_modifier}" if segment.timeline_visual_modifier else ""), | |
aspect_ratio=aspect_ratio, | |
task_id=task_id, | |
) | |
) | |
audio_future = asyncio.create_task(self.generate_audio(segment.audio_text, audio_path, audio_voice)) | |
image_result, audio_result = await asyncio.gather(image_future, audio_future) | |
if image_result: | |
try: | |
image_result.save(image_path) | |
st.image(image_result, width=180, caption=f"Scene {segment.scene_id + 1}") | |
except Exception as e: | |
logger.error(f"β [{task_id}] Error saving image: {e}") | |
errors.append(f"Scene {segment.scene_id + 1}: Image save error.") | |
else: | |
errors.append(f"Scene {segment.scene_id + 1}: Image generation failed.") | |
if audio_result: | |
try: | |
with open(audio_result, "rb") as ap: | |
st.audio(ap.read(), format="audio/wav") | |
except Exception as e: | |
logger.warning(f"β οΈ [{task_id}] Audio preview error: {e}") | |
else: | |
errors.append(f"Scene {segment.scene_id + 1}: Audio generation failed.") | |
if not errors and os.path.exists(image_path) and os.path.exists(audio_path): | |
try: | |
audio_clip = AudioFileClip(audio_path) | |
np_img = np.array(Image.open(image_path)) | |
img_clip = ImageClip(np_img).set_duration(audio_clip.duration) | |
video_clip = img_clip.set_audio(audio_clip) | |
logger.info(f"β [{task_id}] Video clip created (Duration: {audio_clip.duration:.2f}s).") | |
except Exception as e: | |
logger.exception(f"β [{task_id}] Failed to create video clip: {e}") | |
errors.append(f"Scene {segment.scene_id + 1}: Video clip creation failed.") | |
finally: | |
try: | |
if 'audio_clip' in locals(): | |
audio_clip.close() | |
if 'img_clip' in locals(): | |
img_clip.close() | |
except Exception: | |
pass | |
return ( | |
image_path if os.path.exists(image_path) else None, | |
audio_path if os.path.exists(audio_path) else None, | |
video_clip, | |
errors, | |
) | |
async def process_timeline( | |
self, | |
timeline: Timeline, | |
temp_dir: str, | |
aspect_ratio: str, | |
audio_voice: Optional[str] = None, | |
) -> Tuple[Optional[str], List[str]]: | |
""" | |
Processes an entire timeline by concurrently processing all its scenes, | |
then assembling a final video if every scene produced a valid clip. | |
Returns a tuple of (final video path, list of error messages). | |
""" | |
timeline_id = timeline.timeline_id | |
scene_tasks = [ | |
self.process_scene(timeline_id, segment, temp_dir, aspect_ratio, audio_voice) | |
for segment in timeline.segments | |
] | |
results = await asyncio.gather(*scene_tasks) | |
video_clips = [] | |
timeline_errors: List[str] = [] | |
for idx, (img_path, aud_path, clip, errs) in enumerate(results): | |
if errs: | |
timeline_errors.extend(errs) | |
if clip is not None: | |
video_clips.append(clip) | |
if video_clips and len(video_clips) == len(timeline.segments): | |
output_filename = os.path.join(temp_dir, f"timeline_{timeline_id}_final.mp4") | |
try: | |
final_video = concatenate_videoclips(video_clips, method="compose") | |
final_video.write_videofile( | |
output_filename, fps=VIDEO_FPS, codec=VIDEO_CODEC, audio_codec=AUDIO_CODEC, logger=None | |
) | |
logger.info(f"β Timeline {timeline_id} video saved: {output_filename}") | |
for clip in video_clips: | |
clip.close() | |
final_video.close() | |
return output_filename, timeline_errors | |
except Exception as e: | |
logger.exception(f"β Timeline {timeline_id} video assembly failed: {e}") | |
timeline_errors.append(f"Timeline {timeline_id}: Video assembly failed.") | |
else: | |
timeline_errors.append(f"Timeline {timeline_id}: Incomplete scenes; skipping video assembly.") | |
return None, timeline_errors | |
# --- Streamlit UI and Main Process --- | |
def main(): | |
# API Key Retrieval | |
GOOGLE_API_KEY: Optional[str] = None | |
try: | |
GOOGLE_API_KEY = st.secrets["GOOGLE_API_KEY"] | |
logger.info("Google API Key loaded from Streamlit secrets.") | |
except KeyError: | |
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") | |
if GOOGLE_API_KEY: | |
logger.info("Google API Key loaded from environment variable.") | |
else: | |
st.error("π¨ **Google API Key Not Found!** Please configure it.", icon="π¨") | |
st.stop() | |
st.set_page_config(page_title="ChronoWeave", layout="wide", initial_sidebar_state="expanded") | |
st.title("π ChronoWeave: Advanced Branching Narrative Generator") | |
st.markdown(""" | |
Generate multiple, branching story timelines from a single theme using AI β complete with images and narration. | |
*Based on work by Yousif Ahmed. Copyright 2025 Google LLC.* | |
""") | |
st.sidebar.header("βοΈ Configuration") | |
if GOOGLE_API_KEY: | |
st.sidebar.success("Google API Key Loaded", icon="β ") | |
else: | |
st.sidebar.error("Google API Key Missing!", icon="π¨") | |
theme = st.sidebar.text_input("π Story Theme:", "A curious squirrel finds a mysterious, glowing acorn") | |
num_scenes = st.sidebar.slider("π¬ Scenes per Timeline:", min_value=2, max_value=7, value=3) | |
num_timelines = st.sidebar.slider("πΏ Number of Timelines:", min_value=1, max_value=4, value=2) | |
divergence_prompt = st.sidebar.text_input("βοΈ Divergence Hint (Optional):", placeholder="e.g., What if a bird tried to steal it?") | |
st.sidebar.subheader("π¨ Visual & Audio Settings") | |
aspect_ratio = st.sidebar.selectbox("πΌοΈ Image Aspect Ratio:", ["1:1", "16:9", "9:16"], index=0) | |
audio_voice = None | |
generate_button = st.sidebar.button("β¨ Generate ChronoWeave β¨", type="primary", disabled=(not GOOGLE_API_KEY), use_container_width=True) | |
st.sidebar.markdown("---") | |
st.sidebar.info("β³ Generation may take several minutes.") | |
st.sidebar.markdown(f"<small>Txt: {TEXT_MODEL_ID}, Img: {IMAGE_MODEL_ID}, Aud: {AUDIO_MODEL_ID}</small>", unsafe_allow_html=True) | |
if generate_button: | |
if not theme: | |
st.error("Please enter a story theme.", icon="π") | |
return | |
run_id = str(uuid.uuid4()).split('-')[0] | |
temp_dir = os.path.join(TEMP_DIR_BASE, f"run_{run_id}") | |
try: | |
os.makedirs(temp_dir, exist_ok=True) | |
logger.info(f"Created temporary directory: {temp_dir}") | |
except OSError as e: | |
st.error(f"π¨ Failed to create temporary directory {temp_dir}: {e}", icon="π") | |
st.stop() | |
# Instantiate ChronoWeaveGenerator and generate story structure. | |
generator = ChronoWeaveGenerator(GOOGLE_API_KEY) | |
chrono_response = None | |
with st.spinner("Generating narrative structure... π€"): | |
chrono_response = generator.generate_story_structure(theme, num_scenes, num_timelines, divergence_prompt) | |
if not chrono_response: | |
logger.error("Story generation or validation failed.") | |
return | |
overall_start_time = time.time() | |
final_video_paths: Dict[int, str] = {} | |
generation_errors: Dict[int, List[str]] = {} | |
async def process_all_timelines(): | |
timeline_tasks = { | |
timeline.timeline_id: asyncio.create_task( | |
generator.process_timeline(timeline, temp_dir, aspect_ratio, audio_voice) | |
) | |
for timeline in chrono_response.timelines | |
} | |
results = await asyncio.gather(*timeline_tasks.values(), return_exceptions=False) | |
return results | |
with st.spinner("Processing scenes and assembling videos..."): | |
timeline_results = asyncio.run(process_all_timelines()) | |
for timeline, (video_path, errors) in zip(chrono_response.timelines, timeline_results): | |
generation_errors[timeline.timeline_id] = errors | |
if video_path: | |
final_video_paths[timeline.timeline_id] = video_path | |
overall_duration = time.time() - overall_start_time | |
if final_video_paths: | |
st.success(f"Complete! ({len(final_video_paths)} video(s) created in {overall_duration:.2f}s)") | |
else: | |
st.error(f"Failed. No final videos generated in {overall_duration:.2f}s") | |
st.header("π¬ Generated Timelines") | |
if final_video_paths: | |
sorted_ids = sorted(final_video_paths.keys()) | |
num_cols = min(len(sorted_ids), 3) | |
cols = st.columns(num_cols) | |
for idx, timeline_id in enumerate(sorted_ids): | |
video_path = final_video_paths[timeline_id] | |
timeline_data = next((t for t in chrono_response.timelines if t.timeline_id == timeline_id), None) | |
divergence = timeline_data.divergence_reason if timeline_data else "Unknown" | |
with cols[idx % num_cols]: | |
st.subheader(f"Timeline {timeline_id}") | |
st.caption(f"Divergence: {divergence}") | |
try: | |
with open(video_path, "rb") as vf: | |
video_bytes = vf.read() | |
st.video(video_bytes) | |
st.download_button( | |
f"Download Timeline {timeline_id}", | |
video_bytes, | |
file_name=f"timeline_{timeline_id}.mp4", | |
mime="video/mp4", | |
key=f"dl_{timeline_id}" | |
) | |
if generation_errors.get(timeline_id): | |
scene_errs = generation_errors[timeline_id] | |
if scene_errs: | |
with st.expander(f"β οΈ View Scene Issues ({len(scene_errs)})"): | |
for err in scene_errs: | |
st.warning(f"- {err}") | |
except FileNotFoundError: | |
st.error(f"Error: Video for Timeline {timeline_id} is missing.", icon="π¨") | |
except Exception as e: | |
st.error(f"Display error for Timeline {timeline_id}: {e}", icon="π¨") | |
else: | |
st.warning("No final videos were successfully generated.") | |
with st.expander("View All Generation Errors", expanded=True): | |
for tid, errs in generation_errors.items(): | |
if errs: | |
st.error(f"Timeline {tid}:") | |
for msg in errs: | |
st.error(f" - {msg}") | |
st.info(f"Cleaning up temporary files: {temp_dir}") | |
try: | |
shutil.rmtree(temp_dir) | |
st.success("β Temporary files cleaned up.") | |
logger.info(f"Temporary directory removed: {temp_dir}") | |
except Exception as e: | |
st.warning(f"Could not remove temporary files at: {temp_dir}", icon="β οΈ") | |
logger.error(f"Failed to remove temporary directory {temp_dir}: {e}") | |
else: | |
st.info("Configure settings and click 'β¨ Generate ChronoWeave β¨' to start.") | |
if __name__ == "__main__": | |
main() | |