from unittest.mock import patch import pytest from smolagents.cli import load_model from smolagents.local_python_executor import LocalPythonExecutor from smolagents.models import InferenceClientModel, LiteLLMModel, OpenAIServerModel, TransformersModel @pytest.fixture def set_env_vars(monkeypatch): monkeypatch.setenv("FIREWORKS_API_KEY", "test_fireworks_api_key") monkeypatch.setenv("HF_TOKEN", "test_hf_api_key") def test_load_model_openai_server_model(set_env_vars): with patch("openai.OpenAI") as MockOpenAI: model = load_model("OpenAIServerModel", "test_model_id") assert isinstance(model, OpenAIServerModel) assert model.model_id == "test_model_id" assert MockOpenAI.call_count == 1 assert MockOpenAI.call_args.kwargs["base_url"] == "https://api.fireworks.ai/inference/v1" assert MockOpenAI.call_args.kwargs["api_key"] == "test_fireworks_api_key" def test_load_model_litellm_model(): model = load_model("LiteLLMModel", "test_model_id", api_key="test_api_key", api_base="https://api.test.com") assert isinstance(model, LiteLLMModel) assert model.api_key == "test_api_key" assert model.api_base == "https://api.test.com" assert model.model_id == "test_model_id" def test_load_model_transformers_model(): with ( patch( "transformers.AutoModelForImageTextToText.from_pretrained", side_effect=ValueError("Unrecognized configuration class"), ), patch("transformers.AutoModelForCausalLM.from_pretrained"), patch("transformers.AutoTokenizer.from_pretrained"), ): model = load_model("TransformersModel", "test_model_id") assert isinstance(model, TransformersModel) assert model.model_id == "test_model_id" def test_load_model_hf_api_model(set_env_vars): with patch("huggingface_hub.InferenceClient") as huggingface_hub_InferenceClient: model = load_model("InferenceClientModel", "test_model_id") assert isinstance(model, InferenceClientModel) assert model.model_id == "test_model_id" assert huggingface_hub_InferenceClient.call_count == 1 assert huggingface_hub_InferenceClient.call_args.kwargs["token"] == "test_hf_api_key" def test_load_model_invalid_model_type(): with pytest.raises(ValueError, match="Unsupported model type: InvalidModel"): load_model("InvalidModel", "test_model_id") def test_cli_main(capsys): with patch("smolagents.cli.load_model") as mock_load_model: mock_load_model.return_value = "mock_model" with patch("smolagents.cli.CodeAgent") as mock_code_agent: from smolagents.cli import run_smolagent run_smolagent("test_prompt", [], "InferenceClientModel", "test_model_id", provider="hf-inference") # load_model assert len(mock_load_model.call_args_list) == 1 assert mock_load_model.call_args.args == ("InferenceClientModel", "test_model_id") assert mock_load_model.call_args.kwargs == {"api_base": None, "api_key": None, "provider": "hf-inference"} # CodeAgent assert len(mock_code_agent.call_args_list) == 1 assert mock_code_agent.call_args.args == () assert mock_code_agent.call_args.kwargs == { "tools": [], "model": "mock_model", "additional_authorized_imports": None, } # agent.run assert len(mock_code_agent.return_value.run.call_args_list) == 1 assert mock_code_agent.return_value.run.call_args.args == ("test_prompt",) # print captured = capsys.readouterr() assert "Running agent with these tools: []" in captured.out def test_vision_web_browser_main(): with patch("smolagents.vision_web_browser.helium"): with patch("smolagents.vision_web_browser.load_model") as mock_load_model: mock_load_model.return_value = "mock_model" with patch("smolagents.vision_web_browser.CodeAgent") as mock_code_agent: from smolagents.vision_web_browser import helium_instructions, run_webagent run_webagent("test_prompt", "InferenceClientModel", "test_model_id", provider="hf-inference") # load_model assert len(mock_load_model.call_args_list) == 1 assert mock_load_model.call_args.args == ("InferenceClientModel", "test_model_id") # CodeAgent assert len(mock_code_agent.call_args_list) == 1 assert mock_code_agent.call_args.args == () assert len(mock_code_agent.call_args.kwargs["tools"]) == 4 assert mock_code_agent.call_args.kwargs["model"] == "mock_model" assert mock_code_agent.call_args.kwargs["additional_authorized_imports"] == ["helium"] # agent.python_executor assert len(mock_code_agent.return_value.python_executor.call_args_list) == 1 assert mock_code_agent.return_value.python_executor.call_args.args == ("from helium import *",) assert LocalPythonExecutor(["helium"])("from helium import *") == (None, "", False) # agent.run assert len(mock_code_agent.return_value.run.call_args_list) == 1 assert mock_code_agent.return_value.run.call_args.args == ("test_prompt" + helium_instructions,)