veo3 / app.py
Deadmon's picture
Update app.py
156c06c verified
raw
history blame
6.1 kB
# app.py
import os
import gradio as gr
import requests
import json
import time
import base64
import google.auth
import google.auth.transport.requests
from huggingface_hub import login
# --- 1. Configuration and Authentication ---
GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID")
GCP_LOCATION = os.environ.get("GCP_LOCATION")
# --- Authentication and Sanity Checks Block ---
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
print("Hugging Face token found. Logging in.")
login(token=hf_token)
else:
print("WARNING: Hugging Face token ('HF_TOKEN') not found.")
creds_json_str = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_JSON")
if not all([GCP_PROJECT_ID, GCP_LOCATION, creds_json_str]):
missing_secrets = [s for s, v in {
"GCP_PROJECT_ID": GCP_PROJECT_ID,
"GCP_LOCATION": GCP_LOCATION,
"GOOGLE_APPLICATION_CREDENTIALS_JSON": creds_json_str
}.items() if not v]
error_message = f"FATAL: Missing required secrets: {', '.join(missing_secrets)}."
print(error_message)
def generate_video(prompt):
raise gr.Error(error_message)
else:
print("All required secrets are loaded.")
MODEL_ID = "veo-3.0-generate-preview"
API_ENDPOINT = f"{GCP_LOCATION}-aiplatform.googleapis.com"
PREDICT_URL = f"https://{API_ENDPOINT}/v1/projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/publishers/google/models/{MODEL_ID}:predictLongRunning"
FETCH_URL = f"https://{API_ENDPOINT}/v1/projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/publishers/google/models/{MODEL_ID}:fetchPredictOperation"
with open("gcp_creds.json", "w") as f: f.write(creds_json_str)
SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
credentials, _ = google.auth.load_credentials_from_file("gcp_creds.json", scopes=SCOPES)
def get_access_token():
auth_req = google.auth.transport.requests.Request()
credentials.refresh(auth_req)
return credentials.token
# --- 2. Core Video Generation Logic ---
def generate_video(prompt: str):
if not prompt:
raise gr.Error("Prompt cannot be empty.")
yield "Status: Submitting job...", None
try:
headers = {"Authorization": f"Bearer {get_access_token()}", "Content-Type": "application/json"}
payload = {"instances": [{"prompt": prompt}], "parameters": {"aspectRatio": "16:9", "sampleCount": 1, "durationSeconds": 8, "personGeneration": "allow_all", "addWatermark": True, "includeRaiReason": True, "generateAudio": True}}
response = requests.post(PREDICT_URL, headers=headers, json=payload)
response.raise_for_status()
operation_name = response.json()["name"]
print(f"Successfully submitted job. Operation Name: {operation_name}")
MAX_POLL_ATTEMPTS = 60
for i in range(MAX_POLL_ATTEMPTS):
yield f"Status: Polling (Attempt {i+1}/{MAX_POLL_ATTEMPTS})...", None
headers["Authorization"] = f"Bearer {get_access_token()}"
fetch_payload = {"operationName": operation_name}
poll_response = requests.post(FETCH_URL, headers=headers, json=fetch_payload)
poll_response.raise_for_status()
poll_result = poll_response.json()
if poll_result.get("done"):
print("Job finished.")
print(f"Full response payload: {json.dumps(poll_result, indent=2)}") # For debugging
response_data = poll_result.get("response", {})
if "videos" in response_data and response_data["videos"]:
video_base64 = response_data["videos"][0]["bytesBase64Encoded"]
video_bytes = base64.b64decode(video_base64)
with open("generated_video.mp4", "wb") as f: f.write(video_bytes)
yield "Status: Done!", "generated_video.mp4"
return
else:
# <<< START: IMPROVED ERROR HANDLING >>>
error_message = "Video generation failed."
# Check for a specific error message in the operation response
if "error" in poll_result:
error_details = poll_result["error"].get("message", "No details provided.")
error_message += f"\nAPI Error: {error_details}"
# Check for a specific RAI reason
elif "raiResult" in response_data:
rai_reason = response_data.get("raiMediaFilteredReason", "Unknown reason.")
error_message += f"\nReason: Content was blocked by safety filters ({rai_reason})."
else:
error_message += "\nReason: The API did not return a video or a specific error."
raise gr.Error(error_message)
# <<< END: IMPROVED ERROR HANDLING >>>
time.sleep(10)
raise gr.Error("Operation timed out.")
except Exception as e:
print(f"An error occurred: {e}")
raise gr.Error(str(e))
# --- 3. Gradio User Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎬 Vertex AI VEO Video Generator")
gr.Markdown("Generate short videos from a text prompt using Google's VEO model.")
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(label="Prompt", placeholder="A majestic lion...", lines=3)
submit_button = gr.Button("Generate Video", variant="primary")
with gr.Column(scale=1):
status_output = gr.Markdown("Status: Ready")
video_output = gr.Video(label="Generated Video", interactive=False)
gr.Examples(["A high-speed drone shot flying through a futuristic city with flying vehicles."], inputs=prompt_input)
submit_button.click(fn=generate_video, inputs=prompt_input, outputs=[status_output, video_output])
demo.launch()