Spaces:
Configuration error
Configuration error
import copy | |
import sys | |
import time | |
from datetime import datetime | |
from unittest import mock | |
from dotenv import load_dotenv | |
from litellm.types.utils import StandardCallbackDynamicParams | |
load_dotenv() | |
import os | |
sys.path.insert( | |
0, os.path.abspath("../..") | |
) # Adds the parent directory to the system-path | |
import pytest | |
import litellm | |
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, headers | |
from litellm.litellm_core_utils.duration_parser import duration_in_seconds | |
from litellm.litellm_core_utils.duration_parser import ( | |
get_last_day_of_month, | |
_extract_from_regex, | |
) | |
from litellm.utils import ( | |
check_valid_key, | |
create_pretrained_tokenizer, | |
create_tokenizer, | |
function_to_dict, | |
get_llm_provider, | |
get_max_tokens, | |
get_supported_openai_params, | |
get_token_count, | |
get_valid_models, | |
trim_messages, | |
validate_environment, | |
) | |
from unittest.mock import AsyncMock, MagicMock, patch | |
# Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils' | |
def reset_mock_cache(): | |
from litellm.utils import _model_cache | |
_model_cache.flush_cache() | |
# Test 1: Check trimming of normal message | |
def test_basic_trimming(): | |
litellm._turn_on_debug() | |
messages = [ | |
{ | |
"role": "user", | |
"content": "This is a long message that definitely exceeds the token limit.", | |
} | |
] | |
trimmed_messages = trim_messages(messages, model="claude-2", max_tokens=8) | |
print("trimmed messages") | |
print(trimmed_messages) | |
# print(get_token_count(messages=trimmed_messages, model="claude-2")) | |
assert (get_token_count(messages=trimmed_messages, model="claude-2")) <= 8 | |
# test_basic_trimming() | |
def test_basic_trimming_no_max_tokens_specified(): | |
messages = [ | |
{ | |
"role": "user", | |
"content": "This is a long message that is definitely under the token limit.", | |
} | |
] | |
trimmed_messages = trim_messages(messages, model="gpt-4") | |
print("trimmed messages for gpt-4") | |
print(trimmed_messages) | |
# print(get_token_count(messages=trimmed_messages, model="claude-2")) | |
assert ( | |
get_token_count(messages=trimmed_messages, model="gpt-4") | |
) <= litellm.model_cost["gpt-4"]["max_tokens"] | |
# test_basic_trimming_no_max_tokens_specified() | |
def test_multiple_messages_trimming(): | |
messages = [ | |
{ | |
"role": "user", | |
"content": "This is a long message that will exceed the token limit.", | |
}, | |
{ | |
"role": "user", | |
"content": "This is another long message that will also exceed the limit.", | |
}, | |
] | |
trimmed_messages = trim_messages( | |
messages=messages, model="gpt-3.5-turbo", max_tokens=20 | |
) | |
# print(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) | |
assert (get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) <= 20 | |
# test_multiple_messages_trimming() | |
def test_multiple_messages_no_trimming(): | |
messages = [ | |
{ | |
"role": "user", | |
"content": "This is a long message that will exceed the token limit.", | |
}, | |
{ | |
"role": "user", | |
"content": "This is another long message that will also exceed the limit.", | |
}, | |
] | |
trimmed_messages = trim_messages( | |
messages=messages, model="gpt-3.5-turbo", max_tokens=100 | |
) | |
print("Trimmed messages") | |
print(trimmed_messages) | |
assert messages == trimmed_messages | |
# test_multiple_messages_no_trimming() | |
def test_large_trimming_multiple_messages(): | |
messages = [ | |
{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, | |
{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, | |
{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, | |
{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, | |
{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, | |
] | |
trimmed_messages = trim_messages(messages, max_tokens=20, model="gpt-4-0613") | |
print("trimmed messages") | |
print(trimmed_messages) | |
assert (get_token_count(messages=trimmed_messages, model="gpt-4-0613")) <= 20 | |
# test_large_trimming() | |
def test_large_trimming_single_message(): | |
messages = [ | |
{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."} | |
] | |
trimmed_messages = trim_messages(messages, max_tokens=5, model="gpt-4-0613") | |
assert (get_token_count(messages=trimmed_messages, model="gpt-4-0613")) <= 5 | |
assert (get_token_count(messages=trimmed_messages, model="gpt-4-0613")) > 0 | |
def test_trimming_with_system_message_within_max_tokens(): | |
# This message is 33 tokens long | |
messages = [ | |
{"role": "system", "content": "This is a short system message"}, | |
{ | |
"role": "user", | |
"content": "This is a medium normal message, let's say litellm is awesome.", | |
}, | |
] | |
trimmed_messages = trim_messages( | |
messages, max_tokens=30, model="gpt-4-0613" | |
) # The system message should fit within the token limit | |
assert len(trimmed_messages) == 2 | |
assert trimmed_messages[0]["content"] == "This is a short system message" | |
def test_trimming_with_system_message_exceeding_max_tokens(): | |
# This message is 33 tokens long. The system message is 13 tokens long. | |
messages = [ | |
{"role": "system", "content": "This is a short system message"}, | |
{ | |
"role": "user", | |
"content": "This is a medium normal message, let's say litellm is awesome.", | |
}, | |
] | |
trimmed_messages = trim_messages(messages, max_tokens=12, model="gpt-4-0613") | |
assert len(trimmed_messages) == 1 | |
def test_trimming_with_tool_calls(): | |
from litellm.types.utils import ChatCompletionMessageToolCall, Function, Message | |
messages = [ | |
{ | |
"role": "user", | |
"content": "What's the weather like in San Francisco, Tokyo, and Paris?", | |
}, | |
Message( | |
content=None, | |
role="assistant", | |
tool_calls=[ | |
ChatCompletionMessageToolCall( | |
function=Function( | |
arguments='{"location": "San Francisco, CA", "unit": "celsius"}', | |
name="get_current_weather", | |
), | |
id="call_G11shFcS024xEKjiAOSt6Tc9", | |
type="function", | |
), | |
ChatCompletionMessageToolCall( | |
function=Function( | |
arguments='{"location": "Tokyo, Japan", "unit": "celsius"}', | |
name="get_current_weather", | |
), | |
id="call_e0ss43Bg7H8Z9KGdMGWyZ9Mj", | |
type="function", | |
), | |
ChatCompletionMessageToolCall( | |
function=Function( | |
arguments='{"location": "Paris, France", "unit": "celsius"}', | |
name="get_current_weather", | |
), | |
id="call_nRjLXkWTJU2a4l9PZAf5as6g", | |
type="function", | |
), | |
], | |
function_call=None, | |
), | |
{ | |
"tool_call_id": "call_G11shFcS024xEKjiAOSt6Tc9", | |
"role": "tool", | |
"name": "get_current_weather", | |
"content": '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}', | |
}, | |
{ | |
"tool_call_id": "call_e0ss43Bg7H8Z9KGdMGWyZ9Mj", | |
"role": "tool", | |
"name": "get_current_weather", | |
"content": '{"location": "Tokyo", "temperature": "10", "unit": "celsius"}', | |
}, | |
{ | |
"tool_call_id": "call_nRjLXkWTJU2a4l9PZAf5as6g", | |
"role": "tool", | |
"name": "get_current_weather", | |
"content": '{"location": "Paris", "temperature": "22", "unit": "celsius"}', | |
}, | |
] | |
result = trim_messages(messages=messages, max_tokens=1, return_response_tokens=True) | |
print(result) | |
assert len(result[0]) == 3 # final 3 messages are tool calls | |
def test_trimming_should_not_change_original_messages(): | |
messages = [ | |
{"role": "system", "content": "This is a short system message"}, | |
{ | |
"role": "user", | |
"content": "This is a medium normal message, let's say litellm is awesome.", | |
}, | |
] | |
messages_copy = copy.deepcopy(messages) | |
trimmed_messages = trim_messages(messages, max_tokens=12, model="gpt-4-0613") | |
assert messages == messages_copy | |
def test_trimming_with_model_cost_max_input_tokens(model): | |
messages = [ | |
{"role": "system", "content": "This is a normal system message"}, | |
{ | |
"role": "user", | |
"content": "This is a sentence" * 100000, | |
}, | |
] | |
trimmed_messages = trim_messages(messages, model=model) | |
assert ( | |
get_token_count(trimmed_messages, model=model) | |
< litellm.model_cost[model]["max_input_tokens"] | |
) | |
def test_aget_valid_models(): | |
old_environ = os.environ | |
os.environ = {"OPENAI_API_KEY": "temp"} # mock set only openai key in environ | |
valid_models = get_valid_models() | |
print(valid_models) | |
# list of openai supported llms on litellm | |
expected_models = ( | |
litellm.open_ai_chat_completion_models + litellm.open_ai_text_completion_models | |
) | |
assert valid_models == expected_models | |
# reset replicate env key | |
os.environ = old_environ | |
# GEMINI | |
expected_models = litellm.gemini_models | |
old_environ = os.environ | |
os.environ = {"GEMINI_API_KEY": "temp"} # mock set only openai key in environ | |
valid_models = get_valid_models() | |
print(valid_models) | |
assert valid_models == expected_models | |
# reset replicate env key | |
os.environ = old_environ | |
def test_get_valid_models_with_custom_llm_provider(custom_llm_provider): | |
from litellm.utils import ProviderConfigManager | |
from litellm.types.utils import LlmProviders | |
provider_config = ProviderConfigManager.get_provider_model_info( | |
model=None, | |
provider=LlmProviders(custom_llm_provider), | |
) | |
assert provider_config is not None | |
valid_models = get_valid_models( | |
check_provider_endpoint=True, custom_llm_provider=custom_llm_provider | |
) | |
print(valid_models) | |
assert len(valid_models) > 0 | |
assert provider_config.get_models() == valid_models | |
# test_get_valid_models() | |
def test_bad_key(): | |
key = "bad-key" | |
response = check_valid_key(model="gpt-3.5-turbo", api_key=key) | |
print(response, key) | |
assert response == False | |
def test_good_key(): | |
key = os.environ["OPENAI_API_KEY"] | |
response = check_valid_key(model="gpt-3.5-turbo", api_key=key) | |
assert response == True | |
# test validate environment | |
def test_validate_environment_empty_model(): | |
api_key = validate_environment() | |
if api_key is None: | |
raise Exception() | |
def test_validate_environment_api_key(): | |
response_obj = validate_environment(model="gpt-3.5-turbo", api_key="sk-my-test-key") | |
assert ( | |
response_obj["keys_in_environment"] is True | |
), f"Missing keys={response_obj['missing_keys']}" | |
def test_validate_environment_api_base_dynamic(): | |
for provider in ["ollama", "ollama_chat"]: | |
kv = validate_environment(provider + "/mistral", api_base="https://example.com") | |
assert kv["keys_in_environment"] | |
assert kv["missing_keys"] == [] | |
def test_validate_environment_ollama(): | |
for provider in ["ollama", "ollama_chat"]: | |
kv = validate_environment(provider + "/mistral") | |
assert kv["keys_in_environment"] | |
assert kv["missing_keys"] == [] | |
def test_validate_environment_ollama_failed(): | |
for provider in ["ollama", "ollama_chat"]: | |
kv = validate_environment(provider + "/mistral") | |
assert not kv["keys_in_environment"] | |
assert kv["missing_keys"] == ["OLLAMA_API_BASE"] | |
def test_function_to_dict(): | |
print("testing function to dict for get current weather") | |
def get_current_weather(location: str, unit: str): | |
"""Get the current weather in a given location | |
Parameters | |
---------- | |
location : str | |
The city and state, e.g. San Francisco, CA | |
unit : {'celsius', 'fahrenheit'} | |
Temperature unit | |
Returns | |
------- | |
str | |
a sentence indicating the weather | |
""" | |
if location == "Boston, MA": | |
return "The weather is 12F" | |
function_json = litellm.utils.function_to_dict(get_current_weather) | |
print(function_json) | |
expected_output = { | |
"name": "get_current_weather", | |
"description": "Get the current weather in a given location", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"location": { | |
"type": "string", | |
"description": "The city and state, e.g. San Francisco, CA", | |
}, | |
"unit": { | |
"type": "string", | |
"description": "Temperature unit", | |
"enum": "['fahrenheit', 'celsius']", | |
}, | |
}, | |
"required": ["location", "unit"], | |
}, | |
} | |
print(expected_output) | |
assert function_json["name"] == expected_output["name"] | |
assert function_json["description"] == expected_output["description"] | |
assert function_json["parameters"]["type"] == expected_output["parameters"]["type"] | |
assert ( | |
function_json["parameters"]["properties"]["location"] | |
== expected_output["parameters"]["properties"]["location"] | |
) | |
# the enum can change it can be - which is why we don't assert on unit | |
# {'type': 'string', 'description': 'Temperature unit', 'enum': "['fahrenheit', 'celsius']"} | |
# {'type': 'string', 'description': 'Temperature unit', 'enum': "['celsius', 'fahrenheit']"} | |
assert ( | |
function_json["parameters"]["required"] | |
== expected_output["parameters"]["required"] | |
) | |
print("passed") | |
# test_function_to_dict() | |
def test_supports_function_calling(model, expected_bool): | |
try: | |
assert litellm.supports_function_calling(model=model) == expected_bool | |
except Exception as e: | |
pytest.fail(f"Error occurred: {e}") | |
def test_supports_web_search(model, expected_bool): | |
try: | |
assert litellm.supports_web_search(model=model) == expected_bool | |
except Exception as e: | |
pytest.fail(f"Error occurred: {e}") | |
def test_supports_reasoning(model, expected_bool): | |
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" | |
litellm.model_cost = litellm.get_model_cost_map(url="") | |
try: | |
assert litellm.supports_reasoning(model=model) == expected_bool | |
except Exception as e: | |
pytest.fail(f"Error occurred: {e}") | |
def test_get_max_token_unit_test(): | |
""" | |
More complete testing in `test_completion_cost.py` | |
""" | |
model = "bedrock/anthropic.claude-3-haiku-20240307-v1:0" | |
max_tokens = get_max_tokens( | |
model | |
) # Returns a number instead of throwing an Exception | |
assert isinstance(max_tokens, int) | |
def test_get_supported_openai_params() -> None: | |
# Mapped provider | |
assert isinstance(get_supported_openai_params("gpt-4"), list) | |
# Unmapped provider | |
assert get_supported_openai_params("nonexistent") is None | |
def test_get_chat_completion_prompt(): | |
""" | |
Unit test to ensure get_chat_completion_prompt updates messages in logging object. | |
""" | |
from litellm.litellm_core_utils.litellm_logging import Logging | |
litellm_logging_obj = Logging( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": "hi"}], | |
stream=False, | |
call_type="acompletion", | |
litellm_call_id="1234", | |
start_time=datetime.now(), | |
function_id="1234", | |
) | |
updated_message = "hello world" | |
litellm_logging_obj.get_chat_completion_prompt( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": updated_message}], | |
non_default_params={}, | |
prompt_id="1234", | |
prompt_variables=None, | |
) | |
assert litellm_logging_obj.messages == [ | |
{"role": "user", "content": updated_message} | |
] | |
def test_redact_msgs_from_logs(): | |
""" | |
Tests that turn_off_message_logging does not modify the response_obj | |
On the proxy some users were seeing the redaction impact client side responses | |
""" | |
from litellm.litellm_core_utils.litellm_logging import Logging | |
from litellm.litellm_core_utils.redact_messages import ( | |
redact_message_input_output_from_logging, | |
) | |
litellm.turn_off_message_logging = True | |
response_obj = litellm.ModelResponse( | |
choices=[ | |
{ | |
"finish_reason": "stop", | |
"index": 0, | |
"message": { | |
"content": "I'm LLaMA, an AI assistant developed by Meta AI that can understand and respond to human input in a conversational manner.", | |
"role": "assistant", | |
}, | |
} | |
] | |
) | |
litellm_logging_obj = Logging( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": "hi"}], | |
stream=False, | |
call_type="acompletion", | |
litellm_call_id="1234", | |
start_time=datetime.now(), | |
function_id="1234", | |
) | |
_redacted_response_obj = redact_message_input_output_from_logging( | |
result=response_obj, | |
model_call_details=litellm_logging_obj.model_call_details, | |
) | |
# Assert the response_obj content is NOT modified | |
assert ( | |
response_obj.choices[0].message.content | |
== "I'm LLaMA, an AI assistant developed by Meta AI that can understand and respond to human input in a conversational manner." | |
) | |
litellm.turn_off_message_logging = False | |
print("Test passed") | |
def test_redact_msgs_from_logs_with_dynamic_params(): | |
""" | |
Tests redaction behavior based on standard_callback_dynamic_params setting: | |
In all tests litellm.turn_off_message_logging is True | |
1. When standard_callback_dynamic_params.turn_off_message_logging is False (or not set): No redaction should occur. User has opted out of redaction. | |
2. When standard_callback_dynamic_params.turn_off_message_logging is True: Redaction should occur. User has opted in to redaction. | |
3. standard_callback_dynamic_params.turn_off_message_logging not set, litellm.turn_off_message_logging is True: Redaction should occur. | |
""" | |
from litellm.litellm_core_utils.litellm_logging import Logging | |
from litellm.litellm_core_utils.redact_messages import ( | |
redact_message_input_output_from_logging, | |
) | |
litellm.turn_off_message_logging = True | |
test_content = "I'm LLaMA, an AI assistant developed by Meta AI that can understand and respond to human input in a conversational manner." | |
response_obj = litellm.ModelResponse( | |
choices=[ | |
{ | |
"finish_reason": "stop", | |
"index": 0, | |
"message": { | |
"content": test_content, | |
"role": "assistant", | |
}, | |
} | |
] | |
) | |
litellm_logging_obj = Logging( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": "hi"}], | |
stream=False, | |
call_type="acompletion", | |
litellm_call_id="1234", | |
start_time=datetime.now(), | |
function_id="1234", | |
) | |
# Test Case 1: standard_callback_dynamic_params = False (or not set) | |
standard_callback_dynamic_params = StandardCallbackDynamicParams( | |
turn_off_message_logging=False | |
) | |
litellm_logging_obj.model_call_details["standard_callback_dynamic_params"] = ( | |
standard_callback_dynamic_params | |
) | |
_redacted_response_obj = redact_message_input_output_from_logging( | |
result=response_obj, | |
model_call_details=litellm_logging_obj.model_call_details, | |
) | |
# Assert no redaction occurred | |
assert _redacted_response_obj.choices[0].message.content == test_content | |
# Test Case 2: standard_callback_dynamic_params = True | |
standard_callback_dynamic_params = StandardCallbackDynamicParams( | |
turn_off_message_logging=True | |
) | |
litellm_logging_obj.model_call_details["standard_callback_dynamic_params"] = ( | |
standard_callback_dynamic_params | |
) | |
_redacted_response_obj = redact_message_input_output_from_logging( | |
result=response_obj, | |
model_call_details=litellm_logging_obj.model_call_details, | |
) | |
# Assert redaction occurred | |
assert _redacted_response_obj.choices[0].message.content == "redacted-by-litellm" | |
# Test Case 3: standard_callback_dynamic_params does not override litellm.turn_off_message_logging | |
# since litellm.turn_off_message_logging is True redaction should occur | |
standard_callback_dynamic_params = StandardCallbackDynamicParams() | |
litellm_logging_obj.model_call_details["standard_callback_dynamic_params"] = ( | |
standard_callback_dynamic_params | |
) | |
_redacted_response_obj = redact_message_input_output_from_logging( | |
result=response_obj, | |
model_call_details=litellm_logging_obj.model_call_details, | |
) | |
# Assert no redaction occurred | |
assert _redacted_response_obj.choices[0].message.content == "redacted-by-litellm" | |
# Reset settings | |
litellm.turn_off_message_logging = False | |
print("Test passed") | |
def test_extract_from_regex(duration, unit): | |
value, _unit = _extract_from_regex(duration=duration) | |
assert value == 7 | |
assert _unit == unit | |
def test_duration_in_seconds(): | |
""" | |
Test if duration int is correctly calculated for different str | |
""" | |
import time | |
now = time.time() | |
current_time = datetime.fromtimestamp(now) | |
if current_time.month == 12: | |
target_year = current_time.year + 1 | |
target_month = 1 | |
else: | |
target_year = current_time.year | |
target_month = current_time.month + 1 | |
# Determine the day to set for next month | |
target_day = current_time.day | |
last_day_of_target_month = get_last_day_of_month(target_year, target_month) | |
if target_day > last_day_of_target_month: | |
target_day = last_day_of_target_month | |
next_month = datetime( | |
year=target_year, | |
month=target_month, | |
day=target_day, | |
hour=current_time.hour, | |
minute=current_time.minute, | |
second=current_time.second, | |
microsecond=current_time.microsecond, | |
) | |
# Calculate the duration until the first day of the next month | |
duration_until_next_month = next_month - current_time | |
expected_duration = int(duration_until_next_month.total_seconds()) | |
value = duration_in_seconds(duration="1mo") | |
assert value - expected_duration < 2 | |
def test_duration_in_seconds_basic(): | |
assert duration_in_seconds(duration="3s") == 3 | |
assert duration_in_seconds(duration="3m") == 180 | |
assert duration_in_seconds(duration="3h") == 10800 | |
assert duration_in_seconds(duration="3d") == 259200 | |
assert duration_in_seconds(duration="3w") == 1814400 | |
def test_get_llm_provider_ft_models(): | |
""" | |
All ft prefixed models should map to OpenAI | |
gpt-3.5-turbo-0125 (recommended), | |
gpt-3.5-turbo-1106, | |
gpt-3.5-turbo, | |
gpt-4-0613 (experimental) | |
gpt-4o-2024-05-13. | |
babbage-002, davinci-002, | |
""" | |
model, custom_llm_provider, _, _ = get_llm_provider(model="ft:gpt-3.5-turbo-0125") | |
assert custom_llm_provider == "openai" | |
model, custom_llm_provider, _, _ = get_llm_provider(model="ft:gpt-3.5-turbo-1106") | |
assert custom_llm_provider == "openai" | |
model, custom_llm_provider, _, _ = get_llm_provider(model="ft:gpt-3.5-turbo") | |
assert custom_llm_provider == "openai" | |
model, custom_llm_provider, _, _ = get_llm_provider(model="ft:gpt-4-0613") | |
assert custom_llm_provider == "openai" | |
model, custom_llm_provider, _, _ = get_llm_provider(model="ft:gpt-3.5-turbo") | |
assert custom_llm_provider == "openai" | |
model, custom_llm_provider, _, _ = get_llm_provider(model="ft:gpt-4o-2024-05-13") | |
assert custom_llm_provider == "openai" | |
def test_logging_trace_id(langfuse_trace_id, langfuse_existing_trace_id): | |
""" | |
- Unit test for `_get_trace_id` function in Logging obj | |
""" | |
from litellm.litellm_core_utils.litellm_logging import Logging | |
litellm.success_callback = ["langfuse"] | |
litellm_call_id = "my-unique-call-id" | |
litellm_logging_obj = Logging( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": "hi"}], | |
stream=False, | |
call_type="acompletion", | |
litellm_call_id=litellm_call_id, | |
start_time=datetime.now(), | |
function_id="1234", | |
) | |
metadata = {} | |
if langfuse_trace_id is not None: | |
metadata["trace_id"] = langfuse_trace_id | |
if langfuse_existing_trace_id is not None: | |
metadata["existing_trace_id"] = langfuse_existing_trace_id | |
litellm.completion( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": "Hey how's it going?"}], | |
mock_response="Hey!", | |
litellm_logging_obj=litellm_logging_obj, | |
metadata=metadata, | |
) | |
time.sleep(3) | |
assert litellm_logging_obj._get_trace_id(service_name="langfuse") is not None | |
## if existing_trace_id exists | |
if langfuse_existing_trace_id is not None: | |
assert ( | |
litellm_logging_obj._get_trace_id(service_name="langfuse") | |
== langfuse_existing_trace_id | |
) | |
## if trace_id exists | |
elif langfuse_trace_id is not None: | |
assert ( | |
litellm_logging_obj._get_trace_id(service_name="langfuse") | |
== langfuse_trace_id | |
) | |
## if existing_trace_id exists | |
else: | |
assert ( | |
litellm_logging_obj._get_trace_id(service_name="langfuse") | |
== litellm_call_id | |
) | |
def test_convert_model_response_object(): | |
""" | |
Unit test to ensure model response object correctly handles openrouter errors. | |
""" | |
args = { | |
"response_object": { | |
"id": None, | |
"choices": None, | |
"created": None, | |
"model": None, | |
"object": None, | |
"service_tier": None, | |
"system_fingerprint": None, | |
"usage": None, | |
"error": { | |
"message": '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}', | |
"code": 400, | |
}, | |
}, | |
"model_response_object": litellm.ModelResponse( | |
id="chatcmpl-b88ce43a-7bfc-437c-b8cc-e90d59372cfb", | |
choices=[ | |
litellm.Choices( | |
finish_reason="stop", | |
index=0, | |
message=litellm.Message(content="default", role="assistant"), | |
) | |
], | |
created=1719376241, | |
model="openrouter/anthropic/claude-3.5-sonnet", | |
object="chat.completion", | |
system_fingerprint=None, | |
usage=litellm.Usage(), | |
), | |
"response_type": "completion", | |
"stream": False, | |
"start_time": None, | |
"end_time": None, | |
"hidden_params": None, | |
} | |
try: | |
litellm.convert_to_model_response_object(**args) | |
pytest.fail("Expected this to fail") | |
except Exception as e: | |
assert hasattr(e, "status_code") | |
assert e.status_code == 400 | |
assert hasattr(e, "message") | |
assert ( | |
e.message | |
== '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}' | |
) | |
def test_parse_content_for_reasoning(content, expected_reasoning, expected_content): | |
assert litellm.utils._parse_content_for_reasoning(content) == ( | |
expected_reasoning, | |
expected_content, | |
) | |
def test_supports_response_schema(model, expected_bool): | |
""" | |
Unit tests for 'supports_response_schema' helper function. | |
Should be true for gemini-1.5-pro on google ai studio / vertex ai AND predibase models | |
Should be false otherwise | |
""" | |
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" | |
litellm.model_cost = litellm.get_model_cost_map(url="") | |
from litellm.utils import supports_response_schema | |
response = supports_response_schema(model=model, custom_llm_provider=None) | |
assert expected_bool == response | |
def test_supports_function_calling_v2(model, expected_bool): | |
""" | |
Unit test for 'supports_function_calling' helper function. | |
""" | |
from litellm.utils import supports_function_calling | |
response = supports_function_calling(model=model, custom_llm_provider=None) | |
assert expected_bool == response | |
def test_supports_vision(model, expected_bool): | |
""" | |
Unit test for 'supports_vision' helper function. | |
""" | |
from litellm.utils import supports_vision | |
response = supports_vision(model=model, custom_llm_provider=None) | |
assert expected_bool == response | |
def test_usage_object_null_tokens(): | |
""" | |
Unit test. | |
Asserts Usage obj always returns int. | |
Fixes https://github.com/BerriAI/litellm/issues/5096 | |
""" | |
usage_obj = litellm.Usage(prompt_tokens=2, completion_tokens=None, total_tokens=2) | |
assert usage_obj.completion_tokens == 0 | |
def test_is_base64_encoded(): | |
import base64 | |
import requests | |
litellm.set_verbose = True | |
url = "https://dummyimage.com/100/100/fff&text=Test+image" | |
response = requests.get(url) | |
file_data = response.content | |
encoded_file = base64.b64encode(file_data).decode("utf-8") | |
base64_image = f"data:image/png;base64,{encoded_file}" | |
from litellm.utils import is_base64_encoded | |
assert is_base64_encoded(s=base64_image) is True | |
def test_async_http_handler(mock_async_client): | |
import httpx | |
timeout = 120 | |
event_hooks = {"request": [lambda r: r]} | |
concurrent_limit = 2 | |
# Mock the transport creation to return a specific transport | |
with mock.patch.object(AsyncHTTPHandler, '_create_async_transport') as mock_create_transport: | |
mock_transport = mock.MagicMock() | |
mock_create_transport.return_value = mock_transport | |
AsyncHTTPHandler(timeout, event_hooks, concurrent_limit) | |
mock_async_client.assert_called_with( | |
cert="/client.pem", | |
transport=mock_transport, | |
event_hooks=event_hooks, | |
headers=headers, | |
limits=httpx.Limits( | |
max_connections=concurrent_limit, | |
max_keepalive_connections=concurrent_limit, | |
), | |
timeout=timeout, | |
verify="/certificate.pem", | |
) | |
def test_async_http_handler_force_ipv4(mock_async_client): | |
""" | |
Test AsyncHTTPHandler when litellm.force_ipv4 is True | |
This is prod test - we need to ensure that httpx always uses ipv4 when litellm.force_ipv4 is True | |
""" | |
import httpx | |
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler | |
# Set force_ipv4 to True | |
litellm.force_ipv4 = True | |
litellm.disable_aiohttp_transport = True | |
try: | |
timeout = 120 | |
event_hooks = {"request": [lambda r: r]} | |
concurrent_limit = 2 | |
AsyncHTTPHandler(timeout, event_hooks, concurrent_limit) | |
# Get the call arguments | |
call_args = mock_async_client.call_args[1] | |
############# IMPORTANT ASSERTION ################# | |
# Assert transport exists and is configured correctly for using ipv4 | |
assert isinstance(call_args["transport"], httpx.AsyncHTTPTransport) | |
print(call_args["transport"]) | |
assert call_args["transport"]._pool._local_address == "0.0.0.0" | |
#################################### | |
# Assert other parameters match | |
assert call_args["event_hooks"] == event_hooks | |
assert call_args["headers"] == headers | |
assert isinstance(call_args["limits"], httpx.Limits) | |
assert call_args["limits"].max_connections == concurrent_limit | |
assert call_args["limits"].max_keepalive_connections == concurrent_limit | |
assert call_args["timeout"] == timeout | |
assert call_args["verify"] is True | |
assert call_args["cert"] is None | |
finally: | |
# Reset force_ipv4 to default | |
litellm.force_ipv4 = False | |
def test_supports_audio_input(model, expected_bool): | |
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" | |
litellm.model_cost = litellm.get_model_cost_map(url="") | |
from litellm.utils import supports_audio_input, supports_audio_output | |
supports_pc = supports_audio_input(model=model) | |
assert supports_pc == expected_bool | |
def test_is_base64_encoded_2(): | |
from litellm.utils import is_base64_encoded | |
assert ( | |
is_base64_encoded( | |
s="" | |
) | |
is True | |
) | |
assert is_base64_encoded(s="Dog") is False | |
def test_validate_chat_completion_user_messages(messages, expected_bool): | |
from litellm.utils import validate_chat_completion_user_messages | |
if expected_bool: | |
## Valid message | |
validate_chat_completion_user_messages(messages=messages) | |
else: | |
## Invalid message | |
with pytest.raises(Exception): | |
validate_chat_completion_user_messages(messages=messages) | |
def test_validate_chat_completion_tool_choice(tool_choice, expected_bool): | |
from litellm.utils import validate_chat_completion_tool_choice | |
if expected_bool: | |
validate_chat_completion_tool_choice(tool_choice=tool_choice) | |
else: | |
with pytest.raises(Exception): | |
validate_chat_completion_tool_choice(tool_choice=tool_choice) | |
def test_models_by_provider(): | |
""" | |
Make sure all providers from model map are in the valid providers list | |
""" | |
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" | |
litellm.model_cost = litellm.get_model_cost_map(url="") | |
from litellm import models_by_provider | |
providers = set() | |
for k, v in litellm.model_cost.items(): | |
if "_" in v["litellm_provider"] and "-" in v["litellm_provider"]: | |
continue | |
elif k == "sample_spec": | |
continue | |
elif ( | |
v["litellm_provider"] == "sagemaker" | |
or v["litellm_provider"] == "bedrock_converse" | |
): | |
continue | |
else: | |
providers.add(v["litellm_provider"]) | |
for provider in providers: | |
assert provider in models_by_provider.keys() | |
def test_get_end_user_id_for_cost_tracking( | |
litellm_params, disable_end_user_cost_tracking, expected_end_user_id | |
): | |
from litellm.utils import get_end_user_id_for_cost_tracking | |
litellm.disable_end_user_cost_tracking = disable_end_user_cost_tracking | |
assert ( | |
get_end_user_id_for_cost_tracking(litellm_params=litellm_params) | |
== expected_end_user_id | |
) | |
def test_get_end_user_id_for_cost_tracking_prometheus_only( | |
litellm_params, enable_end_user_cost_tracking_prometheus_only, expected_end_user_id | |
): | |
from litellm.utils import get_end_user_id_for_cost_tracking | |
litellm.enable_end_user_cost_tracking_prometheus_only = ( | |
enable_end_user_cost_tracking_prometheus_only | |
) | |
assert ( | |
get_end_user_id_for_cost_tracking( | |
litellm_params=litellm_params, service_type="prometheus" | |
) | |
== expected_end_user_id | |
) | |
def test_is_prompt_caching_enabled_error_handling(): | |
""" | |
Assert that `is_prompt_caching_valid_prompt` safely handles errors in `token_counter`. | |
""" | |
with patch( | |
"litellm.utils.token_counter", | |
side_effect=Exception( | |
"Mocked error, This should not raise an error. Instead is_prompt_caching_valid_prompt should return False." | |
), | |
): | |
result = litellm.utils.is_prompt_caching_valid_prompt( | |
messages=[{"role": "user", "content": "test"}], | |
tools=None, | |
custom_llm_provider="anthropic", | |
model="anthropic/claude-3-5-sonnet-20240620", | |
) | |
assert result is False # Should return False when an error occurs | |
def test_is_prompt_caching_enabled_return_default_image_dimensions(): | |
""" | |
Assert that `is_prompt_caching_valid_prompt` calls token_counter with use_default_image_token_count=True | |
when processing messages containing images | |
IMPORTANT: Ensures Get token counter does not make a GET request to the image url | |
""" | |
with patch("litellm.utils.token_counter") as mock_token_counter: | |
litellm.utils.is_prompt_caching_valid_prompt( | |
messages=[ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": "What is in this image?"}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": "https://www.gstatic.com/webp/gallery/1.webp", | |
"detail": "high", | |
}, | |
}, | |
], | |
} | |
], | |
tools=None, | |
custom_llm_provider="openai", | |
model="gpt-4o-mini", | |
) | |
# Assert token_counter was called with use_default_image_token_count=True | |
args_to_mock_token_counter = mock_token_counter.call_args[1] | |
print("args_to_mock", args_to_mock_token_counter) | |
assert args_to_mock_token_counter["use_default_image_token_count"] is True | |
def test_token_counter_with_image_url_with_detail_high(): | |
""" | |
Assert that token_counter does not make a GET request to the image url when `use_default_image_token_count=True` | |
PROD TEST this is importat - Can impact latency very badly | |
""" | |
from litellm.constants import DEFAULT_IMAGE_TOKEN_COUNT | |
from litellm._logging import verbose_logger | |
import logging | |
verbose_logger.setLevel(logging.DEBUG) | |
_tokens = litellm.utils.token_counter( | |
messages=[ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": "https://www.gstatic.com/webp/gallery/1.webp", | |
"detail": "high", | |
}, | |
}, | |
], | |
} | |
], | |
model="gpt-4o-mini", | |
use_default_image_token_count=True, | |
) | |
print("tokens", _tokens) | |
assert _tokens == DEFAULT_IMAGE_TOKEN_COUNT + 7 | |
def test_fireworks_ai_document_inlining(): | |
""" | |
With document inlining, all fireworks ai models are now: | |
- supports_pdf | |
- supports_vision | |
""" | |
from litellm.utils import supports_pdf_input, supports_vision | |
litellm._turn_on_debug() | |
assert supports_pdf_input("fireworks_ai/llama-3.1-8b-instruct") is True | |
assert supports_vision("fireworks_ai/llama-3.1-8b-instruct") is True | |
def test_logprobs_type(): | |
from litellm.types.utils import Logprobs | |
logprobs = { | |
"text_offset": None, | |
"token_logprobs": None, | |
"tokens": None, | |
"top_logprobs": None, | |
} | |
logprobs = Logprobs(**logprobs) | |
assert logprobs.text_offset is None | |
assert logprobs.token_logprobs is None | |
assert logprobs.tokens is None | |
assert logprobs.top_logprobs is None | |
def test_get_valid_models_openai_proxy(monkeypatch): | |
from litellm.utils import get_valid_models | |
import litellm | |
litellm._turn_on_debug() | |
monkeypatch.setenv("LITELLM_PROXY_API_KEY", "sk-1234") | |
monkeypatch.setenv("LITELLM_PROXY_API_BASE", "https://litellm-api.up.railway.app/") | |
monkeypatch.delenv("FIREWORKS_AI_ACCOUNT_ID", None) | |
monkeypatch.delenv("FIREWORKS_AI_API_KEY", None) | |
mock_response_data = { | |
"object": "list", | |
"data": [ | |
{ | |
"id": "gpt-4o", | |
"object": "model", | |
"created": 1686935002, | |
"owned_by": "organization-owner", | |
}, | |
], | |
} | |
# Create a mock response object | |
mock_response = MagicMock() | |
mock_response.status_code = 200 | |
mock_response.json.return_value = mock_response_data | |
with patch.object( | |
litellm.module_level_client, "get", return_value=mock_response | |
) as mock_post: | |
valid_models = get_valid_models(check_provider_endpoint=True) | |
assert "litellm_proxy/gpt-4o" in valid_models | |
def test_get_valid_models_fireworks_ai(monkeypatch): | |
from litellm.utils import get_valid_models | |
import litellm | |
litellm._turn_on_debug() | |
monkeypatch.setenv("FIREWORKS_API_KEY", "sk-1234") | |
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "1234") | |
monkeypatch.setattr(litellm, "provider_list", ["fireworks_ai"]) | |
mock_response_data = { | |
"models": [ | |
{ | |
"name": "accounts/fireworks/models/llama-3.1-8b-instruct", | |
"displayName": "<string>", | |
"description": "<string>", | |
"createTime": "2023-11-07T05:31:56Z", | |
"createdBy": "<string>", | |
"state": "STATE_UNSPECIFIED", | |
"status": {"code": "OK", "message": "<string>"}, | |
"kind": "KIND_UNSPECIFIED", | |
"githubUrl": "<string>", | |
"huggingFaceUrl": "<string>", | |
"baseModelDetails": { | |
"worldSize": 123, | |
"checkpointFormat": "CHECKPOINT_FORMAT_UNSPECIFIED", | |
"parameterCount": "<string>", | |
"moe": True, | |
"tunable": True, | |
}, | |
"peftDetails": { | |
"baseModel": "<string>", | |
"r": 123, | |
"targetModules": ["<string>"], | |
}, | |
"teftDetails": {}, | |
"public": True, | |
"conversationConfig": { | |
"style": "<string>", | |
"system": "<string>", | |
"template": "<string>", | |
}, | |
"contextLength": 123, | |
"supportsImageInput": True, | |
"supportsTools": True, | |
"importedFrom": "<string>", | |
"fineTuningJob": "<string>", | |
"defaultDraftModel": "<string>", | |
"defaultDraftTokenCount": 123, | |
"precisions": ["PRECISION_UNSPECIFIED"], | |
"deployedModelRefs": [ | |
{ | |
"name": "<string>", | |
"deployment": "<string>", | |
"state": "STATE_UNSPECIFIED", | |
"default": True, | |
"public": True, | |
} | |
], | |
"cluster": "<string>", | |
"deprecationDate": {"year": 123, "month": 123, "day": 123}, | |
} | |
], | |
"nextPageToken": "<string>", | |
"totalSize": 123, | |
} | |
# Create a mock response object | |
mock_response = MagicMock() | |
mock_response.status_code = 200 | |
mock_response.json.return_value = mock_response_data | |
with patch.object( | |
litellm.module_level_client, "get", return_value=mock_response | |
) as mock_post: | |
valid_models = get_valid_models(check_provider_endpoint=True) | |
print("valid_models", valid_models) | |
mock_post.assert_called_once() | |
assert ( | |
"fireworks_ai/accounts/fireworks/models/llama-3.1-8b-instruct" | |
in valid_models | |
) | |
def test_get_valid_models_default(monkeypatch): | |
""" | |
Ensure that the default models is used when error retrieving from model api. | |
Prevent regression for existing usage. | |
""" | |
from litellm.utils import get_valid_models | |
import litellm | |
monkeypatch.setenv("FIREWORKS_API_KEY", "sk-1234") | |
valid_models = get_valid_models() | |
assert len(valid_models) > 0 | |
def test_supports_vision_gemini(): | |
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" | |
litellm.model_cost = litellm.get_model_cost_map(url="") | |
from litellm.utils import supports_vision | |
assert supports_vision("gemini-1.5-pro") is True | |
def test_pick_cheapest_chat_model_from_llm_provider(): | |
from litellm.litellm_core_utils.llm_request_utils import ( | |
pick_cheapest_chat_models_from_llm_provider, | |
) | |
assert len(pick_cheapest_chat_models_from_llm_provider("openai", n=3)) == 3 | |
assert len(pick_cheapest_chat_models_from_llm_provider("unknown", n=1)) == 0 | |
def test_get_potential_model_names(): | |
from litellm.utils import _get_potential_model_names | |
assert _get_potential_model_names( | |
model="bedrock/ap-northeast-1/anthropic.claude-instant-v1", | |
custom_llm_provider="bedrock", | |
) | |
def test_get_num_retries(num_retries): | |
from litellm.utils import _get_wrapper_num_retries | |
assert _get_wrapper_num_retries( | |
kwargs={"num_retries": num_retries}, exception=Exception("test") | |
) == ( | |
num_retries, | |
{ | |
"num_retries": num_retries, | |
}, | |
) | |
def test_add_custom_logger_callback_to_specific_event(monkeypatch): | |
from litellm.utils import _add_custom_logger_callback_to_specific_event | |
monkeypatch.setattr(litellm, "success_callback", []) | |
monkeypatch.setattr(litellm, "failure_callback", []) | |
_add_custom_logger_callback_to_specific_event("langfuse", "success") | |
assert len(litellm.success_callback) == 1 | |
assert len(litellm.failure_callback) == 0 | |
def test_add_custom_logger_callback_to_specific_event_e2e(monkeypatch): | |
monkeypatch.setattr(litellm, "success_callback", []) | |
monkeypatch.setattr(litellm, "failure_callback", []) | |
monkeypatch.setattr(litellm, "callbacks", []) | |
litellm.success_callback = ["humanloop"] | |
curr_len_success_callback = len(litellm.success_callback) | |
curr_len_failure_callback = len(litellm.failure_callback) | |
litellm.completion( | |
model="gpt-4o-mini", | |
messages=[{"role": "user", "content": "Hello, world!"}], | |
mock_response="Testing langfuse", | |
) | |
assert len(litellm.success_callback) == curr_len_success_callback | |
assert len(litellm.failure_callback) == curr_len_failure_callback | |
def test_custom_logger_exists_in_callbacks_individual_functions(monkeypatch): | |
""" | |
Test _custom_logger_class_exists_in_success_callbacks and _custom_logger_class_exists_in_failure_callbacks helper functions | |
Tests if logger is found in different callback lists | |
""" | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm.utils import ( | |
_custom_logger_class_exists_in_failure_callbacks, | |
_custom_logger_class_exists_in_success_callbacks, | |
) | |
# Create a mock CustomLogger class | |
class MockCustomLogger(CustomLogger): | |
def log_success_event(self, kwargs, response_obj, start_time, end_time): | |
pass | |
def log_failure_event(self, kwargs, response_obj, start_time, end_time): | |
pass | |
# Reset all callback lists | |
for list_name in [ | |
"callbacks", | |
"_async_success_callback", | |
"_async_failure_callback", | |
"success_callback", | |
"failure_callback", | |
]: | |
monkeypatch.setattr(litellm, list_name, []) | |
mock_logger = MockCustomLogger() | |
# Test 1: No logger exists in any callback list | |
assert _custom_logger_class_exists_in_success_callbacks(mock_logger) == False | |
assert _custom_logger_class_exists_in_failure_callbacks(mock_logger) == False | |
# Test 2: Logger exists in success_callback | |
litellm.success_callback.append(mock_logger) | |
assert _custom_logger_class_exists_in_success_callbacks(mock_logger) == True | |
assert _custom_logger_class_exists_in_failure_callbacks(mock_logger) == False | |
# Reset callbacks | |
litellm.success_callback = [] | |
# Test 3: Logger exists in _async_success_callback | |
litellm._async_success_callback.append(mock_logger) | |
assert _custom_logger_class_exists_in_success_callbacks(mock_logger) == True | |
assert _custom_logger_class_exists_in_failure_callbacks(mock_logger) == False | |
# Reset callbacks | |
litellm._async_success_callback = [] | |
# Test 4: Logger exists in failure_callback | |
litellm.failure_callback.append(mock_logger) | |
assert _custom_logger_class_exists_in_success_callbacks(mock_logger) == False | |
assert _custom_logger_class_exists_in_failure_callbacks(mock_logger) == True | |
# Reset callbacks | |
litellm.failure_callback = [] | |
# Test 5: Logger exists in _async_failure_callback | |
litellm._async_failure_callback.append(mock_logger) | |
assert _custom_logger_class_exists_in_success_callbacks(mock_logger) == False | |
assert _custom_logger_class_exists_in_failure_callbacks(mock_logger) == True | |
# Test 6: Logger exists in both success and failure callbacks | |
litellm.success_callback.append(mock_logger) | |
litellm.failure_callback.append(mock_logger) | |
assert _custom_logger_class_exists_in_success_callbacks(mock_logger) == True | |
assert _custom_logger_class_exists_in_failure_callbacks(mock_logger) == True | |
# Test 7: Different instance of same logger class | |
mock_logger_2 = MockCustomLogger() | |
assert _custom_logger_class_exists_in_success_callbacks(mock_logger_2) == True | |
assert _custom_logger_class_exists_in_failure_callbacks(mock_logger_2) == True | |
async def test_add_custom_logger_callback_to_specific_event_with_duplicates( | |
monkeypatch, | |
): | |
""" | |
Test that when a callback exists in both success_callback and _async_success_callback, | |
it's not added again | |
""" | |
from litellm.integrations.langfuse.langfuse_prompt_management import ( | |
LangfusePromptManagement, | |
) | |
# Reset all callback lists | |
monkeypatch.setattr(litellm, "callbacks", []) | |
monkeypatch.setattr(litellm, "_async_success_callback", []) | |
monkeypatch.setattr(litellm, "_async_failure_callback", []) | |
monkeypatch.setattr(litellm, "success_callback", []) | |
monkeypatch.setattr(litellm, "failure_callback", []) | |
# Add logger to both success_callback and _async_success_callback | |
langfuse_logger = LangfusePromptManagement() | |
litellm.success_callback.append(langfuse_logger) | |
litellm._async_success_callback.append(langfuse_logger) | |
# Get initial lengths | |
initial_success_callback_len = len(litellm.success_callback) | |
initial_async_success_callback_len = len(litellm._async_success_callback) | |
# Make a completion call | |
await litellm.acompletion( | |
model="gpt-4o-mini", | |
messages=[{"role": "user", "content": "Hello, world!"}], | |
mock_response="Testing duplicate callbacks", | |
) | |
# Assert no new callbacks were added | |
assert len(litellm.success_callback) == initial_success_callback_len | |
assert len(litellm._async_success_callback) == initial_async_success_callback_len | |
async def test_add_custom_logger_callback_to_specific_event_with_duplicates_success_callback( | |
monkeypatch, | |
): | |
""" | |
Test that when a callback exists in both success_callback and _async_success_callback, | |
it's not added again | |
""" | |
from litellm.integrations.langfuse.langfuse_prompt_management import ( | |
LangfusePromptManagement, | |
) | |
# Reset all callback lists | |
monkeypatch.setattr(litellm, "callbacks", []) | |
monkeypatch.setattr(litellm, "_async_success_callback", []) | |
monkeypatch.setattr(litellm, "_async_failure_callback", []) | |
monkeypatch.setattr(litellm, "success_callback", []) | |
monkeypatch.setattr(litellm, "failure_callback", []) | |
# Add logger to both success_callback and _async_success_callback | |
langfuse_logger = LangfusePromptManagement() | |
litellm.success_callback.append(langfuse_logger) | |
# Get initial lengths | |
initial_success_callback_len = len(litellm.success_callback) | |
initial_async_success_callback_len = len(litellm._async_success_callback) | |
# Make a completion call | |
await litellm.acompletion( | |
model="gpt-4o-mini", | |
messages=[{"role": "user", "content": "Hello, world!"}], | |
mock_response="Testing duplicate callbacks", | |
) | |
# Assert no new callbacks were added | |
assert len(litellm.success_callback) == initial_success_callback_len | |
assert len(litellm._async_success_callback) == initial_async_success_callback_len | |
async def test_add_custom_logger_callback_to_specific_event_with_duplicates_callbacks( | |
monkeypatch, | |
): | |
""" | |
Test that when a callback exists in both success_callback and _async_success_callback, | |
it's not added again | |
""" | |
from litellm.integrations.langfuse.langfuse_prompt_management import ( | |
LangfusePromptManagement, | |
) | |
# Reset all callback lists | |
monkeypatch.setattr(litellm, "callbacks", []) | |
monkeypatch.setattr(litellm, "_async_success_callback", []) | |
monkeypatch.setattr(litellm, "_async_failure_callback", []) | |
monkeypatch.setattr(litellm, "success_callback", []) | |
monkeypatch.setattr(litellm, "failure_callback", []) | |
# Add logger to both success_callback and _async_success_callback | |
langfuse_logger = LangfusePromptManagement() | |
litellm.callbacks.append(langfuse_logger) | |
# Make a completion call | |
await litellm.acompletion( | |
model="gpt-4o-mini", | |
messages=[{"role": "user", "content": "Hello, world!"}], | |
mock_response="Testing duplicate callbacks", | |
) | |
# Assert no new callbacks were added | |
initial_callbacks_len = len(litellm.callbacks) | |
initial_async_success_callback_len = len(litellm._async_success_callback) | |
initial_success_callback_len = len(litellm.success_callback) | |
print( | |
f"Num callbacks before: litellm.callbacks: {len(litellm.callbacks)}, litellm._async_success_callback: {len(litellm._async_success_callback)}, litellm.success_callback: {len(litellm.success_callback)}" | |
) | |
for _ in range(10): | |
await litellm.acompletion( | |
model="gpt-4o-mini", | |
messages=[{"role": "user", "content": "Hello, world!"}], | |
mock_response="Testing duplicate callbacks", | |
) | |
assert len(litellm.callbacks) == initial_callbacks_len | |
assert len(litellm._async_success_callback) == initial_async_success_callback_len | |
assert len(litellm.success_callback) == initial_success_callback_len | |
print( | |
f"Num callbacks after 10 mock calls: litellm.callbacks: {len(litellm.callbacks)}, litellm._async_success_callback: {len(litellm._async_success_callback)}, litellm.success_callback: {len(litellm.success_callback)}" | |
) | |
def test_add_custom_logger_callback_to_specific_event_e2e_failure(monkeypatch): | |
from litellm.integrations.openmeter import OpenMeterLogger | |
monkeypatch.setattr(litellm, "success_callback", []) | |
monkeypatch.setattr(litellm, "failure_callback", []) | |
monkeypatch.setattr(litellm, "callbacks", []) | |
monkeypatch.setenv("OPENMETER_API_KEY", "wedlwe") | |
monkeypatch.setenv("OPENMETER_API_URL", "https://openmeter.dev") | |
litellm.failure_callback = ["openmeter"] | |
curr_len_success_callback = len(litellm.success_callback) | |
curr_len_failure_callback = len(litellm.failure_callback) | |
litellm.completion( | |
model="gpt-4o-mini", | |
messages=[{"role": "user", "content": "Hello, world!"}], | |
mock_response="Testing langfuse", | |
) | |
assert len(litellm.success_callback) == curr_len_success_callback | |
assert len(litellm.failure_callback) == curr_len_failure_callback | |
assert any( | |
isinstance(callback, OpenMeterLogger) for callback in litellm.failure_callback | |
) | |
async def test_wrapper_kwargs_passthrough(): | |
from litellm.utils import client | |
from litellm.litellm_core_utils.litellm_logging import ( | |
Logging as LiteLLMLoggingObject, | |
) | |
# Create mock original function | |
mock_original = AsyncMock() | |
# Apply decorator | |
async def test_function(**kwargs): | |
return await mock_original(**kwargs) | |
# Test kwargs | |
test_kwargs = {"base_model": "gpt-4o-mini"} | |
# Call decorated function | |
await test_function(**test_kwargs) | |
mock_original.assert_called_once() | |
# get litellm logging object | |
litellm_logging_obj: LiteLLMLoggingObject = mock_original.call_args.kwargs.get( | |
"litellm_logging_obj" | |
) | |
assert litellm_logging_obj is not None | |
print( | |
f"litellm_logging_obj.model_call_details: {litellm_logging_obj.model_call_details}" | |
) | |
# get base model | |
assert ( | |
litellm_logging_obj.model_call_details["litellm_params"]["base_model"] | |
== "gpt-4o-mini" | |
) | |
def test_dict_to_response_format_helper(): | |
from litellm.llms.base_llm.base_utils import _dict_to_response_format_helper | |
args = { | |
"response_format": { | |
"type": "json_schema", | |
"json_schema": { | |
"schema": { | |
"$defs": { | |
"CalendarEvent": { | |
"properties": { | |
"name": {"title": "Name", "type": "string"}, | |
"date": {"title": "Date", "type": "string"}, | |
"participants": { | |
"items": {"type": "string"}, | |
"title": "Participants", | |
"type": "array", | |
}, | |
}, | |
"required": ["name", "date", "participants"], | |
"title": "CalendarEvent", | |
"type": "object", | |
"additionalProperties": False, | |
} | |
}, | |
"properties": { | |
"events": { | |
"items": {"$ref": "#/$defs/CalendarEvent"}, | |
"title": "Events", | |
"type": "array", | |
} | |
}, | |
"required": ["events"], | |
"title": "EventsList", | |
"type": "object", | |
"additionalProperties": False, | |
}, | |
"name": "EventsList", | |
"strict": True, | |
}, | |
}, | |
"ref_template": "/$defs/{model}", | |
} | |
_dict_to_response_format_helper(**args) | |
def test_validate_user_messages_invalid_content_type(): | |
from litellm.utils import validate_chat_completion_user_messages | |
messages = [{"content": [{"type": "invalid_type", "text": "Hello"}]}] | |
with pytest.raises(Exception) as e: | |
validate_chat_completion_user_messages(messages) | |
assert "Invalid message" in str(e) | |
print(e) | |
from litellm.integrations.custom_guardrail import CustomGuardrail | |
from litellm.utils import get_applied_guardrails | |
from unittest.mock import Mock | |
def test_get_applied_guardrails(test_case): | |
# Setup | |
litellm.callbacks = test_case["callbacks"] | |
# Execute | |
result = get_applied_guardrails(test_case["kwargs"]) | |
# Assert | |
assert sorted(result) == sorted(test_case["expected"]) | |
def test_should_use_cohere_v1_client(endpoint, params, expected_bool): | |
assert litellm.utils.should_use_cohere_v1_client(endpoint, params) == expected_bool | |
def test_add_openai_metadata(): | |
from litellm.utils import add_openai_metadata | |
metadata = { | |
"user_api_key_end_user_id": "123", | |
"hidden_params": {"api_key": "123"}, | |
"litellm_parent_otel_span": MagicMock(), | |
"none-val": None, | |
"int-val": 1, | |
"dict-val": {"a": 1, "b": 2}, | |
} | |
result = add_openai_metadata(metadata) | |
assert result == { | |
"user_api_key_end_user_id": "123", | |
} | |
def test_message_object(): | |
from litellm.types.utils import Message | |
message = Message(content="Hello, world!", role="user") | |
assert message.content == "Hello, world!" | |
assert message.role == "user" | |
assert not hasattr(message, "audio") | |
assert not hasattr(message, "thinking_blocks") | |
assert not hasattr(message, "reasoning_content") | |
def test_delta_object(): | |
from litellm.types.utils import Delta | |
delta = Delta(content="Hello, world!", role="user") | |
assert delta.content == "Hello, world!" | |
assert delta.role == "user" | |
assert not hasattr(delta, "thinking_blocks") | |
assert not hasattr(delta, "reasoning_content") | |
def test_get_provider_audio_transcription_config(): | |
from litellm.utils import ProviderConfigManager | |
from litellm.types.utils import LlmProviders | |
for provider in LlmProviders: | |
config = ProviderConfigManager.get_provider_audio_transcription_config( | |
model="whisper-1", provider=provider | |
) | |
def test_claude_3_7_sonnet_supports_pdf_input(model, expected_bool): | |
from litellm.utils import supports_pdf_input | |
assert supports_pdf_input(model) == expected_bool | |
def test_get_valid_models_from_provider(): | |
""" | |
Test that get_valid_models returns the correct models for a given provider | |
""" | |
from litellm.utils import get_valid_models | |
valid_models = get_valid_models(custom_llm_provider="openai") | |
assert len(valid_models) > 0 | |
assert "gpt-4o-mini" in valid_models | |
print("Valid models: ", valid_models) | |
valid_models.remove("gpt-4o-mini") | |
assert "gpt-4o-mini" not in valid_models | |
valid_models = get_valid_models(custom_llm_provider="openai") | |
assert len(valid_models) > 0 | |
assert "gpt-4o-mini" in valid_models | |
def test_get_valid_models_from_provider_cache_invalidation(monkeypatch): | |
""" | |
Test that get_valid_models returns the correct models for a given provider | |
""" | |
from litellm.utils import _model_cache | |
monkeypatch.setenv("OPENAI_API_KEY", "123") | |
_model_cache.set_cached_model_info("openai", litellm_params=None, available_models=["gpt-4o-mini"]) | |
monkeypatch.delenv("OPENAI_API_KEY") | |
assert _model_cache.get_cached_model_info("openai") is None | |
def test_get_valid_models_from_dynamic_api_key(): | |
""" | |
Test that get_valid_models returns the correct models for a given provider | |
""" | |
from litellm.utils import get_valid_models | |
from litellm.types.router import CredentialLiteLLMParams | |
creds = CredentialLiteLLMParams(api_key="123") | |
valid_models = get_valid_models(custom_llm_provider="anthropic", litellm_params=creds, check_provider_endpoint=True) | |
assert len(valid_models) == 0 | |
creds = CredentialLiteLLMParams(api_key=os.getenv("ANTHROPIC_API_KEY")) | |
valid_models = get_valid_models(custom_llm_provider="anthropic", litellm_params=creds, check_provider_endpoint=True) | |
assert len(valid_models) > 0 | |
assert "anthropic/claude-3-7-sonnet-20250219" in valid_models | |