File size: 3,089 Bytes
8fb7841
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
from typing import List, Dict, Any, Optional
import vertexai
from vertexai.generative_models import GenerativeModel, Part
from google.auth import default
from google.auth.transport import requests


# TODO: check if this is the correct way to use Vertex AI
# TODO: add langfuse support
class VertexAIWrapper:
    """Wrapper for Vertex AI to support Gemini models."""
    
    def __init__(
        self,
        model_name: str = "gemini-1.5-pro",
        temperature: float = 0.7,
        print_cost: bool = False,
        verbose: bool = False,
        use_langfuse: bool = False
    ):
        """Initialize the Vertex AI wrapper.
        
        Args:
            model_name: Name of the model to use (e.g. "gemini-1.5-pro")
            temperature: Temperature for generation between 0 and 1
            print_cost: Whether to print the cost of the completion
            verbose: Whether to print verbose output
            use_langfuse: Whether to enable Langfuse logging
        """
        self.model_name = model_name
        self.temperature = temperature
        self.print_cost = print_cost
        self.verbose = verbose
        
        # Initialize Vertex AI
        project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
        location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1")
        if not project_id:
            raise ValueError("No GOOGLE_CLOUD_PROJECT found in environment variables")
            
        vertexai.init(project=project_id, location=location)
        self.model = GenerativeModel(model_name)
        
    def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str:
        """Process messages and return completion.
        
        Args:
            messages: List of message dictionaries containing type and content
            metadata: Optional metadata dictionary to pass to the model
            
        Returns:
            Generated text response from the model
            
        Raises:
            ValueError: If message type is not supported
        """
        parts = []
        
        for msg in messages:
            if msg["type"] == "text":
                parts.append(Part.from_text(msg["content"]))
            elif msg["type"] in ["image", "video"]:
                mime_type = "video/mp4" if msg["type"] == "video" else "image/jpeg"
                if isinstance(msg["content"], str):
                    # Handle GCS URI
                    parts.append(Part.from_uri(
                        msg["content"],
                        mime_type=mime_type
                    ))
                else:
                    # Handle file path or bytes
                    parts.append(Part.from_data(
                        msg["content"],
                        mime_type=mime_type
                    ))
                    
        response = self.model.generate_content(
            parts,
            generation_config={
                "temperature": self.temperature,
                "top_p": 0.95,
            }
        )
        
        return response.text