t2m / mllm_tools /litellm.py
thanhkt's picture
Upload 26 files
8fb7841 verified
import json
import re
from typing import List, Dict, Any, Union, Optional
import io
import os
import base64
from PIL import Image
import mimetypes
import litellm
from litellm import completion, completion_cost
from dotenv import load_dotenv
load_dotenv()
class LiteLLMWrapper:
"""Wrapper for LiteLLM to support multiple models and logging"""
def __init__(
self,
model_name: str = "gpt-4-vision-preview",
temperature: float = 0.7,
print_cost: bool = False,
verbose: bool = False,
use_langfuse: bool = True,
):
"""
Initialize the LiteLLM wrapper
Args:
model_name: Name of the model to use (e.g. "azure/gpt-4", "vertex_ai/gemini-pro")
temperature: Temperature for completion
print_cost: Whether to print the cost of the completion
verbose: Whether to print verbose output
use_langfuse: Whether to enable Langfuse logging
"""
self.model_name = model_name
self.temperature = temperature
self.print_cost = print_cost
self.verbose = verbose
self.accumulated_cost = 0
if self.verbose:
os.environ['LITELLM_LOG'] = 'DEBUG'
# Set langfuse callback only if enabled
if use_langfuse:
litellm.success_callback = ["langfuse"]
litellm.failure_callback = ["langfuse"]
def _encode_file(self, file_path: Union[str, Image.Image]) -> str:
"""
Encode local file or PIL Image to base64 string
Args:
file_path: Path to local file or PIL Image object
Returns:
Base64 encoded file string
"""
if isinstance(file_path, Image.Image):
buffered = io.BytesIO()
file_path.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
else:
with open(file_path, "rb") as file:
return base64.b64encode(file.read()).decode("utf-8")
def _get_mime_type(self, file_path: str) -> str:
"""
Get the MIME type of a file based on its extension
Args:
file_path: Path to the file
Returns:
MIME type as a string (e.g., "image/jpeg", "audio/mp3")
"""
mime_type, _ = mimetypes.guess_type(file_path)
if mime_type is None:
raise ValueError(f"Unsupported file type: {file_path}")
return mime_type
def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str:
"""
Process messages and return completion
Args:
messages: List of message dictionaries with 'type' and 'content' keys
metadata: Optional metadata to pass to litellm completion, e.g. for Langfuse tracking
Returns:
Generated text response
"""
if metadata is None:
print("No metadata provided, using empty metadata")
metadata = {}
metadata["trace_name"] = f"litellm-completion-{self.model_name}"
# Convert messages to LiteLLM format
formatted_messages = []
for msg in messages:
if msg["type"] == "text":
formatted_messages.append({
"role": "user",
"content": [{"type": "text", "text": msg["content"]}]
})
elif msg["type"] in ["image", "audio", "video"]:
# Check if content is a local file path or PIL Image
if isinstance(msg["content"], Image.Image) or os.path.isfile(msg["content"]):
try:
if isinstance(msg["content"], Image.Image):
mime_type = "image/png"
else:
mime_type = self._get_mime_type(msg["content"])
base64_data = self._encode_file(msg["content"])
data_url = f"data:{mime_type};base64,{base64_data}"
except ValueError as e:
print(f"Error processing file {msg['content']}: {e}")
continue
else:
data_url = msg["content"]
# Append the formatted message based on the model
if "gemini" in self.model_name:
formatted_messages.append({
"role": "user",
"content": [
{
"type": "image_url",
"image_url": data_url
}
]
})
elif "gpt" in self.model_name:
# GPT and other models expect a different format
if msg["type"] == "image":
# Default format for images and videos in GPT
formatted_messages.append({
"role": "user",
"content": [
{
"type": f"image_url",
f"{msg['type']}_url": {
"url": data_url,
"detail": "high"
}
}
]
})
else:
raise ValueError("For GPT, only text and image inferencing are supported")
else:
raise ValueError("Only support Gemini and Gpt for Multimodal capability now")
try:
# if it's openai o series model, set temperature to None and reasoning_effort to "medium"
if (re.match(r"^o\d+.*$", self.model_name) or re.match(r"^openai/o.*$", self.model_name)):
self.temperature = None
self.reasoning_effort = "medium"
response = completion(
model=self.model_name,
messages=formatted_messages,
temperature=self.temperature,
reasoning_effort=self.reasoning_effort,
metadata=metadata,
max_retries=99
)
else:
response = completion(
model=self.model_name,
messages=formatted_messages,
temperature=self.temperature,
metadata=metadata,
max_retries=99
)
if self.print_cost:
# pass your response from completion to completion_cost
cost = completion_cost(completion_response=response)
formatted_string = f"Cost: ${float(cost):.10f}"
# print(formatted_string)
self.accumulated_cost += cost
print(f"Accumulated Cost: ${self.accumulated_cost:.10f}")
content = response.choices[0].message.content
if content is None:
print(f"Got null response from model. Full response: {response}")
return content
except Exception as e:
print(f"Error in model completion: {e}")
return str(e)
if __name__ == "__main__":
pass