|
import os |
|
from typing import List, Dict, Any, Optional |
|
import vertexai |
|
from vertexai.generative_models import GenerativeModel, Part |
|
from google.auth import default |
|
from google.auth.transport import requests |
|
|
|
|
|
|
|
|
|
class VertexAIWrapper: |
|
"""Wrapper for Vertex AI to support Gemini models.""" |
|
|
|
def __init__( |
|
self, |
|
model_name: str = "gemini-1.5-pro", |
|
temperature: float = 0.7, |
|
print_cost: bool = False, |
|
verbose: bool = False, |
|
use_langfuse: bool = False |
|
): |
|
"""Initialize the Vertex AI wrapper. |
|
|
|
Args: |
|
model_name: Name of the model to use (e.g. "gemini-1.5-pro") |
|
temperature: Temperature for generation between 0 and 1 |
|
print_cost: Whether to print the cost of the completion |
|
verbose: Whether to print verbose output |
|
use_langfuse: Whether to enable Langfuse logging |
|
""" |
|
self.model_name = model_name |
|
self.temperature = temperature |
|
self.print_cost = print_cost |
|
self.verbose = verbose |
|
|
|
|
|
project_id = os.getenv("GOOGLE_CLOUD_PROJECT") |
|
location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1") |
|
if not project_id: |
|
raise ValueError("No GOOGLE_CLOUD_PROJECT found in environment variables") |
|
|
|
vertexai.init(project=project_id, location=location) |
|
self.model = GenerativeModel(model_name) |
|
|
|
def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str: |
|
"""Process messages and return completion. |
|
|
|
Args: |
|
messages: List of message dictionaries containing type and content |
|
metadata: Optional metadata dictionary to pass to the model |
|
|
|
Returns: |
|
Generated text response from the model |
|
|
|
Raises: |
|
ValueError: If message type is not supported |
|
""" |
|
parts = [] |
|
|
|
for msg in messages: |
|
if msg["type"] == "text": |
|
parts.append(Part.from_text(msg["content"])) |
|
elif msg["type"] in ["image", "video"]: |
|
mime_type = "video/mp4" if msg["type"] == "video" else "image/jpeg" |
|
if isinstance(msg["content"], str): |
|
|
|
parts.append(Part.from_uri( |
|
msg["content"], |
|
mime_type=mime_type |
|
)) |
|
else: |
|
|
|
parts.append(Part.from_data( |
|
msg["content"], |
|
mime_type=mime_type |
|
)) |
|
|
|
response = self.model.generate_content( |
|
parts, |
|
generation_config={ |
|
"temperature": self.temperature, |
|
"top_p": 0.95, |
|
} |
|
) |
|
|
|
return response.text |