Duibonduil's picture
Upload 17 files
d7949de verified
#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from dotenv import load_dotenv
from smolagents import CodeAgent, InferenceClientModel, LiteLLMModel, Model, OpenAIServerModel, Tool, TransformersModel
from smolagents.default_tools import TOOL_MAPPING
leopard_prompt = "How many seconds would it take for a leopard at full speed to run through Pont des Arts?"
def parse_arguments():
parser = argparse.ArgumentParser(description="Run a CodeAgent with all specified parameters")
parser.add_argument(
"prompt",
type=str,
nargs="?", # Makes it optional
default=leopard_prompt,
help="The prompt to run with the agent",
)
parser.add_argument(
"--model-type",
type=str,
default="InferenceClientModel",
help="The model type to use (e.g., InferenceClientModel, OpenAIServerModel, LiteLLMModel, TransformersModel)",
)
parser.add_argument(
"--model-id",
type=str,
default="Qwen/Qwen2.5-Coder-32B-Instruct",
help="The model ID to use for the specified model type",
)
parser.add_argument(
"--imports",
nargs="*", # accepts zero or more arguments
default=[],
help="Space-separated list of imports to authorize (e.g., 'numpy pandas')",
)
parser.add_argument(
"--tools",
nargs="*",
default=["web_search"],
help="Space-separated list of tools that the agent can use (e.g., 'tool1 tool2 tool3')",
)
parser.add_argument(
"--verbosity-level",
type=int,
default=1,
help="The verbosity level, as an int in [0, 1, 2].",
)
group = parser.add_argument_group("api options", "Options for API-based model types")
group.add_argument(
"--provider",
type=str,
default=None,
help="The inference provider to use for the model",
)
group.add_argument(
"--api-base",
type=str,
help="The base URL for the model",
)
group.add_argument(
"--api-key",
type=str,
help="The API key for the model",
)
return parser.parse_args()
def load_model(
model_type: str,
model_id: str,
api_base: str | None = None,
api_key: str | None = None,
provider: str | None = None,
) -> Model:
if model_type == "OpenAIServerModel":
return OpenAIServerModel(
api_key=api_key or os.getenv("FIREWORKS_API_KEY"),
api_base=api_base or "https://api.fireworks.ai/inference/v1",
model_id=model_id,
)
elif model_type == "LiteLLMModel":
return LiteLLMModel(
model_id=model_id,
api_key=api_key,
api_base=api_base,
)
elif model_type == "TransformersModel":
return TransformersModel(model_id=model_id, device_map="auto")
elif model_type == "InferenceClientModel":
return InferenceClientModel(
model_id=model_id,
token=api_key or os.getenv("HF_API_KEY"),
provider=provider,
)
else:
raise ValueError(f"Unsupported model type: {model_type}")
def run_smolagent(
prompt: str,
tools: list[str],
model_type: str,
model_id: str,
api_base: str | None = None,
api_key: str | None = None,
imports: list[str] | None = None,
provider: str | None = None,
) -> None:
load_dotenv()
model = load_model(model_type, model_id, api_base=api_base, api_key=api_key, provider=provider)
available_tools = []
for tool_name in tools:
if "/" in tool_name:
available_tools.append(Tool.from_space(tool_name))
else:
if tool_name in TOOL_MAPPING:
available_tools.append(TOOL_MAPPING[tool_name]())
else:
raise ValueError(f"Tool {tool_name} is not recognized either as a default tool or a Space.")
print(f"Running agent with these tools: {tools}")
agent = CodeAgent(tools=available_tools, model=model, additional_authorized_imports=imports)
agent.run(prompt)
def main() -> None:
args = parse_arguments()
run_smolagent(
args.prompt,
args.tools,
args.model_type,
args.model_id,
provider=args.provider,
api_base=args.api_base,
api_key=args.api_key,
imports=args.imports,
)
if __name__ == "__main__":
main()