t2m / mllm_tools /vertex_ai.py
thanhkt's picture
Upload 26 files
8fb7841 verified
import os
from typing import List, Dict, Any, Optional
import vertexai
from vertexai.generative_models import GenerativeModel, Part
from google.auth import default
from google.auth.transport import requests
# TODO: check if this is the correct way to use Vertex AI
# TODO: add langfuse support
class VertexAIWrapper:
"""Wrapper for Vertex AI to support Gemini models."""
def __init__(
self,
model_name: str = "gemini-1.5-pro",
temperature: float = 0.7,
print_cost: bool = False,
verbose: bool = False,
use_langfuse: bool = False
):
"""Initialize the Vertex AI wrapper.
Args:
model_name: Name of the model to use (e.g. "gemini-1.5-pro")
temperature: Temperature for generation between 0 and 1
print_cost: Whether to print the cost of the completion
verbose: Whether to print verbose output
use_langfuse: Whether to enable Langfuse logging
"""
self.model_name = model_name
self.temperature = temperature
self.print_cost = print_cost
self.verbose = verbose
# Initialize Vertex AI
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1")
if not project_id:
raise ValueError("No GOOGLE_CLOUD_PROJECT found in environment variables")
vertexai.init(project=project_id, location=location)
self.model = GenerativeModel(model_name)
def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str:
"""Process messages and return completion.
Args:
messages: List of message dictionaries containing type and content
metadata: Optional metadata dictionary to pass to the model
Returns:
Generated text response from the model
Raises:
ValueError: If message type is not supported
"""
parts = []
for msg in messages:
if msg["type"] == "text":
parts.append(Part.from_text(msg["content"]))
elif msg["type"] in ["image", "video"]:
mime_type = "video/mp4" if msg["type"] == "video" else "image/jpeg"
if isinstance(msg["content"], str):
# Handle GCS URI
parts.append(Part.from_uri(
msg["content"],
mime_type=mime_type
))
else:
# Handle file path or bytes
parts.append(Part.from_data(
msg["content"],
mime_type=mime_type
))
response = self.model.generate_content(
parts,
generation_config={
"temperature": self.temperature,
"top_p": 0.95,
}
)
return response.text