mgbam commited on
Commit
9728f29
Β·
verified Β·
1 Parent(s): 7401371

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -37
app.py CHANGED
@@ -16,7 +16,7 @@ import contextlib
16
  import asyncio
17
  import uuid # For unique identifiers
18
  import shutil # For directory operations
19
- import logging # For improved logging
20
 
21
  # Image handling
22
  from PIL import Image
@@ -27,7 +27,6 @@ from typing import List, Optional, Dict, Any
27
 
28
  # Video and audio processing
29
  from moviepy.editor import ImageClip, AudioFileClip, concatenate_videoclips
30
- # from moviepy.config import change_settings # Uncomment if you need to change settings
31
 
32
  # Type hints
33
  import typing_extensions as typing
@@ -36,8 +35,9 @@ import typing_extensions as typing
36
  import nest_asyncio
37
  nest_asyncio.apply()
38
 
39
- # Import Vertex AI SDK
40
- from google.cloud import aiplatform
 
41
 
42
  # --- Logging Setup ---
43
  logging.basicConfig(
@@ -58,14 +58,15 @@ Generate multiple, branching story timelines from a single theme using AI, compl
58
  TEXT_MODEL_ID = "models/gemini-1.5-flash"
59
  AUDIO_MODEL_ID = "models/gemini-1.5-flash"
60
  AUDIO_SAMPLING_RATE = 24000
61
- IMAGE_MODEL_ID = "imagen-3" # Now used with Vertex AI
 
62
  DEFAULT_ASPECT_RATIO = "1:1"
63
  VIDEO_FPS = 24
64
  VIDEO_CODEC = "libx264"
65
  AUDIO_CODEC = "aac"
66
  TEMP_DIR_BASE = ".chrono_temp"
67
 
68
- # --- API Key and Vertex AI Config Handling ---
69
  GOOGLE_API_KEY = None
70
  try:
71
  GOOGLE_API_KEY = st.secrets["GOOGLE_API_KEY"]
@@ -78,17 +79,16 @@ except KeyError:
78
  st.error("🚨 **Google API Key Not Found!** Please configure it.", icon="🚨")
79
  st.stop()
80
 
81
- # --- Vertex AI Configuration ---
82
- # Set up environment variables for Vertex AI; ensure these are in your Streamlit secrets or environment.
83
  PROJECT_ID = st.secrets.get("PROJECT_ID") or os.environ.get("PROJECT_ID")
84
  LOCATION = st.secrets.get("LOCATION") or os.environ.get("LOCATION", "us-central1")
85
- IMAGE_ENDPOINT_ID = st.secrets.get("IMAGE_ENDPOINT_ID") or os.environ.get("IMAGE_ENDPOINT_ID")
86
-
87
- if not PROJECT_ID or not IMAGE_ENDPOINT_ID:
88
- st.error("🚨 **Vertex AI is not configured properly!** "
89
- "Please set PROJECT_ID and IMAGE_ENDPOINT_ID in your secrets.", icon="🚨")
90
  st.stop()
91
 
 
 
 
92
  # --- Initialize Google Clients for text/audio ---
93
  try:
94
  genai.configure(api_key=GOOGLE_API_KEY)
@@ -117,7 +117,7 @@ class StorySegment(BaseModel):
117
  @field_validator('image_prompt')
118
  @classmethod
119
  def image_prompt_no_humans(cls, v: str) -> str:
120
- if any(w in v.lower() for w in ["person", "people", "human", "man", "woman", "boy", "girl", "child"]):
121
  logger.warning(f"Prompt '{v[:50]}...' may contain humans.")
122
  return v
123
 
@@ -263,36 +263,31 @@ JSON Schema: ```json
263
 
264
  def generate_image_imagen(prompt: str, aspect_ratio: str = "1:1", task_id: str = "IMG") -> Optional[Image.Image]:
265
  """
266
- Generates an image using Vertex AI's Imagen model via the Google Cloud AI Platform SDK.
267
 
268
- Ensure that the following environment variables or Streamlit secrets are set:
269
- - PROJECT_ID: Your Google Cloud project ID.
270
- - LOCATION: The Vertex AI region (e.g., "us-central1").
271
- - IMAGE_ENDPOINT_ID: The resource ID of your deployed Imagen endpoint.
272
  """
273
  logger.info(f"πŸ–ΌοΈ [{task_id}] Requesting image: '{prompt[:70]}...' (Aspect: {aspect_ratio})")
274
  try:
275
- # Initialize Vertex AI with your project and location.
276
- aiplatform.init(project=PROJECT_ID, location=LOCATION)
277
- # Retrieve your deployed endpoint.
278
- endpoint = aiplatform.Endpoint(IMAGE_ENDPOINT_ID)
279
- # Create a prediction instance. (The instance structure depends on your model.)
280
- instance = {"prompt": prompt, "aspect_ratio": aspect_ratio}
281
- prediction_response = endpoint.predict(instances=[instance])
282
- # Assume the prediction returns a base64-encoded image string under the key "image".
283
- import base64
284
- image_base64 = prediction_response.predictions[0].get("image")
285
- if not image_base64:
286
- logger.error(f"❌ [{task_id}] No image returned in prediction.")
287
- st.error(f"Image prediction failed for {task_id}: No image returned.", icon="πŸ–ΌοΈ")
288
- return None
289
- image_data = base64.b64decode(image_base64)
290
- image = Image.open(BytesIO(image_data))
291
  logger.info(f"βœ… [{task_id}] Image generated successfully.")
292
  return image
293
  except Exception as e:
294
  logger.exception(f"❌ [{task_id}] Image generation failed: {e}")
295
- st.error(f"Image generation failed for {task_id}: {e}", icon="πŸ–ΌοΈ")
296
  return None
297
 
298
  # --- Streamlit UI Elements ---
@@ -325,7 +320,7 @@ if generate_button:
325
  os.makedirs(temp_dir, exist_ok=True)
326
  logger.info(f"Created temp dir: {temp_dir}")
327
  except OSError as e:
328
- st.error(f"🚨 Failed create temp dir {temp_dir}: {e}", icon="πŸ“‚")
329
  st.stop()
330
  final_video_paths, generation_errors = {}, {}
331
 
 
16
  import asyncio
17
  import uuid # For unique identifiers
18
  import shutil # For directory operations
19
+ import logging # For logging
20
 
21
  # Image handling
22
  from PIL import Image
 
27
 
28
  # Video and audio processing
29
  from moviepy.editor import ImageClip, AudioFileClip, concatenate_videoclips
 
30
 
31
  # Type hints
32
  import typing_extensions as typing
 
35
  import nest_asyncio
36
  nest_asyncio.apply()
37
 
38
+ # Import Vertex AI SDK for image generation (Preview API)
39
+ import vertexai
40
+ from vertexai.preview.vision_models import ImageGenerationModel
41
 
42
  # --- Logging Setup ---
43
  logging.basicConfig(
 
58
  TEXT_MODEL_ID = "models/gemini-1.5-flash"
59
  AUDIO_MODEL_ID = "models/gemini-1.5-flash"
60
  AUDIO_SAMPLING_RATE = 24000
61
+ # IMAGE_MODEL_ID is now used with the preview API
62
+ IMAGE_MODEL_ID = "imagen-3.0-generate-002"
63
  DEFAULT_ASPECT_RATIO = "1:1"
64
  VIDEO_FPS = 24
65
  VIDEO_CODEC = "libx264"
66
  AUDIO_CODEC = "aac"
67
  TEMP_DIR_BASE = ".chrono_temp"
68
 
69
+ # --- Secrets and Environment Variables ---
70
  GOOGLE_API_KEY = None
71
  try:
72
  GOOGLE_API_KEY = st.secrets["GOOGLE_API_KEY"]
 
79
  st.error("🚨 **Google API Key Not Found!** Please configure it.", icon="🚨")
80
  st.stop()
81
 
82
+ # For Vertex AI, we also need PROJECT_ID and LOCATION.
 
83
  PROJECT_ID = st.secrets.get("PROJECT_ID") or os.environ.get("PROJECT_ID")
84
  LOCATION = st.secrets.get("LOCATION") or os.environ.get("LOCATION", "us-central1")
85
+ if not PROJECT_ID:
86
+ st.error("🚨 **PROJECT_ID not set!** Please add PROJECT_ID to your secrets.", icon="🚨")
 
 
 
87
  st.stop()
88
 
89
+ # Initialize Vertex AI (used for image generation)
90
+ vertexai.init(project=PROJECT_ID, location=LOCATION)
91
+
92
  # --- Initialize Google Clients for text/audio ---
93
  try:
94
  genai.configure(api_key=GOOGLE_API_KEY)
 
117
  @field_validator('image_prompt')
118
  @classmethod
119
  def image_prompt_no_humans(cls, v: str) -> str:
120
+ if any(word in v.lower() for word in ["person", "people", "human", "man", "woman", "boy", "girl", "child"]):
121
  logger.warning(f"Prompt '{v[:50]}...' may contain humans.")
122
  return v
123
 
 
263
 
264
  def generate_image_imagen(prompt: str, aspect_ratio: str = "1:1", task_id: str = "IMG") -> Optional[Image.Image]:
265
  """
266
+ Generates an image using Vertex AI's Imagen model via the Vertex AI preview API.
267
 
268
+ It calls the ImageGenerationModel from vertexai.preview.vision_models with the pretrained model "imagen-3.0-generate-002" and returns a PIL Image.
 
 
 
269
  """
270
  logger.info(f"πŸ–ΌοΈ [{task_id}] Requesting image: '{prompt[:70]}...' (Aspect: {aspect_ratio})")
271
  try:
272
+ # Load the pretrained Imagen model
273
+ generation_model = ImageGenerationModel.from_pretrained(IMAGE_MODEL_ID)
274
+ # Generate the image (here we generate one image)
275
+ images = generation_model.generate_images(
276
+ prompt=prompt,
277
+ number_of_images=1,
278
+ aspect_ratio=aspect_ratio,
279
+ negative_prompt="",
280
+ person_generation="",
281
+ safety_filter_level="",
282
+ add_watermark=True,
283
+ )
284
+ # Return the generated PIL image (using the internal _pil_image attribute)
285
+ image = images[0]._pil_image
 
 
286
  logger.info(f"βœ… [{task_id}] Image generated successfully.")
287
  return image
288
  except Exception as e:
289
  logger.exception(f"❌ [{task_id}] Image generation failed: {e}")
290
+ st.error(f"Image generation for {task_id} failed: {e}", icon="πŸ–ΌοΈ")
291
  return None
292
 
293
  # --- Streamlit UI Elements ---
 
320
  os.makedirs(temp_dir, exist_ok=True)
321
  logger.info(f"Created temp dir: {temp_dir}")
322
  except OSError as e:
323
+ st.error(f"🚨 Failed to create temp dir {temp_dir}: {e}", icon="πŸ“‚")
324
  st.stop()
325
  final_video_paths, generation_errors = {}, {}
326