File size: 6,129 Bytes
a51a15b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from typing import Dict, Type, Any, List, Optional, Callable
from agentpress.tool import Tool, SchemaType, ToolSchema
from utils.logger import logger
class ToolRegistry:
"""Registry for managing and accessing tools.
Maintains a collection of tool instances and their schemas, allowing for
selective registration of tool functions and easy access to tool capabilities.
Attributes:
tools (Dict[str, Dict[str, Any]]): OpenAPI-style tools and schemas
xml_tools (Dict[str, Dict[str, Any]]): XML-style tools and schemas
Methods:
register_tool: Register a tool with optional function filtering
get_tool: Get a specific tool by name
get_xml_tool: Get a tool by XML tag name
get_openapi_schemas: Get OpenAPI schemas for function calling
get_xml_examples: Get examples of XML tool usage
"""
def __init__(self):
"""Initialize a new ToolRegistry instance."""
self.tools = {}
self.xml_tools = {}
logger.debug("Initialized new ToolRegistry instance")
def register_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
"""Register a tool with optional function filtering.
Args:
tool_class: The tool class to register
function_names: Optional list of specific functions to register
**kwargs: Additional arguments passed to tool initialization
Notes:
- If function_names is None, all functions are registered
- Handles both OpenAPI and XML schema registration
"""
logger.debug(f"Registering tool class: {tool_class.__name__}")
tool_instance = tool_class(**kwargs)
schemas = tool_instance.get_schemas()
logger.debug(f"Available schemas for {tool_class.__name__}: {list(schemas.keys())}")
registered_openapi = 0
registered_xml = 0
for func_name, schema_list in schemas.items():
if function_names is None or func_name in function_names:
for schema in schema_list:
if schema.schema_type == SchemaType.OPENAPI:
self.tools[func_name] = {
"instance": tool_instance,
"schema": schema
}
registered_openapi += 1
logger.debug(f"Registered OpenAPI function {func_name} from {tool_class.__name__}")
if schema.schema_type == SchemaType.XML and schema.xml_schema:
self.xml_tools[schema.xml_schema.tag_name] = {
"instance": tool_instance,
"method": func_name,
"schema": schema
}
registered_xml += 1
logger.debug(f"Registered XML tag {schema.xml_schema.tag_name} -> {func_name} from {tool_class.__name__}")
logger.debug(f"Tool registration complete for {tool_class.__name__}: {registered_openapi} OpenAPI functions, {registered_xml} XML tags")
def get_available_functions(self) -> Dict[str, Callable]:
"""Get all available tool functions.
Returns:
Dict mapping function names to their implementations
"""
available_functions = {}
# Get OpenAPI tool functions
for tool_name, tool_info in self.tools.items():
tool_instance = tool_info['instance']
function_name = tool_name
function = getattr(tool_instance, function_name)
available_functions[function_name] = function
# Get XML tool functions
for tag_name, tool_info in self.xml_tools.items():
tool_instance = tool_info['instance']
method_name = tool_info['method']
function = getattr(tool_instance, method_name)
available_functions[method_name] = function
logger.debug(f"Retrieved {len(available_functions)} available functions")
return available_functions
def get_tool(self, tool_name: str) -> Dict[str, Any]:
"""Get a specific tool by name.
Args:
tool_name: Name of the tool function
Returns:
Dict containing tool instance and schema, or empty dict if not found
"""
tool = self.tools.get(tool_name, {})
if not tool:
logger.warning(f"Tool not found: {tool_name}")
return tool
def get_xml_tool(self, tag_name: str) -> Dict[str, Any]:
"""Get tool info by XML tag name.
Args:
tag_name: XML tag name for the tool
Returns:
Dict containing tool instance, method name, and schema
"""
tool = self.xml_tools.get(tag_name, {})
if not tool:
logger.warning(f"XML tool not found for tag: {tag_name}")
return tool
def get_openapi_schemas(self) -> List[Dict[str, Any]]:
"""Get OpenAPI schemas for function calling.
Returns:
List of OpenAPI-compatible schema definitions
"""
schemas = [
tool_info['schema'].schema
for tool_info in self.tools.values()
if tool_info['schema'].schema_type == SchemaType.OPENAPI
]
logger.debug(f"Retrieved {len(schemas)} OpenAPI schemas")
return schemas
def get_xml_examples(self) -> Dict[str, str]:
"""Get all XML tag examples.
Returns:
Dict mapping tag names to their example usage
"""
examples = {}
for tool_info in self.xml_tools.values():
schema = tool_info['schema']
if schema.xml_schema and schema.xml_schema.example:
examples[schema.xml_schema.tag_name] = schema.xml_schema.example
logger.debug(f"Retrieved {len(examples)} XML examples")
return examples
|