File size: 10,078 Bytes
8fb7841 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 |
import os
import re
from typing import List, Dict, Any, Optional, Union
import io
import base64
from PIL import Image
import mimetypes
from litellm import completion, completion_cost
from dotenv import load_dotenv
load_dotenv()
class OpenRouterWrapper:
"""
OpenRouter wrapper using LiteLLM for various language models.
Compatible with the existing wrapper interface.
"""
def __init__(
self,
model_name: str = "openrouter/deepseek/deepseek-chat-v3-0324:free",
temperature: float = 0.7,
print_cost: bool = False,
verbose: bool = False,
use_langfuse: bool = True,
site_url: str = "",
app_name: str = "Theory2Manim"
):
"""
Initialize OpenRouter wrapper.
Args:
model_name: OpenRouter model name (with openrouter/ prefix)
temperature: Temperature for completion
print_cost: Whether to print the cost of the completion
verbose: Whether to print verbose output
use_langfuse: Whether to enable Langfuse logging
site_url: Optional site URL for tracking
app_name: Optional app name for tracking
"""
self.model_name = model_name
self.temperature = temperature
self.print_cost = print_cost
self.verbose = verbose
self.accumulated_cost = 0
# Setup OpenRouter environment variables
api_key = os.getenv("OPENROUTER_API_KEY")
if not api_key:
raise ValueError("No OPENROUTER_API_KEY found. Please set the environment variable.")
os.environ["OPENROUTER_API_KEY"] = api_key
os.environ["OPENROUTER_API_BASE"] = "https://openrouter.ai/api/v1"
if site_url or os.getenv("OR_SITE_URL"):
os.environ["OR_SITE_URL"] = site_url or os.getenv("OR_SITE_URL", "")
if app_name:
os.environ["OR_APP_NAME"] = app_name
if self.verbose:
os.environ['LITELLM_LOG'] = 'DEBUG'
# Set langfuse callback only if enabled
if use_langfuse:
import litellm
litellm.success_callback = ["langfuse"]
litellm.failure_callback = ["langfuse"]
def _encode_file(self, file_path: Union[str, Image.Image]) -> str:
"""
Encode local file or PIL Image to base64 string
Args:
file_path: Path to local file or PIL Image object
Returns:
Base64 encoded file string
"""
if isinstance(file_path, Image.Image):
buffered = io.BytesIO()
file_path.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
else:
with open(file_path, "rb") as file:
return base64.b64encode(file.read()).decode("utf-8")
def _get_mime_type(self, file_path: str) -> str:
"""
Get the MIME type of a file based on its extension
Args:
file_path: Path to the file
Returns:
MIME type as a string (e.g., "image/jpeg", "audio/mp3")
"""
mime_type, _ = mimetypes.guess_type(file_path)
if mime_type is None:
raise ValueError(f"Unsupported file type: {file_path}")
return mime_type
def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str:
"""
Process messages and return completion
Args:
messages: List of message dictionaries with 'type' and 'content' keys
metadata: Optional metadata to pass to completion
Returns:
Generated text response
"""
if metadata is None:
metadata = {}
metadata["trace_name"] = f"openrouter-completion-{self.model_name}"
# Convert messages to LiteLLM format
formatted_messages = []
for msg in messages:
if msg["type"] == "text":
formatted_messages.append({
"role": "user",
"content": [{"type": "text", "text": msg["content"]}]
})
elif msg["type"] in ["image", "audio", "video"]:
# Check if content is a local file path or PIL Image
if isinstance(msg["content"], Image.Image) or os.path.isfile(msg["content"]):
try:
if isinstance(msg["content"], Image.Image):
mime_type = "image/png"
else:
mime_type = self._get_mime_type(msg["content"])
base64_data = self._encode_file(msg["content"])
data_url = f"data:{mime_type};base64,{base64_data}"
except ValueError as e:
print(f"Error processing file {msg['content']}: {e}")
continue
else:
data_url = msg["content"]
# Format for vision models
if msg["type"] == "image":
formatted_messages.append({
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": data_url,
"detail": "high"
}
}
]
})
else:
# For audio/video, treat as text for now
formatted_messages.append({
"role": "user",
"content": [{"type": "text", "text": f"[{msg['type'].upper()}]: {msg['content']}"}]
})
try:
response = completion(
model=self.model_name,
messages=formatted_messages,
temperature=self.temperature,
metadata=metadata,
max_retries=99
)
if self.print_cost:
# Calculate and print cost
cost = completion_cost(completion_response=response)
self.accumulated_cost += cost
print(f"Accumulated Cost: ${self.accumulated_cost:.10f}")
content = response.choices[0].message.content
if content is None:
print(f"Got null response from model. Full response: {response}")
return "Error: Received null response from model"
# Check if the response contains error messages about unmapped models
if "This model isn't mapped yet" in content or "model isn't mapped" in content.lower():
error_msg = f"Error: Model {self.model_name} is not supported by LiteLLM. Please use a supported model."
print(error_msg)
return error_msg
return content
except Exception as e:
print(f"Error in OpenRouter completion: {e}")
return f"Error: {str(e)}"
class OpenRouterClient:
"""
Legacy OpenRouter client for backward compatibility.
"""
def __init__(self, api_key: str, site_url: str = "", app_name: str = "Theory2Manim"):
"""
Initialize OpenRouter client.
Args:
api_key: OpenRouter API key
site_url: Optional site URL for tracking
app_name: Optional app name for tracking
"""
os.environ["OPENROUTER_API_KEY"] = api_key
os.environ["OPENROUTER_API_BASE"] = "https://openrouter.ai/api/v1"
if site_url:
os.environ["OR_SITE_URL"] = site_url
if app_name:
os.environ["OR_APP_NAME"] = app_name
def complete(
self,
messages: List[Dict[str, str]],
model: str = "openrouter/openai/gpt-3.5-turbo",
transforms: Optional[List[str]] = None,
route: Optional[str] = None,
**kwargs
) -> Any:
"""
Generate completion using OpenRouter model.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: Model name (with openrouter/ prefix)
transforms: Optional transforms to apply
route: Optional route specification
**kwargs: Additional parameters for completion
Returns:
Completion response
"""
params = {
"model": model,
"messages": messages,
**kwargs
}
if transforms:
params["transforms"] = transforms
if route:
params["route"] = route
return completion(**params)
# Convenience functions for common models
def ds_r1(messages: List[Dict[str, str]], **kwargs) -> Any:
"""Use GPT-3.5 Turbo via OpenRouter"""
client = OpenRouterClient(os.environ.get("OPENROUTER_API_KEY", ""))
return client.complete(messages, "deepseek/deepseek-r1:free", **kwargs)
def ds_v3(messages: List[Dict[str, str]], **kwargs) -> Any:
"""Use GPT-4 via OpenRouter"""
client = OpenRouterClient(os.environ.get("OPENROUTER_API_KEY", ""))
return client.complete(messages, "deepseek/deepseek-chat-v3-0324:free", **kwargs)
def qwen3(messages: List[Dict[str, str]], **kwargs) -> Any:
"""Use Claude-2 via OpenRouter"""
client = OpenRouterClient(os.environ.get("OPENROUTER_API_KEY", ""))
return client.complete(messages, "qwen/qwen3-235b-a22b:free", **kwargs)
|