Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,345 +1,511 @@
|
|
1 |
-
# Copyright 2025 Google LLC.
|
2 |
-
#
|
3 |
-
#
|
4 |
-
#
|
5 |
-
# You may
|
|
|
6 |
|
7 |
-
import streamlit as st
|
8 |
-
import google.generativeai as genai
|
9 |
import os
|
10 |
import json
|
11 |
-
import numpy as np
|
12 |
-
from io import BytesIO
|
13 |
import time
|
14 |
-
import
|
15 |
-
import contextlib
|
16 |
import asyncio
|
17 |
-
import
|
18 |
-
import shutil
|
19 |
-
import
|
|
|
|
|
|
|
20 |
|
21 |
-
|
|
|
22 |
from PIL import Image
|
|
|
23 |
# Pydantic for data validation
|
24 |
from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator
|
25 |
-
from typing import List, Optional, Literal, Dict, Any
|
26 |
|
27 |
# Video and audio processing
|
28 |
from moviepy.editor import ImageClip, AudioFileClip, concatenate_videoclips
|
29 |
-
# from moviepy.config import change_settings # Potential for setting imagemagick path if needed
|
30 |
|
31 |
-
#
|
32 |
-
import
|
33 |
-
|
34 |
-
# Async support for Streamlit/Google API
|
35 |
import nest_asyncio
|
36 |
-
nest_asyncio.apply()
|
37 |
|
38 |
# --- Logging Setup ---
|
39 |
-
logging.basicConfig(level=logging.INFO, format=
|
40 |
logger = logging.getLogger(__name__)
|
41 |
|
42 |
-
# ---
|
43 |
-
|
44 |
-
|
45 |
-
st.markdown("""
|
46 |
-
Generate multiple, branching story timelines from a single theme using AI, complete with images and narration.
|
47 |
-
*Based on the work by Yousif Ahmed. Copyright 2025 Google LLC.*
|
48 |
-
""")
|
49 |
-
|
50 |
-
# --- Constants ---
|
51 |
-
# Text/JSON Model
|
52 |
-
TEXT_MODEL_ID = "models/gemini-1.5-flash" # Or "gemini-1.5-pro"
|
53 |
-
# Audio Model Config
|
54 |
-
AUDIO_MODEL_ID = "models/gemini-1.5-flash" # Model used for audio tasks
|
55 |
AUDIO_SAMPLING_RATE = 24000
|
56 |
-
#
|
57 |
-
IMAGE_MODEL_ID = "imagen-3" # <<< NOTE: Likely needs Vertex AI SDK access
|
58 |
DEFAULT_ASPECT_RATIO = "1:1"
|
59 |
-
# Video Config
|
60 |
VIDEO_FPS = 24
|
61 |
VIDEO_CODEC = "libx264"
|
62 |
AUDIO_CODEC = "aac"
|
63 |
-
# File Management
|
64 |
TEMP_DIR_BASE = ".chrono_temp"
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
try:
|
69 |
-
GOOGLE_API_KEY = st.secrets["GOOGLE_API_KEY"]
|
70 |
-
logger.info("Google API Key loaded from Streamlit secrets.")
|
71 |
-
except KeyError:
|
72 |
-
GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY')
|
73 |
-
if GOOGLE_API_KEY:
|
74 |
-
logger.info("Google API Key loaded from environment variable.")
|
75 |
-
else:
|
76 |
-
st.error("🚨 **Google API Key Not Found!** Please configure it.", icon="🚨"); st.stop()
|
77 |
-
|
78 |
-
# --- Initialize Google Clients ---
|
79 |
-
try:
|
80 |
-
genai.configure(api_key=GOOGLE_API_KEY)
|
81 |
-
logger.info("Configured google-generativeai with API key.")
|
82 |
-
client_standard = genai.GenerativeModel(TEXT_MODEL_ID)
|
83 |
-
logger.info(f"Initialized text/JSON model handle: {TEXT_MODEL_ID}.")
|
84 |
-
live_model = genai.GenerativeModel(AUDIO_MODEL_ID)
|
85 |
-
logger.info(f"Initialized audio model handle: {AUDIO_MODEL_ID}.")
|
86 |
-
image_model_genai = genai.GenerativeModel(IMAGE_MODEL_ID) # Retained but likely needs Vertex SDK
|
87 |
-
logger.info(f"Initialized google-generativeai handle for image model: {IMAGE_MODEL_ID} (May require Vertex AI SDK).")
|
88 |
-
# ---> TODO: Initialize Vertex AI client here if switching SDK <---
|
89 |
-
|
90 |
-
except AttributeError as ae:
|
91 |
-
logger.exception("AttributeError during Client Init."); st.error(f"🚨 Init Error: {ae}. Update library?", icon="🚨"); st.stop()
|
92 |
-
except Exception as e:
|
93 |
-
logger.exception("Failed to initialize Google Clients/Models."); st.error(f"🚨 Failed Init: {e}", icon="🚨"); st.stop()
|
94 |
-
|
95 |
-
# --- Define Pydantic Schemas (Using V2 Syntax) ---
|
96 |
class StorySegment(BaseModel):
|
97 |
scene_id: int = Field(..., ge=0)
|
98 |
image_prompt: str = Field(..., min_length=10, max_length=250)
|
99 |
audio_text: str = Field(..., min_length=5, max_length=150)
|
100 |
character_description: str = Field(..., max_length=250)
|
101 |
timeline_visual_modifier: Optional[str] = Field(None, max_length=50)
|
102 |
-
|
|
|
103 |
@classmethod
|
104 |
def image_prompt_no_humans(cls, v: str) -> str:
|
105 |
-
if any(
|
|
|
106 |
return v
|
|
|
|
|
107 |
class Timeline(BaseModel):
|
108 |
timeline_id: int = Field(..., ge=0)
|
109 |
divergence_reason: str = Field(..., min_length=5)
|
110 |
segments: List[StorySegment] = Field(..., min_items=1)
|
|
|
|
|
111 |
class ChronoWeaveResponse(BaseModel):
|
112 |
core_theme: str = Field(..., min_length=5)
|
113 |
timelines: List[Timeline] = Field(..., min_items=1)
|
114 |
total_scenes_per_timeline: int = Field(..., gt=0)
|
115 |
-
|
116 |
-
|
|
|
117 |
expected = self.total_scenes_per_timeline
|
118 |
-
for i,
|
119 |
-
if len(
|
|
|
120 |
return self
|
121 |
|
|
|
122 |
# --- Helper Functions ---
|
123 |
@contextlib.contextmanager
|
124 |
def wave_file_writer(filename: str, channels: int = 1, rate: int = AUDIO_SAMPLING_RATE, sample_width: int = 2):
|
125 |
-
"""
|
|
|
|
|
126 |
wf = None
|
127 |
try:
|
128 |
-
wf = wave.open(filename, "wb")
|
|
|
|
|
|
|
129 |
yield wf
|
130 |
-
except Exception as
|
|
|
|
|
131 |
finally:
|
132 |
-
if wf:
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
"""Generates audio using Gemini Live API (async version) via the GenerativeModel."""
|
138 |
-
collected_audio = bytearray(); task_id = os.path.basename(output_filename).split('.')[0]
|
139 |
-
logger.info(f"🎙️ [{task_id}] Requesting audio: '{api_text[:60]}...'")
|
140 |
-
try:
|
141 |
-
# CORRECTED config structure for audio generation <<<<<<-------
|
142 |
-
config = {
|
143 |
-
"response_modalities": ["AUDIO"],
|
144 |
-
# Removed 'audio_config' nesting
|
145 |
-
"audio_encoding": "LINEAR16",
|
146 |
-
"sample_rate_hertz": AUDIO_SAMPLING_RATE,
|
147 |
-
# Add other parameters like "voice" here directly if needed
|
148 |
-
}
|
149 |
-
directive_prompt = f"Narrate directly: \"{api_text}\""
|
150 |
-
async with live_model.connect(config=config) as session: # Pass corrected config
|
151 |
-
await session.send_request([directive_prompt])
|
152 |
-
async for response in session.stream_content():
|
153 |
-
if response.audio_chunk and response.audio_chunk.data: collected_audio.extend(response.audio_chunk.data)
|
154 |
-
if hasattr(response, 'error') and response.error: logger.error(f" ❌ [{task_id}] Audio stream error: {response.error}"); st.error(f"Audio stream error {task_id}: {response.error}", icon="🔊"); return None
|
155 |
-
if not collected_audio: logger.warning(f"⚠️ [{task_id}] No audio data received."); st.warning(f"No audio data for {task_id}.", icon="🔊"); return None
|
156 |
-
with wave_file_writer(output_filename, rate=AUDIO_SAMPLING_RATE) as wf: wf.writeframes(bytes(collected_audio))
|
157 |
-
logger.info(f" ✅ [{task_id}] Audio saved: {os.path.basename(output_filename)} ({len(collected_audio)} bytes)")
|
158 |
-
return output_filename
|
159 |
-
except genai.types.generation_types.BlockedPromptException as bpe: logger.error(f" ❌ [{task_id}] Audio blocked: {bpe}"); st.error(f"Audio blocked {task_id}.", icon="🔇"); return None
|
160 |
-
# Catch TypeError specifically for config issues
|
161 |
-
except TypeError as te:
|
162 |
-
logger.exception(f" ❌ [{task_id}] Audio config TypeError: {te}")
|
163 |
-
st.error(f"Audio configuration error for {task_id} (TypeError): {te}. Check library version/config structure.", icon="⚙️")
|
164 |
-
return None
|
165 |
-
except Exception as e: logger.exception(f" ❌ [{task_id}] Audio failed: {e}"); st.error(f"Audio failed {task_id}: {e}", icon="🔊"); return None
|
166 |
-
|
167 |
-
|
168 |
-
def generate_story_sequence_chrono(theme: str, num_scenes: int, num_timelines: int, divergence_prompt: str = "") -> Optional[ChronoWeaveResponse]:
|
169 |
-
"""Generates branching story sequences using Gemini structured output and validates with Pydantic."""
|
170 |
-
st.info(f"📚 Generating {num_timelines} timeline(s) x {num_scenes} scenes for: '{theme}'...")
|
171 |
-
logger.info(f"Requesting story structure: Theme='{theme}', Timelines={num_timelines}, Scenes={num_scenes}")
|
172 |
-
divergence_instruction = (f"Introduce clear points of divergence between timelines, after first scene if possible. Hint: '{divergence_prompt}'. State divergence reason clearly. **For timeline_id 0, use 'Initial path' or 'Baseline scenario'.**")
|
173 |
-
prompt = f"""Act as narrative designer. Create story for theme: "{theme}". Instructions: 1. Exactly **{num_timelines}** timelines. 2. Each timeline exactly **{num_scenes}** scenes. 3. **NO humans/humanoids**. Focus: animals, fantasy creatures, animated objects, nature. 4. {divergence_instruction}. 5. Style: **'Simple, friendly kids animation, bright colors, rounded shapes'**, unless `timeline_visual_modifier` alters. 6. `audio_text`: single concise sentence (max 30 words). 7. `image_prompt`: descriptive, concise (target 15-35 words MAX). Focus on scene elements. **AVOID repeating general style**. 8. `character_description`: VERY brief (name, features). Target < 20 words. Output: ONLY valid JSON object adhering to schema. No text before/after. JSON Schema: ```json\n{json.dumps(ChronoWeaveResponse.model_json_schema(), indent=2)}\n```"""
|
174 |
-
try:
|
175 |
-
response = client_standard.generate_content(contents=prompt, generation_config=genai.types.GenerationConfig(response_mime_type="application/json", temperature=0.7))
|
176 |
-
try: raw_data = json.loads(response.text)
|
177 |
-
except json.JSONDecodeError as json_err: logger.error(f"Failed JSON decode: {json_err}\nResponse:\n{response.text}"); st.error(f"🚨 Failed parse story: {json_err}", icon="📄"); st.text_area("Problem Response:", response.text, height=150); return None
|
178 |
-
except Exception as e: logger.error(f"Error processing text: {e}"); st.error(f"🚨 Error processing AI response: {e}", icon="📄"); return None
|
179 |
-
try: validated_data = ChronoWeaveResponse.model_validate(raw_data); logger.info("✅ Story structure OK!"); st.success("✅ Story structure OK!"); return validated_data
|
180 |
-
except ValidationError as val_err: logger.error(f"JSON validation failed: {val_err}\nData:\n{json.dumps(raw_data, indent=2)}"); st.error(f"🚨 Gen structure invalid: {val_err}", icon="🧬"); st.json(raw_data); return None
|
181 |
-
except genai.types.generation_types.BlockedPromptException as bpe: logger.error(f"Story gen blocked: {bpe}"); st.error("🚨 Story prompt blocked.", icon="🚫"); return None
|
182 |
-
except Exception as e: logger.exception("Error during story gen:"); st.error(f"🚨 Story gen error: {e}", icon="💥"); return None
|
183 |
|
184 |
|
185 |
-
|
|
|
186 |
"""
|
187 |
-
|
188 |
-
|
189 |
-
(google-cloud-aiplatform) to correctly call Imagen models. >>>
|
190 |
"""
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
else:
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
chrono_response: Optional[ChronoWeaveResponse] = None
|
221 |
-
with st.spinner("Generating narrative structure... 🤔"): chrono_response = generate_story_sequence_chrono(theme, num_scenes, num_timelines, divergence_prompt)
|
222 |
-
|
223 |
-
if chrono_response:
|
224 |
-
overall_start_time = time.time(); all_timelines_successful = True
|
225 |
-
with st.status("Generating assets and composing videos...", expanded=True) as status:
|
226 |
-
for timeline_index, timeline in enumerate(chrono_response.timelines):
|
227 |
-
timeline_id, divergence, segments = timeline.timeline_id, timeline.divergence_reason, timeline.segments
|
228 |
-
timeline_label = f"Timeline {timeline_id}"; st.subheader(f"Processing {timeline_label}: {divergence}")
|
229 |
-
logger.info(f"--- Processing {timeline_label} (Idx: {timeline_index}) ---"); generation_errors[timeline_id] = []
|
230 |
-
temp_image_files, temp_audio_files, video_clips = {}, {}, []
|
231 |
-
timeline_start_time = time.time(); scene_success_count = 0
|
232 |
-
|
233 |
-
for scene_index, segment in enumerate(segments):
|
234 |
-
scene_id = segment.scene_id; task_id = f"T{timeline_id}_S{scene_id}"
|
235 |
-
status.update(label=f"Processing {timeline_label}, Scene {scene_id + 1}/{len(segments)}...")
|
236 |
-
st.markdown(f"--- **Scene {scene_id + 1} ({task_id})** ---")
|
237 |
-
logger.info(f"Processing {timeline_label}, Scene {scene_id + 1}/{len(segments)}...")
|
238 |
-
scene_has_error = False
|
239 |
-
st.write(f" *Img Prompt:* {segment.image_prompt}" + (f" *(Mod: {segment.timeline_visual_modifier})*" if segment.timeline_visual_modifier else "")); st.write(f" *Audio Text:* {segment.audio_text}")
|
240 |
-
|
241 |
-
# --- 2a. Image Generation ---
|
242 |
-
generated_image: Optional[Image.Image] = None
|
243 |
-
with st.spinner(f"[{task_id}] Generating image... 🎨"):
|
244 |
-
combined_prompt = segment.image_prompt
|
245 |
-
if segment.character_description: combined_prompt += f" Featuring: {segment.character_description}"
|
246 |
-
if segment.timeline_visual_modifier: combined_prompt += f" Style hint: {segment.timeline_visual_modifier}."
|
247 |
-
generated_image = generate_image_imagen(combined_prompt, aspect_ratio, task_id) # <<< Needs Vertex AI SDK update
|
248 |
-
if generated_image:
|
249 |
-
image_path = os.path.join(temp_dir, f"{task_id}_image.png")
|
250 |
-
try: generated_image.save(image_path); temp_image_files[scene_id] = image_path; st.image(generated_image, width=180, caption=f"Scene {scene_id+1}")
|
251 |
-
except Exception as e: logger.error(f" ❌ [{task_id}] Img save error: {e}"); st.error(f"Save image {task_id} failed.", icon="💾"); scene_has_error = True; generation_errors[timeline_id].append(f"S{scene_id+1}: Img save fail.")
|
252 |
-
else: scene_has_error = True; generation_errors[timeline_id].append(f"S{scene_id+1}: Img gen fail."); continue
|
253 |
-
|
254 |
-
# --- 2b. Audio Generation ---
|
255 |
-
generated_audio_path: Optional[str] = None
|
256 |
-
if not scene_has_error: # Should not be reached currently due to image fail
|
257 |
-
with st.spinner(f"[{task_id}] Generating audio... 🔊"):
|
258 |
-
audio_path_temp = os.path.join(temp_dir, f"{task_id}_audio.wav")
|
259 |
-
try: generated_audio_path = asyncio.run(generate_audio_live_async(segment.audio_text, audio_path_temp, audio_voice))
|
260 |
-
except RuntimeError as e: logger.error(f" ❌ [{task_id}] Asyncio error: {e}"); st.error(f"Asyncio audio error {task_id}: {e}", icon="⚡"); scene_has_error = True; generation_errors[timeline_id].append(f"S{scene_id+1}: Audio async err.")
|
261 |
-
except Exception as e: logger.exception(f" ❌ [{task_id}] Audio error: {e}"); st.error(f"Audio error {task_id}: {e}", icon="💥"); scene_has_error = True; generation_errors[timeline_id].append(f"S{scene_id+1}: Audio gen err.")
|
262 |
-
if generated_audio_path:
|
263 |
-
temp_audio_files[scene_id] = generated_audio_path; try: open(generated_audio_path,'rb') as ap: st.audio(ap.read(), format='audio/wav')
|
264 |
-
except Exception as e: logger.warning(f" ⚠️ [{task_id}] Audio preview error: {e}")
|
265 |
-
else: scene_has_error = True; generation_errors[timeline_id].append(f"S{scene_id+1}: Audio gen fail."); continue
|
266 |
-
|
267 |
-
# --- 2c. Create Video Clip ---
|
268 |
-
if not scene_has_error and scene_id in temp_image_files and scene_id in temp_audio_files: # Should not be reached currently
|
269 |
-
st.write(f" 🎬 Creating clip S{scene_id+1}..."); img_path, aud_path = temp_image_files[scene_id], temp_audio_files[scene_id]
|
270 |
-
audio_clip_instance, image_clip_instance, composite_clip = None, None, None
|
271 |
-
try:
|
272 |
-
if not os.path.exists(img_path): raise FileNotFoundError(f"Img missing: {img_path}")
|
273 |
-
if not os.path.exists(aud_path): raise FileNotFoundError(f"Aud missing: {aud_path}")
|
274 |
-
audio_clip_instance = AudioFileClip(aud_path); np_image = np.array(Image.open(img_path))
|
275 |
-
image_clip_instance = ImageClip(np_image).set_duration(audio_clip_instance.duration)
|
276 |
-
composite_clip = image_clip_instance.set_audio(audio_clip_instance); video_clips.append(composite_clip)
|
277 |
-
logger.info(f" ✅ [{task_id}] Clip created (Dur: {audio_clip_instance.duration:.2f}s)."); st.write(f" ✅ Clip created (Dur: {audio_clip_instance.duration:.2f}s)."); scene_success_count += 1
|
278 |
-
except Exception as e: logger.exception(f" ❌ [{task_id}] Failed clip creation: {e}"); st.error(f"Failed clip {task_id}: {e}", icon="🎬"); scene_has_error = True; generation_errors[timeline_id].append(f"S{scene_id+1}: Clip fail.")
|
279 |
-
finally:
|
280 |
-
if audio_clip_instance: audio_clip_instance.close();
|
281 |
-
if image_clip_instance: image_clip_instance.close()
|
282 |
-
|
283 |
-
# --- 2d. Assemble Timeline Video ---
|
284 |
-
timeline_duration = time.time() - timeline_start_time
|
285 |
-
if video_clips and scene_success_count == len(segments):
|
286 |
-
status.update(label=f"Composing video {timeline_label}..."); st.write(f"🎞️ Assembling video {timeline_label}..."); logger.info(f"🎞️ Assembling video {timeline_label}...")
|
287 |
-
output_filename = os.path.join(temp_dir, f"timeline_{timeline_id}_final.mp4"); final_timeline_video = None
|
288 |
-
try: final_timeline_video = concatenate_videoclips(video_clips, method="compose"); final_timeline_video.write_videofile(output_filename, fps=VIDEO_FPS, codec=VIDEO_CODEC, audio_codec=AUDIO_CODEC, logger=None); final_video_paths[timeline_id] = output_filename; logger.info(f" ✅ [{timeline_label}] Video saved: {os.path.basename(output_filename)}"); st.success(f"✅ Video {timeline_label} completed in {timeline_duration:.2f}s.")
|
289 |
-
except Exception as e: logger.exception(f" ❌ [{timeline_label}] Video assembly failed: {e}"); st.error(f"Assemble video {timeline_label} failed: {e}", icon="📼"); all_timelines_successful = False; generation_errors[timeline_id].append(f"T{timeline_id}: Assembly fail.")
|
290 |
-
finally:
|
291 |
-
logger.debug(f"[{timeline_label}] Closing {len(video_clips)} clips...");
|
292 |
-
for i, clip in enumerate(video_clips): try: clip.close() except Exception as e_close: logger.warning(f" ⚠️ [{timeline_label}] Clip close err {i}: {e_close}")
|
293 |
-
if final_timeline_video: try: final_timeline_video.close() except Exception as e_close_final: logger.warning(f" ⚠️ [{timeline_label}] Final vid close err: {e_close_final}")
|
294 |
-
elif not video_clips: logger.warning(f"[{timeline_label}] No clips. Skip assembly."); st.warning(f"No scenes for {timeline_label}. No video.", icon="🚫"); all_timelines_successful = False
|
295 |
-
else: error_count = len(generation_errors[timeline_id]); logger.warning(f"[{timeline_label}] {error_count} scene err(s). Skip assembly."); st.warning(f"{timeline_label}: {error_count} err(s). Video not assembled.", icon="⚠️"); all_timelines_successful = False
|
296 |
-
if generation_errors[timeline_id]: logger.error(f"Errors {timeline_label}: {generation_errors[timeline_id]}")
|
297 |
-
|
298 |
-
# --- End of Timelines Loop ---
|
299 |
-
overall_duration = time.time() - overall_start_time
|
300 |
-
if all_timelines_successful and final_video_paths: status_msg = f"Complete! ({len(final_video_paths)} videos in {overall_duration:.2f}s)"; status.update(label=status_msg, state="complete", expanded=False); logger.info(status_msg)
|
301 |
-
elif final_video_paths: status_msg = f"Partially Complete ({len(final_video_paths)} videos, errors). {overall_duration:.2f}s"; status.update(label=status_msg, state="warning", expanded=True); logger.warning(status_msg)
|
302 |
-
else: status_msg = f"Failed. No videos. {overall_duration:.2f}s"; status.update(label=status_msg, state="error", expanded=True); logger.error(status_msg)
|
303 |
-
|
304 |
-
# --- 3. Display Results ---
|
305 |
-
st.header("🎬 Generated Timelines")
|
306 |
-
if final_video_paths:
|
307 |
-
sorted_timeline_ids = sorted(final_video_paths.keys()); num_cols = min(len(sorted_timeline_ids), 3); cols = st.columns(num_cols)
|
308 |
-
for idx, timeline_id in enumerate(sorted_timeline_ids):
|
309 |
-
col = cols[idx % num_cols]; video_path = final_video_paths[timeline_id]
|
310 |
-
timeline_data = next((t for t in chrono_response.timelines if t.timeline_id == timeline_id), None)
|
311 |
-
reason = timeline_data.divergence_reason if timeline_data else "Unknown"
|
312 |
-
with col:
|
313 |
-
st.subheader(f"Timeline {timeline_id}"); st.caption(f"Divergence: {reason}")
|
314 |
-
try:
|
315 |
-
with open(video_path, 'rb') as vf: video_bytes = vf.read()
|
316 |
-
st.video(video_bytes); logger.info(f"Displaying T{timeline_id}")
|
317 |
-
st.download_button(f"Download T{timeline_id}", video_bytes, f"timeline_{timeline_id}.mp4", "video/mp4", key=f"dl_{timeline_id}")
|
318 |
-
if generation_errors.get(timeline_id):
|
319 |
-
scene_errors = [err for err in generation_errors[timeline_id] if not err.startswith(f"T{timeline_id}:")]
|
320 |
-
if scene_errors:
|
321 |
-
with st.expander(f"⚠️ View {len(scene_errors)} Scene Issues"):
|
322 |
-
for err in scene_errors: st.warning(f"- {err}")
|
323 |
-
except FileNotFoundError: logger.error(f"Video missing: {video_path}"); st.error(f"Error: Video missing T{timeline_id}.", icon="🚨")
|
324 |
-
except Exception as e: logger.exception(f"Display error {video_path}: {e}"); st.error(f"Display error T{timeline_id}: {e}", icon="🚨")
|
325 |
-
else: # No videos generated
|
326 |
-
st.warning("No final videos were successfully generated.")
|
327 |
-
st.subheader("Summary of Generation Issues")
|
328 |
-
has_errors = any(generation_errors.values())
|
329 |
-
if has_errors:
|
330 |
-
with st.expander("View All Errors", expanded=True):
|
331 |
-
for tid, errors in generation_errors.items():
|
332 |
-
if errors:
|
333 |
-
st.error(f"**Timeline {tid}:**")
|
334 |
-
for msg in errors: st.error(f" - {msg}") # Use standard loop
|
335 |
-
else: st.info("No generation errors recorded.")
|
336 |
-
|
337 |
-
# --- 4. Cleanup ---
|
338 |
-
st.info(f"Attempting cleanup: {temp_dir}")
|
339 |
-
try: shutil.rmtree(temp_dir); logger.info(f"✅ Temp dir removed: {temp_dir}"); st.success("✅ Temp files cleaned.")
|
340 |
-
except Exception as e: logger.error(f"⚠️ Failed remove temp dir {temp_dir}: {e}"); st.warning(f"Could not remove temp files: {temp_dir}.", icon="⚠️")
|
341 |
-
|
342 |
-
elif not chrono_response: logger.error("Story gen/validation failed.")
|
343 |
-
else: st.error("Unexpected issue post-gen.", icon="🛑"); logger.error("Chrono_response truthy but invalid.")
|
344 |
-
|
345 |
-
else: st.info("Configure settings and click '✨ Generate ChronoWeave ✨' to start.")
|
|
|
1 |
+
# Copyright 2025 Google LLC.
|
2 |
+
# Based on work by Yousif Ahmed.
|
3 |
+
# Concept: ChronoWeave – Branching Narrative Generation
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License").
|
5 |
+
# You may not use this file except in compliance with the License.
|
6 |
+
# Obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0
|
7 |
|
|
|
|
|
8 |
import os
|
9 |
import json
|
|
|
|
|
10 |
import time
|
11 |
+
import uuid
|
|
|
12 |
import asyncio
|
13 |
+
import logging
|
14 |
+
import shutil
|
15 |
+
import contextlib
|
16 |
+
import wave
|
17 |
+
from io import BytesIO
|
18 |
+
from typing import List, Optional, Tuple, Dict, Any
|
19 |
|
20 |
+
import streamlit as st
|
21 |
+
import numpy as np
|
22 |
from PIL import Image
|
23 |
+
|
24 |
# Pydantic for data validation
|
25 |
from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator
|
|
|
26 |
|
27 |
# Video and audio processing
|
28 |
from moviepy.editor import ImageClip, AudioFileClip, concatenate_videoclips
|
|
|
29 |
|
30 |
+
# Google Generative AI library and async patch
|
31 |
+
import google.generativeai as genai
|
|
|
|
|
32 |
import nest_asyncio
|
33 |
+
nest_asyncio.apply() # Ensure asyncio works correctly in Streamlit/Jupyter
|
34 |
|
35 |
# --- Logging Setup ---
|
36 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
37 |
logger = logging.getLogger(__name__)
|
38 |
|
39 |
+
# --- Constants & Configurations ---
|
40 |
+
TEXT_MODEL_ID = "models/gemini-1.5-flash" # Alternatively "gemini-1.5-pro"
|
41 |
+
AUDIO_MODEL_ID = "models/gemini-1.5-flash" # Synchronous generation for audio now
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
AUDIO_SAMPLING_RATE = 24000
|
43 |
+
IMAGE_MODEL_ID = "imagen-3" # NOTE: Requires Vertex AI SDK integration in the future
|
|
|
44 |
DEFAULT_ASPECT_RATIO = "1:1"
|
|
|
45 |
VIDEO_FPS = 24
|
46 |
VIDEO_CODEC = "libx264"
|
47 |
AUDIO_CODEC = "aac"
|
|
|
48 |
TEMP_DIR_BASE = ".chrono_temp"
|
49 |
|
50 |
+
|
51 |
+
# --- Pydantic Schemas ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
class StorySegment(BaseModel):
|
53 |
scene_id: int = Field(..., ge=0)
|
54 |
image_prompt: str = Field(..., min_length=10, max_length=250)
|
55 |
audio_text: str = Field(..., min_length=5, max_length=150)
|
56 |
character_description: str = Field(..., max_length=250)
|
57 |
timeline_visual_modifier: Optional[str] = Field(None, max_length=50)
|
58 |
+
|
59 |
+
@field_validator("image_prompt")
|
60 |
@classmethod
|
61 |
def image_prompt_no_humans(cls, v: str) -> str:
|
62 |
+
if any(word in v.lower() for word in ["person", "people", "human", "man", "woman", "boy", "girl", "child"]):
|
63 |
+
logger.warning(f"Image prompt '{v[:50]}...' may include human-related descriptions.")
|
64 |
return v
|
65 |
+
|
66 |
+
|
67 |
class Timeline(BaseModel):
|
68 |
timeline_id: int = Field(..., ge=0)
|
69 |
divergence_reason: str = Field(..., min_length=5)
|
70 |
segments: List[StorySegment] = Field(..., min_items=1)
|
71 |
+
|
72 |
+
|
73 |
class ChronoWeaveResponse(BaseModel):
|
74 |
core_theme: str = Field(..., min_length=5)
|
75 |
timelines: List[Timeline] = Field(..., min_items=1)
|
76 |
total_scenes_per_timeline: int = Field(..., gt=0)
|
77 |
+
|
78 |
+
@model_validator(mode="after")
|
79 |
+
def check_timeline_segment_count(self) -> "ChronoWeaveResponse":
|
80 |
expected = self.total_scenes_per_timeline
|
81 |
+
for i, timeline in enumerate(self.timelines):
|
82 |
+
if len(timeline.segments) != expected:
|
83 |
+
raise ValueError(f"Timeline {i} (ID: {timeline.timeline_id}): Expected {expected} segments, got {len(timeline.segments)}.")
|
84 |
return self
|
85 |
|
86 |
+
|
87 |
# --- Helper Functions ---
|
88 |
@contextlib.contextmanager
|
89 |
def wave_file_writer(filename: str, channels: int = 1, rate: int = AUDIO_SAMPLING_RATE, sample_width: int = 2):
|
90 |
+
"""
|
91 |
+
Safely writes a WAV file using a context manager.
|
92 |
+
"""
|
93 |
wf = None
|
94 |
try:
|
95 |
+
wf = wave.open(filename, "wb")
|
96 |
+
wf.setnchannels(channels)
|
97 |
+
wf.setsampwidth(sample_width) # 16-bit audio (2 bytes)
|
98 |
+
wf.setframerate(rate)
|
99 |
yield wf
|
100 |
+
except Exception as exc:
|
101 |
+
logger.error(f"Error writing wave file {filename}: {exc}")
|
102 |
+
raise
|
103 |
finally:
|
104 |
+
if wf:
|
105 |
+
try:
|
106 |
+
wf.close()
|
107 |
+
except Exception as e_close:
|
108 |
+
logger.error(f"Error closing wave file {filename}: {e_close}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
|
111 |
+
# --- ChronoWeave Generator Class ---
|
112 |
+
class ChronoWeaveGenerator:
|
113 |
"""
|
114 |
+
Encapsulates the logic for generating branching narratives,
|
115 |
+
processing audio, images, and assembling video outputs.
|
|
|
116 |
"""
|
117 |
+
|
118 |
+
def __init__(self, api_key: str):
|
119 |
+
self.api_key = api_key
|
120 |
+
genai.configure(api_key=self.api_key)
|
121 |
+
|
122 |
+
try:
|
123 |
+
self.client_text = genai.GenerativeModel(TEXT_MODEL_ID)
|
124 |
+
logger.info(f"Initialized text model: {TEXT_MODEL_ID}")
|
125 |
+
self.client_audio = genai.GenerativeModel(AUDIO_MODEL_ID)
|
126 |
+
logger.info(f"Initialized audio model: {AUDIO_MODEL_ID}")
|
127 |
+
self.client_image = genai.GenerativeModel(IMAGE_MODEL_ID)
|
128 |
+
logger.info(f"Initialized image model: {IMAGE_MODEL_ID} (Placeholder: Update to Vertex AI SDK)")
|
129 |
+
except Exception as exc:
|
130 |
+
logger.exception("Failed to initialize Google Clients/Models.")
|
131 |
+
raise exc
|
132 |
+
|
133 |
+
def generate_story_structure(
|
134 |
+
self, theme: str, num_scenes: int, num_timelines: int, divergence_prompt: str = ""
|
135 |
+
) -> Optional[ChronoWeaveResponse]:
|
136 |
+
"""
|
137 |
+
Generates a story structure as JSON using the text model and validates it via Pydantic.
|
138 |
+
"""
|
139 |
+
st.info(f"Generating {num_timelines} timeline(s) with {num_scenes} scene(s) for theme: '{theme}'")
|
140 |
+
logger.info(f"Story generation request: Theme='{theme}', Timelines={num_timelines}, Scenes={num_scenes}")
|
141 |
+
|
142 |
+
divergence_instruction = (
|
143 |
+
f"Introduce clear divergence after the first scene. Hint: '{divergence_prompt}'. "
|
144 |
+
f"For timeline_id 0, use 'Initial path' or 'Baseline scenario'."
|
145 |
+
)
|
146 |
+
|
147 |
+
prompt = f"""Act as a narrative designer. Create a story for the theme: "{theme}".
|
148 |
+
Instructions:
|
149 |
+
1. Exactly **{num_timelines}** timelines.
|
150 |
+
2. Each timeline must consist of exactly **{num_scenes}** scenes.
|
151 |
+
3. **NO humans/humanoids**; focus on animals, fantasy creatures, animated objects, and nature.
|
152 |
+
4. {divergence_instruction}
|
153 |
+
5. Style: **'Simple, friendly kids animation, bright colors, rounded shapes'** unless modified by `timeline_visual_modifier`.
|
154 |
+
6. `audio_text`: One concise sentence (max 30 words).
|
155 |
+
7. `image_prompt`: Descriptive prompt (15–35 words) emphasizing scene elements. **Avoid repeating general style.**
|
156 |
+
8. `character_description`: Very brief (name and features; < 20 words).
|
157 |
+
|
158 |
+
Output only a valid JSON object conforming exactly to this schema:
|
159 |
+
JSON Schema: ```json
|
160 |
+
{json.dumps(ChronoWeaveResponse.model_json_schema(), indent=2)}
|
161 |
+
```"""
|
162 |
+
|
163 |
+
try:
|
164 |
+
response = self.client_text.generate_content(
|
165 |
+
contents=prompt,
|
166 |
+
generation_config=genai.types.GenerationConfig(
|
167 |
+
response_mime_type="application/json", temperature=0.7
|
168 |
+
),
|
169 |
+
)
|
170 |
+
raw_data = json.loads(response.text)
|
171 |
+
validated_data = ChronoWeaveResponse.model_validate(raw_data)
|
172 |
+
st.success("Story structure validated successfully!")
|
173 |
+
return validated_data
|
174 |
+
|
175 |
+
except json.JSONDecodeError as json_err:
|
176 |
+
logger.error(f"JSON decode failed: {json_err}\nResponse: {response.text}")
|
177 |
+
st.error(f"🚨 JSON Parsing Error: {json_err}", icon="📄")
|
178 |
+
st.text_area("Response", response.text, height=150)
|
179 |
+
except ValidationError as val_err:
|
180 |
+
logger.error(f"Pydantic validation error: {val_err}\nData: {json.dumps(raw_data, indent=2)}")
|
181 |
+
st.error(f"🚨 Invalid story structure: {val_err}", icon="🧬")
|
182 |
+
st.json(raw_data)
|
183 |
+
except Exception as e:
|
184 |
+
logger.exception("Story generation error:")
|
185 |
+
st.error(f"🚨 Error generating story: {e}", icon="💥")
|
186 |
+
return None
|
187 |
+
|
188 |
+
async def generate_audio(self, text: str, output_filename: str, voice: Optional[str] = None) -> Optional[str]:
|
189 |
+
"""
|
190 |
+
Asynchronously generates audio by wrapping the synchronous generate_content call.
|
191 |
+
The call is executed using asyncio.to_thread to avoid blocking.
|
192 |
+
"""
|
193 |
+
task_id = os.path.basename(output_filename).split(".")[0]
|
194 |
+
logger.info(f"🎙️ [{task_id}] Generating audio for text: '{text[:60]}...'")
|
195 |
+
|
196 |
+
try:
|
197 |
+
# Define a synchronous function for audio generation.
|
198 |
+
def sync_generate_audio():
|
199 |
+
prompt = f"Narrate directly: \"{text}\""
|
200 |
+
response = self.client_audio.generate_content(
|
201 |
+
contents=prompt,
|
202 |
+
generation_config=genai.types.GenerationConfig(
|
203 |
+
response_mime_type="application/octet-stream",
|
204 |
+
temperature=0.7,
|
205 |
+
audio_config={"audio_encoding": "LINEAR16", "sample_rate_hertz": AUDIO_SAMPLING_RATE}
|
206 |
+
)
|
207 |
+
)
|
208 |
+
return response
|
209 |
+
|
210 |
+
# Execute the synchronous call in a separate thread.
|
211 |
+
response = await asyncio.to_thread(sync_generate_audio)
|
212 |
+
|
213 |
+
# Process the response. Adjust as necessary based on the API’s actual response structure.
|
214 |
+
if not response or not hasattr(response, "audio_chunk") or not response.audio_chunk.data:
|
215 |
+
logger.error(f"❌ [{task_id}] No audio data returned.")
|
216 |
+
st.error(f"Audio generation failed for {task_id}: No audio data.", icon="🔊")
|
217 |
+
return None
|
218 |
+
|
219 |
+
audio_data = response.audio_chunk.data
|
220 |
+
with wave_file_writer(output_filename) as wf:
|
221 |
+
wf.writeframes(audio_data)
|
222 |
+
logger.info(f"✅ [{task_id}] Audio saved: {os.path.basename(output_filename)} ({len(audio_data)} bytes)")
|
223 |
+
return output_filename
|
224 |
+
|
225 |
+
except Exception as e:
|
226 |
+
logger.exception(f"❌ [{task_id}] Audio generation error: {e}")
|
227 |
+
st.error(f"Audio generation failed for {task_id}: {e}", icon="🔊")
|
228 |
+
return None
|
229 |
+
|
230 |
+
async def generate_image_async(self, prompt: str, aspect_ratio: str, task_id: str) -> Optional[Image.Image]:
|
231 |
+
"""
|
232 |
+
Placeholder for image generation.
|
233 |
+
Currently logs an error and returns None. Update this function once integrating Vertex AI SDK.
|
234 |
+
"""
|
235 |
+
logger.info(f"🖼️ [{task_id}] Requesting image for prompt: '{prompt[:70]}...' (Aspect Ratio: {aspect_ratio})")
|
236 |
+
logger.error(f"❌ [{task_id}] Image generation not implemented. Update required for Vertex AI.")
|
237 |
+
st.error(f"Image generation for {task_id} skipped: Requires Vertex AI SDK implementation.", icon="🖼️")
|
238 |
+
return None
|
239 |
+
|
240 |
+
async def process_scene(
|
241 |
+
self,
|
242 |
+
timeline_id: int,
|
243 |
+
segment: StorySegment,
|
244 |
+
temp_dir: str,
|
245 |
+
aspect_ratio: str,
|
246 |
+
audio_voice: Optional[str] = None,
|
247 |
+
) -> Tuple[Optional[str], Optional[str], Optional[Any], List[str]]:
|
248 |
+
"""
|
249 |
+
Processes a single scene: concurrently generates image and audio,
|
250 |
+
and then creates a video clip if both outputs are available.
|
251 |
+
Returns a tuple of (image_path, audio_path, video_clip, [error messages]).
|
252 |
+
"""
|
253 |
+
errors: List[str] = []
|
254 |
+
task_id = f"T{timeline_id}_S{segment.scene_id}"
|
255 |
+
image_path = os.path.join(temp_dir, f"{task_id}_image.png")
|
256 |
+
audio_path = os.path.join(temp_dir, f"{task_id}_audio.wav")
|
257 |
+
video_clip = None
|
258 |
+
|
259 |
+
# Launch image and audio generation concurrently.
|
260 |
+
image_future = asyncio.create_task(
|
261 |
+
self.generate_image_async(
|
262 |
+
prompt=f"{segment.image_prompt} Featuring: {segment.character_description} " +
|
263 |
+
(f"Style hint: {segment.timeline_visual_modifier}" if segment.timeline_visual_modifier else ""),
|
264 |
+
aspect_ratio=aspect_ratio,
|
265 |
+
task_id=task_id,
|
266 |
+
)
|
267 |
+
)
|
268 |
+
audio_future = asyncio.create_task(self.generate_audio(segment.audio_text, audio_path, audio_voice))
|
269 |
+
|
270 |
+
image_result, audio_result = await asyncio.gather(image_future, audio_future)
|
271 |
+
|
272 |
+
if image_result:
|
273 |
+
try:
|
274 |
+
image_result.save(image_path)
|
275 |
+
st.image(image_result, width=180, caption=f"Scene {segment.scene_id + 1}")
|
276 |
+
except Exception as e:
|
277 |
+
logger.error(f"❌ [{task_id}] Error saving image: {e}")
|
278 |
+
errors.append(f"Scene {segment.scene_id + 1}: Image save error.")
|
279 |
+
else:
|
280 |
+
errors.append(f"Scene {segment.scene_id + 1}: Image generation failed.")
|
281 |
+
|
282 |
+
if audio_result:
|
283 |
+
try:
|
284 |
+
with open(audio_result, "rb") as ap:
|
285 |
+
st.audio(ap.read(), format="audio/wav")
|
286 |
+
except Exception as e:
|
287 |
+
logger.warning(f"⚠️ [{task_id}] Audio preview error: {e}")
|
288 |
+
else:
|
289 |
+
errors.append(f"Scene {segment.scene_id + 1}: Audio generation failed.")
|
290 |
+
|
291 |
+
if not errors and os.path.exists(image_path) and os.path.exists(audio_path):
|
292 |
+
try:
|
293 |
+
audio_clip = AudioFileClip(audio_path)
|
294 |
+
np_img = np.array(Image.open(image_path))
|
295 |
+
img_clip = ImageClip(np_img).set_duration(audio_clip.duration)
|
296 |
+
video_clip = img_clip.set_audio(audio_clip)
|
297 |
+
logger.info(f"✅ [{task_id}] Video clip created (Duration: {audio_clip.duration:.2f}s).")
|
298 |
+
except Exception as e:
|
299 |
+
logger.exception(f"❌ [{task_id}] Failed to create video clip: {e}")
|
300 |
+
errors.append(f"Scene {segment.scene_id + 1}: Video clip creation failed.")
|
301 |
+
finally:
|
302 |
+
try:
|
303 |
+
if 'audio_clip' in locals():
|
304 |
+
audio_clip.close()
|
305 |
+
if 'img_clip' in locals():
|
306 |
+
img_clip.close()
|
307 |
+
except Exception:
|
308 |
+
pass
|
309 |
+
|
310 |
+
return (
|
311 |
+
image_path if os.path.exists(image_path) else None,
|
312 |
+
audio_path if os.path.exists(audio_path) else None,
|
313 |
+
video_clip,
|
314 |
+
errors,
|
315 |
+
)
|
316 |
+
|
317 |
+
async def process_timeline(
|
318 |
+
self,
|
319 |
+
timeline: Timeline,
|
320 |
+
temp_dir: str,
|
321 |
+
aspect_ratio: str,
|
322 |
+
audio_voice: Optional[str] = None,
|
323 |
+
) -> Tuple[Optional[str], List[str]]:
|
324 |
+
"""
|
325 |
+
Processes an entire timeline by concurrently processing all its scenes,
|
326 |
+
then assembling a final video if every scene produced a valid clip.
|
327 |
+
Returns a tuple of (final video path, list of error messages).
|
328 |
+
"""
|
329 |
+
timeline_id = timeline.timeline_id
|
330 |
+
scene_tasks = [
|
331 |
+
self.process_scene(timeline_id, segment, temp_dir, aspect_ratio, audio_voice)
|
332 |
+
for segment in timeline.segments
|
333 |
+
]
|
334 |
+
results = await asyncio.gather(*scene_tasks)
|
335 |
+
video_clips = []
|
336 |
+
timeline_errors: List[str] = []
|
337 |
+
for idx, (img_path, aud_path, clip, errs) in enumerate(results):
|
338 |
+
if errs:
|
339 |
+
timeline_errors.extend(errs)
|
340 |
+
if clip is not None:
|
341 |
+
video_clips.append(clip)
|
342 |
+
|
343 |
+
if video_clips and len(video_clips) == len(timeline.segments):
|
344 |
+
output_filename = os.path.join(temp_dir, f"timeline_{timeline_id}_final.mp4")
|
345 |
+
try:
|
346 |
+
final_video = concatenate_videoclips(video_clips, method="compose")
|
347 |
+
final_video.write_videofile(
|
348 |
+
output_filename, fps=VIDEO_FPS, codec=VIDEO_CODEC, audio_codec=AUDIO_CODEC, logger=None
|
349 |
+
)
|
350 |
+
logger.info(f"✅ Timeline {timeline_id} video saved: {output_filename}")
|
351 |
+
for clip in video_clips:
|
352 |
+
clip.close()
|
353 |
+
final_video.close()
|
354 |
+
return output_filename, timeline_errors
|
355 |
+
except Exception as e:
|
356 |
+
logger.exception(f"❌ Timeline {timeline_id} video assembly failed: {e}")
|
357 |
+
timeline_errors.append(f"Timeline {timeline_id}: Video assembly failed.")
|
358 |
+
else:
|
359 |
+
timeline_errors.append(f"Timeline {timeline_id}: Incomplete scenes; skipping video assembly.")
|
360 |
+
return None, timeline_errors
|
361 |
+
|
362 |
+
|
363 |
+
# --- Streamlit UI and Main Process ---
|
364 |
+
def main():
|
365 |
+
# API Key Retrieval
|
366 |
+
GOOGLE_API_KEY: Optional[str] = None
|
367 |
+
try:
|
368 |
+
GOOGLE_API_KEY = st.secrets["GOOGLE_API_KEY"]
|
369 |
+
logger.info("Google API Key loaded from Streamlit secrets.")
|
370 |
+
except KeyError:
|
371 |
+
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
|
372 |
+
if GOOGLE_API_KEY:
|
373 |
+
logger.info("Google API Key loaded from environment variable.")
|
374 |
+
else:
|
375 |
+
st.error("🚨 **Google API Key Not Found!** Please configure it.", icon="🚨")
|
376 |
+
st.stop()
|
377 |
+
|
378 |
+
st.set_page_config(page_title="ChronoWeave", layout="wide", initial_sidebar_state="expanded")
|
379 |
+
st.title("🌀 ChronoWeave: Advanced Branching Narrative Generator")
|
380 |
+
st.markdown("""
|
381 |
+
Generate multiple, branching story timelines from a single theme using AI – complete with images and narration.
|
382 |
+
*Based on work by Yousif Ahmed. Copyright 2025 Google LLC.*
|
383 |
+
""")
|
384 |
+
|
385 |
+
st.sidebar.header("⚙️ Configuration")
|
386 |
+
if GOOGLE_API_KEY:
|
387 |
+
st.sidebar.success("Google API Key Loaded", icon="✅")
|
388 |
+
else:
|
389 |
+
st.sidebar.error("Google API Key Missing!", icon="🚨")
|
390 |
+
|
391 |
+
theme = st.sidebar.text_input("📖 Story Theme:", "A curious squirrel finds a mysterious, glowing acorn")
|
392 |
+
num_scenes = st.sidebar.slider("🎬 Scenes per Timeline:", min_value=2, max_value=7, value=3)
|
393 |
+
num_timelines = st.sidebar.slider("🌿 Number of Timelines:", min_value=1, max_value=4, value=2)
|
394 |
+
divergence_prompt = st.sidebar.text_input("↔️ Divergence Hint (Optional):", placeholder="e.g., What if a bird tried to steal it?")
|
395 |
+
st.sidebar.subheader("🎨 Visual & Audio Settings")
|
396 |
+
aspect_ratio = st.sidebar.selectbox("🖼️ Image Aspect Ratio:", ["1:1", "16:9", "9:16"], index=0)
|
397 |
+
audio_voice = None
|
398 |
+
|
399 |
+
generate_button = st.sidebar.button("✨ Generate ChronoWeave ✨", type="primary", disabled=(not GOOGLE_API_KEY), use_container_width=True)
|
400 |
+
st.sidebar.markdown("---")
|
401 |
+
st.sidebar.info("⏳ Generation may take several minutes.")
|
402 |
+
st.sidebar.markdown(f"<small>Txt: {TEXT_MODEL_ID}, Img: {IMAGE_MODEL_ID}, Aud: {AUDIO_MODEL_ID}</small>", unsafe_allow_html=True)
|
403 |
+
|
404 |
+
if generate_button:
|
405 |
+
if not theme:
|
406 |
+
st.error("Please enter a story theme.", icon="👈")
|
407 |
+
return
|
408 |
+
|
409 |
+
run_id = str(uuid.uuid4()).split('-')[0]
|
410 |
+
temp_dir = os.path.join(TEMP_DIR_BASE, f"run_{run_id}")
|
411 |
+
try:
|
412 |
+
os.makedirs(temp_dir, exist_ok=True)
|
413 |
+
logger.info(f"Created temporary directory: {temp_dir}")
|
414 |
+
except OSError as e:
|
415 |
+
st.error(f"🚨 Failed to create temporary directory {temp_dir}: {e}", icon="📂")
|
416 |
+
st.stop()
|
417 |
+
|
418 |
+
# Instantiate ChronoWeaveGenerator and generate story structure.
|
419 |
+
generator = ChronoWeaveGenerator(GOOGLE_API_KEY)
|
420 |
+
chrono_response = None
|
421 |
+
with st.spinner("Generating narrative structure... 🤔"):
|
422 |
+
chrono_response = generator.generate_story_structure(theme, num_scenes, num_timelines, divergence_prompt)
|
423 |
+
|
424 |
+
if not chrono_response:
|
425 |
+
logger.error("Story generation or validation failed.")
|
426 |
+
return
|
427 |
+
|
428 |
+
overall_start_time = time.time()
|
429 |
+
final_video_paths: Dict[int, str] = {}
|
430 |
+
generation_errors: Dict[int, List[str]] = {}
|
431 |
+
|
432 |
+
async def process_all_timelines():
|
433 |
+
timeline_tasks = {
|
434 |
+
timeline.timeline_id: asyncio.create_task(
|
435 |
+
generator.process_timeline(timeline, temp_dir, aspect_ratio, audio_voice)
|
436 |
+
)
|
437 |
+
for timeline in chrono_response.timelines
|
438 |
+
}
|
439 |
+
results = await asyncio.gather(*timeline_tasks.values(), return_exceptions=False)
|
440 |
+
return results
|
441 |
+
|
442 |
+
with st.spinner("Processing scenes and assembling videos..."):
|
443 |
+
timeline_results = asyncio.run(process_all_timelines())
|
444 |
+
|
445 |
+
for timeline, (video_path, errors) in zip(chrono_response.timelines, timeline_results):
|
446 |
+
generation_errors[timeline.timeline_id] = errors
|
447 |
+
if video_path:
|
448 |
+
final_video_paths[timeline.timeline_id] = video_path
|
449 |
+
|
450 |
+
overall_duration = time.time() - overall_start_time
|
451 |
+
if final_video_paths:
|
452 |
+
st.success(f"Complete! ({len(final_video_paths)} video(s) created in {overall_duration:.2f}s)")
|
453 |
+
else:
|
454 |
+
st.error(f"Failed. No final videos generated in {overall_duration:.2f}s")
|
455 |
+
|
456 |
+
st.header("🎬 Generated Timelines")
|
457 |
+
if final_video_paths:
|
458 |
+
sorted_ids = sorted(final_video_paths.keys())
|
459 |
+
num_cols = min(len(sorted_ids), 3)
|
460 |
+
cols = st.columns(num_cols)
|
461 |
+
for idx, timeline_id in enumerate(sorted_ids):
|
462 |
+
video_path = final_video_paths[timeline_id]
|
463 |
+
timeline_data = next((t for t in chrono_response.timelines if t.timeline_id == timeline_id), None)
|
464 |
+
divergence = timeline_data.divergence_reason if timeline_data else "Unknown"
|
465 |
+
with cols[idx % num_cols]:
|
466 |
+
st.subheader(f"Timeline {timeline_id}")
|
467 |
+
st.caption(f"Divergence: {divergence}")
|
468 |
+
try:
|
469 |
+
with open(video_path, "rb") as vf:
|
470 |
+
video_bytes = vf.read()
|
471 |
+
st.video(video_bytes)
|
472 |
+
st.download_button(
|
473 |
+
f"Download Timeline {timeline_id}",
|
474 |
+
video_bytes,
|
475 |
+
file_name=f"timeline_{timeline_id}.mp4",
|
476 |
+
mime="video/mp4",
|
477 |
+
key=f"dl_{timeline_id}"
|
478 |
+
)
|
479 |
+
if generation_errors.get(timeline_id):
|
480 |
+
scene_errs = generation_errors[timeline_id]
|
481 |
+
if scene_errs:
|
482 |
+
with st.expander(f"⚠️ View Scene Issues ({len(scene_errs)})"):
|
483 |
+
for err in scene_errs:
|
484 |
+
st.warning(f"- {err}")
|
485 |
+
except FileNotFoundError:
|
486 |
+
st.error(f"Error: Video for Timeline {timeline_id} is missing.", icon="🚨")
|
487 |
+
except Exception as e:
|
488 |
+
st.error(f"Display error for Timeline {timeline_id}: {e}", icon="🚨")
|
489 |
+
else:
|
490 |
+
st.warning("No final videos were successfully generated.")
|
491 |
+
with st.expander("View All Generation Errors", expanded=True):
|
492 |
+
for tid, errs in generation_errors.items():
|
493 |
+
if errs:
|
494 |
+
st.error(f"Timeline {tid}:")
|
495 |
+
for msg in errs:
|
496 |
+
st.error(f" - {msg}")
|
497 |
+
|
498 |
+
st.info(f"Cleaning up temporary files: {temp_dir}")
|
499 |
+
try:
|
500 |
+
shutil.rmtree(temp_dir)
|
501 |
+
st.success("✅ Temporary files cleaned up.")
|
502 |
+
logger.info(f"Temporary directory removed: {temp_dir}")
|
503 |
+
except Exception as e:
|
504 |
+
st.warning(f"Could not remove temporary files at: {temp_dir}", icon="⚠️")
|
505 |
+
logger.error(f"Failed to remove temporary directory {temp_dir}: {e}")
|
506 |
else:
|
507 |
+
st.info("Configure settings and click '✨ Generate ChronoWeave ✨' to start.")
|
508 |
+
|
509 |
+
|
510 |
+
if __name__ == "__main__":
|
511 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|