mgbam commited on
Commit
a66ce42
Β·
verified Β·
1 Parent(s): 463d8c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -29
app.py CHANGED
@@ -23,11 +23,11 @@ from PIL import Image
23
 
24
  # Pydantic for data validation
25
  from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator
26
- from typing import List, Optional, Literal, Dict, Any
27
 
28
  # Video and audio processing
29
  from moviepy.editor import ImageClip, AudioFileClip, concatenate_videoclips
30
- # from moviepy.config import change_settings # Potential for setting ImageMagick path if needed
31
 
32
  # Type hints
33
  import typing_extensions as typing
@@ -36,6 +36,9 @@ import typing_extensions as typing
36
  import nest_asyncio
37
  nest_asyncio.apply()
38
 
 
 
 
39
  # --- Logging Setup ---
40
  logging.basicConfig(
41
  level=logging.INFO,
@@ -55,14 +58,14 @@ Generate multiple, branching story timelines from a single theme using AI, compl
55
  TEXT_MODEL_ID = "models/gemini-1.5-flash"
56
  AUDIO_MODEL_ID = "models/gemini-1.5-flash"
57
  AUDIO_SAMPLING_RATE = 24000
58
- IMAGE_MODEL_ID = "imagen-3" # NOTE: Requires Vertex AI SDK access
59
  DEFAULT_ASPECT_RATIO = "1:1"
60
  VIDEO_FPS = 24
61
  VIDEO_CODEC = "libx264"
62
  AUDIO_CODEC = "aac"
63
  TEMP_DIR_BASE = ".chrono_temp"
64
 
65
- # --- API Key Handling ---
66
  GOOGLE_API_KEY = None
67
  try:
68
  GOOGLE_API_KEY = st.secrets["GOOGLE_API_KEY"]
@@ -75,24 +78,25 @@ except KeyError:
75
  st.error("🚨 **Google API Key Not Found!** Please configure it.", icon="🚨")
76
  st.stop()
77
 
78
- # --- Initialize Google Clients ---
 
 
 
 
 
 
 
 
 
 
 
79
  try:
80
  genai.configure(api_key=GOOGLE_API_KEY)
81
  logger.info("Configured google-generativeai with API key.")
82
-
83
- # Initialize text/JSON model
84
  client_standard = genai.GenerativeModel(TEXT_MODEL_ID)
85
  logger.info(f"Initialized text/JSON model handle: {TEXT_MODEL_ID}.")
86
-
87
- # Initialize audio model
88
  live_model = genai.GenerativeModel(AUDIO_MODEL_ID)
89
  logger.info(f"Initialized audio model handle: {AUDIO_MODEL_ID}.")
90
-
91
- # Initialize image model (placeholder for future Vertex AI SDK integration)
92
- image_model_genai = genai.GenerativeModel(IMAGE_MODEL_ID)
93
- logger.info(f"Initialized google-generativeai handle for image model: {IMAGE_MODEL_ID} (May require Vertex AI SDK).")
94
-
95
- # ---> TODO: Initialize Vertex AI client here if switching SDK <---
96
  except AttributeError as ae:
97
  logger.exception("AttributeError during Client Init.")
98
  st.error(f"🚨 Init Error: {ae}. Update library?", icon="🚨")
@@ -158,14 +162,11 @@ def wave_file_writer(filename: str, channels: int = 1, rate: int = AUDIO_SAMPLIN
158
  logger.error(f"Error closing wave file {filename}: {e_close}")
159
 
160
  async def generate_audio_live_async(api_text: str, output_filename: str, voice: Optional[str] = None) -> Optional[str]:
161
- """
162
- Generates audio using Gemini Live API (async version) via the GenerativeModel.
163
- """
164
  collected_audio = bytearray()
165
  task_id = os.path.basename(output_filename).split('.')[0]
166
  logger.info(f"πŸŽ™οΈ [{task_id}] Requesting audio: '{api_text[:60]}...'")
167
  try:
168
- # Corrected config structure for audio generation
169
  config = {
170
  "response_modalities": ["AUDIO"],
171
  "audio_encoding": "LINEAR16",
@@ -203,16 +204,24 @@ async def generate_audio_live_async(api_text: str, output_filename: str, voice:
203
  return None
204
 
205
  def generate_story_sequence_chrono(theme: str, num_scenes: int, num_timelines: int, divergence_prompt: str = "") -> Optional[ChronoWeaveResponse]:
206
- """
207
- Generates branching story sequences using Gemini structured output and validates with Pydantic.
208
- """
209
  st.info(f"πŸ“š Generating {num_timelines} timeline(s) x {num_scenes} scenes for: '{theme}'...")
210
  logger.info(f"Requesting story structure: Theme='{theme}', Timelines={num_timelines}, Scenes={num_scenes}")
211
  divergence_instruction = (
212
  f"Introduce clear points of divergence between timelines, after first scene if possible. "
213
  f"Hint: '{divergence_prompt}'. State divergence reason clearly. **For timeline_id 0, use 'Initial path' or 'Baseline scenario'.**"
214
  )
215
- prompt = f"""Act as narrative designer. Create story for theme: "{theme}". Instructions: 1. Exactly **{num_timelines}** timelines. 2. Each timeline exactly **{num_scenes}** scenes. 3. **NO humans/humanoids**. Focus: animals, fantasy creatures, animated objects, nature. 4. {divergence_instruction}. 5. Style: **'Simple, friendly kids animation, bright colors, rounded shapes'**, unless `timeline_visual_modifier` alters. 6. `audio_text`: single concise sentence (max 30 words). 7. `image_prompt`: descriptive, concise (target 15-35 words MAX). Focus on scene elements. **AVOID repeating general style**. 8. `character_description`: VERY brief (name, features). Target < 20 words. Output: ONLY valid JSON object adhering to schema. No text before/after. JSON Schema: ```json
 
 
 
 
 
 
 
 
 
 
216
  {json.dumps(ChronoWeaveResponse.model_json_schema(), indent=2)}
217
  ```"""
218
  try:
@@ -254,14 +263,37 @@ def generate_story_sequence_chrono(theme: str, num_scenes: int, num_timelines: i
254
 
255
  def generate_image_imagen(prompt: str, aspect_ratio: str = "1:1", task_id: str = "IMG") -> Optional[Image.Image]:
256
  """
257
- Generates an image.
258
- <<< IMPORTANT: This function needs to be rewritten using the Vertex AI SDK
259
- (google-cloud-aiplatform) to correctly call Imagen models. >>>
 
 
 
260
  """
261
  logger.info(f"πŸ–ΌοΈ [{task_id}] Requesting image: '{prompt[:70]}...' (Aspect: {aspect_ratio})")
262
- logger.error(f"❌ [{task_id}] Image generation skipped: Function needs update to use Vertex AI SDK for Imagen.")
263
- st.error(f"Image generation for {task_id} skipped: Requires Vertex AI SDK implementation.", icon="πŸ–ΌοΈ")
264
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  # --- Streamlit UI Elements ---
267
  st.sidebar.header("βš™οΈ Configuration")
 
23
 
24
  # Pydantic for data validation
25
  from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator
26
+ from typing import List, Optional, Dict, Any
27
 
28
  # Video and audio processing
29
  from moviepy.editor import ImageClip, AudioFileClip, concatenate_videoclips
30
+ # from moviepy.config import change_settings # Uncomment if you need to change settings
31
 
32
  # Type hints
33
  import typing_extensions as typing
 
36
  import nest_asyncio
37
  nest_asyncio.apply()
38
 
39
+ # Import Vertex AI SDK
40
+ from google.cloud import aiplatform
41
+
42
  # --- Logging Setup ---
43
  logging.basicConfig(
44
  level=logging.INFO,
 
58
  TEXT_MODEL_ID = "models/gemini-1.5-flash"
59
  AUDIO_MODEL_ID = "models/gemini-1.5-flash"
60
  AUDIO_SAMPLING_RATE = 24000
61
+ IMAGE_MODEL_ID = "imagen-3" # Now used with Vertex AI
62
  DEFAULT_ASPECT_RATIO = "1:1"
63
  VIDEO_FPS = 24
64
  VIDEO_CODEC = "libx264"
65
  AUDIO_CODEC = "aac"
66
  TEMP_DIR_BASE = ".chrono_temp"
67
 
68
+ # --- API Key and Vertex AI Config Handling ---
69
  GOOGLE_API_KEY = None
70
  try:
71
  GOOGLE_API_KEY = st.secrets["GOOGLE_API_KEY"]
 
78
  st.error("🚨 **Google API Key Not Found!** Please configure it.", icon="🚨")
79
  st.stop()
80
 
81
+ # --- Vertex AI Configuration ---
82
+ # Set up environment variables for Vertex AI; ensure these are in your Streamlit secrets or environment.
83
+ PROJECT_ID = st.secrets.get("PROJECT_ID") or os.environ.get("PROJECT_ID")
84
+ LOCATION = st.secrets.get("LOCATION") or os.environ.get("LOCATION", "us-central1")
85
+ IMAGE_ENDPOINT_ID = st.secrets.get("IMAGE_ENDPOINT_ID") or os.environ.get("IMAGE_ENDPOINT_ID")
86
+
87
+ if not PROJECT_ID or not IMAGE_ENDPOINT_ID:
88
+ st.error("🚨 **Vertex AI is not configured properly!** "
89
+ "Please set PROJECT_ID and IMAGE_ENDPOINT_ID in your secrets.", icon="🚨")
90
+ st.stop()
91
+
92
+ # --- Initialize Google Clients for text/audio ---
93
  try:
94
  genai.configure(api_key=GOOGLE_API_KEY)
95
  logger.info("Configured google-generativeai with API key.")
 
 
96
  client_standard = genai.GenerativeModel(TEXT_MODEL_ID)
97
  logger.info(f"Initialized text/JSON model handle: {TEXT_MODEL_ID}.")
 
 
98
  live_model = genai.GenerativeModel(AUDIO_MODEL_ID)
99
  logger.info(f"Initialized audio model handle: {AUDIO_MODEL_ID}.")
 
 
 
 
 
 
100
  except AttributeError as ae:
101
  logger.exception("AttributeError during Client Init.")
102
  st.error(f"🚨 Init Error: {ae}. Update library?", icon="🚨")
 
162
  logger.error(f"Error closing wave file {filename}: {e_close}")
163
 
164
  async def generate_audio_live_async(api_text: str, output_filename: str, voice: Optional[str] = None) -> Optional[str]:
165
+ """Generates audio using Gemini Live API (async version) via the GenerativeModel."""
 
 
166
  collected_audio = bytearray()
167
  task_id = os.path.basename(output_filename).split('.')[0]
168
  logger.info(f"πŸŽ™οΈ [{task_id}] Requesting audio: '{api_text[:60]}...'")
169
  try:
 
170
  config = {
171
  "response_modalities": ["AUDIO"],
172
  "audio_encoding": "LINEAR16",
 
204
  return None
205
 
206
  def generate_story_sequence_chrono(theme: str, num_scenes: int, num_timelines: int, divergence_prompt: str = "") -> Optional[ChronoWeaveResponse]:
207
+ """Generates branching story sequences using Gemini structured output and validates with Pydantic."""
 
 
208
  st.info(f"πŸ“š Generating {num_timelines} timeline(s) x {num_scenes} scenes for: '{theme}'...")
209
  logger.info(f"Requesting story structure: Theme='{theme}', Timelines={num_timelines}, Scenes={num_scenes}")
210
  divergence_instruction = (
211
  f"Introduce clear points of divergence between timelines, after first scene if possible. "
212
  f"Hint: '{divergence_prompt}'. State divergence reason clearly. **For timeline_id 0, use 'Initial path' or 'Baseline scenario'.**"
213
  )
214
+ prompt = f"""Act as narrative designer. Create story for theme: "{theme}". Instructions:
215
+ 1. Exactly **{num_timelines}** timelines.
216
+ 2. Each timeline exactly **{num_scenes}** scenes.
217
+ 3. **NO humans/humanoids**; focus on animals, fantasy creatures, animated objects, nature.
218
+ 4. {divergence_instruction}.
219
+ 5. Style: **'Simple, friendly kids animation, bright colors, rounded shapes'**, unless `timeline_visual_modifier` alters.
220
+ 6. `audio_text`: single concise sentence (max 30 words).
221
+ 7. `image_prompt`: descriptive, concise (target 15-35 words MAX). Focus on scene elements. **AVOID repeating general style**.
222
+ 8. `character_description`: VERY brief (name, features). Target < 20 words.
223
+ Output: ONLY valid JSON object adhering to schema. No text before/after.
224
+ JSON Schema: ```json
225
  {json.dumps(ChronoWeaveResponse.model_json_schema(), indent=2)}
226
  ```"""
227
  try:
 
263
 
264
  def generate_image_imagen(prompt: str, aspect_ratio: str = "1:1", task_id: str = "IMG") -> Optional[Image.Image]:
265
  """
266
+ Generates an image using Vertex AI's Imagen model via the Google Cloud AI Platform SDK.
267
+
268
+ Ensure that the following environment variables or Streamlit secrets are set:
269
+ - PROJECT_ID: Your Google Cloud project ID.
270
+ - LOCATION: The Vertex AI region (e.g., "us-central1").
271
+ - IMAGE_ENDPOINT_ID: The resource ID of your deployed Imagen endpoint.
272
  """
273
  logger.info(f"πŸ–ΌοΈ [{task_id}] Requesting image: '{prompt[:70]}...' (Aspect: {aspect_ratio})")
274
+ try:
275
+ # Initialize Vertex AI with your project and location.
276
+ aiplatform.init(project=PROJECT_ID, location=LOCATION)
277
+ # Retrieve your deployed endpoint.
278
+ endpoint = aiplatform.Endpoint(IMAGE_ENDPOINT_ID)
279
+ # Create a prediction instance. (The instance structure depends on your model.)
280
+ instance = {"prompt": prompt, "aspect_ratio": aspect_ratio}
281
+ prediction_response = endpoint.predict(instances=[instance])
282
+ # Assume the prediction returns a base64-encoded image string under the key "image".
283
+ import base64
284
+ image_base64 = prediction_response.predictions[0].get("image")
285
+ if not image_base64:
286
+ logger.error(f"❌ [{task_id}] No image returned in prediction.")
287
+ st.error(f"Image prediction failed for {task_id}: No image returned.", icon="πŸ–ΌοΈ")
288
+ return None
289
+ image_data = base64.b64decode(image_base64)
290
+ image = Image.open(BytesIO(image_data))
291
+ logger.info(f"βœ… [{task_id}] Image generated successfully.")
292
+ return image
293
+ except Exception as e:
294
+ logger.exception(f"❌ [{task_id}] Image generation failed: {e}")
295
+ st.error(f"Image generation failed for {task_id}: {e}", icon="πŸ–ΌοΈ")
296
+ return None
297
 
298
  # --- Streamlit UI Elements ---
299
  st.sidebar.header("βš™οΈ Configuration")