File size: 6,103 Bytes
c1c563f
 
 
 
 
 
 
 
 
 
 
 
 
 
987e891
 
c1c563f
987e891
c1c563f
 
 
 
 
 
156c06c
c1c563f
 
987e891
 
156c06c
 
 
 
 
 
987e891
c1c563f
987e891
c1c563f
156c06c
987e891
 
 
 
 
156c06c
c1c563f
 
 
 
 
 
 
 
987e891
c1c563f
 
 
156c06c
c1c563f
156c06c
 
c1c563f
156c06c
c1c563f
 
156c06c
c1c563f
156c06c
 
c1c563f
 
 
 
 
156c06c
 
bc7d519
17573c5
 
bc7d519
156c06c
 
bc7d519
 
156c06c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1c563f
987e891
c1c563f
987e891
bc7d519
c1c563f
987e891
c1c563f
 
987e891
c1c563f
 
987e891
c1c563f
 
 
 
156c06c
 
c1c563f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# app.py

import os
import gradio as gr
import requests
import json
import time
import base64
import google.auth
import google.auth.transport.requests
from huggingface_hub import login

# --- 1. Configuration and Authentication ---

GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID")
GCP_LOCATION = os.environ.get("GCP_LOCATION")

# --- Authentication and Sanity Checks Block ---

hf_token = os.environ.get("HF_TOKEN")
if hf_token:
    print("Hugging Face token found. Logging in.")
    login(token=hf_token)
else:
    print("WARNING: Hugging Face token ('HF_TOKEN') not found.")

creds_json_str = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_JSON")

if not all([GCP_PROJECT_ID, GCP_LOCATION, creds_json_str]):
    missing_secrets = [s for s, v in {
        "GCP_PROJECT_ID": GCP_PROJECT_ID,
        "GCP_LOCATION": GCP_LOCATION,
        "GOOGLE_APPLICATION_CREDENTIALS_JSON": creds_json_str
    }.items() if not v]
    error_message = f"FATAL: Missing required secrets: {', '.join(missing_secrets)}."
    print(error_message)
    def generate_video(prompt):
        raise gr.Error(error_message)
else:
    print("All required secrets are loaded.")
    MODEL_ID = "veo-3.0-generate-preview"
    API_ENDPOINT = f"{GCP_LOCATION}-aiplatform.googleapis.com"
    PREDICT_URL = f"https://{API_ENDPOINT}/v1/projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/publishers/google/models/{MODEL_ID}:predictLongRunning"
    FETCH_URL = f"https://{API_ENDPOINT}/v1/projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/publishers/google/models/{MODEL_ID}:fetchPredictOperation"

    with open("gcp_creds.json", "w") as f: f.write(creds_json_str)
    SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
    credentials, _ = google.auth.load_credentials_from_file("gcp_creds.json", scopes=SCOPES)

    def get_access_token():
        auth_req = google.auth.transport.requests.Request()
        credentials.refresh(auth_req)
        return credentials.token

    # --- 2. Core Video Generation Logic ---
    def generate_video(prompt: str):
        if not prompt:
            raise gr.Error("Prompt cannot be empty.")
        yield "Status: Submitting job...", None
        try:
            headers = {"Authorization": f"Bearer {get_access_token()}", "Content-Type": "application/json"}
            payload = {"instances": [{"prompt": prompt}], "parameters": {"aspectRatio": "16:9", "sampleCount": 1, "durationSeconds": 8, "personGeneration": "allow_all", "addWatermark": True, "includeRaiReason": True, "generateAudio": True}}
            response = requests.post(PREDICT_URL, headers=headers, json=payload)
            response.raise_for_status()
            operation_name = response.json()["name"]
            print(f"Successfully submitted job. Operation Name: {operation_name}")
            MAX_POLL_ATTEMPTS = 60
            for i in range(MAX_POLL_ATTEMPTS):
                yield f"Status: Polling (Attempt {i+1}/{MAX_POLL_ATTEMPTS})...", None
                headers["Authorization"] = f"Bearer {get_access_token()}"
                fetch_payload = {"operationName": operation_name}
                poll_response = requests.post(FETCH_URL, headers=headers, json=fetch_payload)
                poll_response.raise_for_status()
                poll_result = poll_response.json()
                if poll_result.get("done"):
                    print("Job finished.")
                    print(f"Full response payload: {json.dumps(poll_result, indent=2)}") # For debugging
                    response_data = poll_result.get("response", {})
                    if "videos" in response_data and response_data["videos"]:
                        video_base64 = response_data["videos"][0]["bytesBase64Encoded"]
                        video_bytes = base64.b64decode(video_base64)
                        with open("generated_video.mp4", "wb") as f: f.write(video_bytes)
                        yield "Status: Done!", "generated_video.mp4"
                        return
                    else:
                        # <<< START: IMPROVED ERROR HANDLING >>>
                        error_message = "Video generation failed."
                        # Check for a specific error message in the operation response
                        if "error" in poll_result:
                            error_details = poll_result["error"].get("message", "No details provided.")
                            error_message += f"\nAPI Error: {error_details}"
                        # Check for a specific RAI reason
                        elif "raiResult" in response_data:
                            rai_reason = response_data.get("raiMediaFilteredReason", "Unknown reason.")
                            error_message += f"\nReason: Content was blocked by safety filters ({rai_reason})."
                        else:
                            error_message += "\nReason: The API did not return a video or a specific error."
                        
                        raise gr.Error(error_message)
                        # <<< END: IMPROVED ERROR HANDLING >>>
                time.sleep(10)
            raise gr.Error("Operation timed out.")
        except Exception as e:
            print(f"An error occurred: {e}")
            raise gr.Error(str(e))

# --- 3. Gradio User Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🎬 Vertex AI VEO Video Generator")
    gr.Markdown("Generate short videos from a text prompt using Google's VEO model.")
    with gr.Row():
        with gr.Column(scale=1):
            prompt_input = gr.Textbox(label="Prompt", placeholder="A majestic lion...", lines=3)
            submit_button = gr.Button("Generate Video", variant="primary")
        with gr.Column(scale=1):
            status_output = gr.Markdown("Status: Ready")
            video_output = gr.Video(label="Generated Video", interactive=False)
    gr.Examples(["A high-speed drone shot flying through a futuristic city with flying vehicles."], inputs=prompt_input)
    submit_button.click(fn=generate_video, inputs=prompt_input, outputs=[status_output, video_output])

demo.launch()