File size: 6,129 Bytes
a51a15b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from typing import Dict, Type, Any, List, Optional, Callable
from agentpress.tool import Tool, SchemaType, ToolSchema
from utils.logger import logger


class ToolRegistry:
    """Registry for managing and accessing tools.
    
    Maintains a collection of tool instances and their schemas, allowing for
    selective registration of tool functions and easy access to tool capabilities.
    
    Attributes:
        tools (Dict[str, Dict[str, Any]]): OpenAPI-style tools and schemas
        xml_tools (Dict[str, Dict[str, Any]]): XML-style tools and schemas
        
    Methods:
        register_tool: Register a tool with optional function filtering
        get_tool: Get a specific tool by name
        get_xml_tool: Get a tool by XML tag name
        get_openapi_schemas: Get OpenAPI schemas for function calling
        get_xml_examples: Get examples of XML tool usage
    """
    
    def __init__(self):
        """Initialize a new ToolRegistry instance."""
        self.tools = {}
        self.xml_tools = {}
        logger.debug("Initialized new ToolRegistry instance")
    
    def register_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
        """Register a tool with optional function filtering.
        
        Args:
            tool_class: The tool class to register
            function_names: Optional list of specific functions to register
            **kwargs: Additional arguments passed to tool initialization
            
        Notes:
            - If function_names is None, all functions are registered
            - Handles both OpenAPI and XML schema registration
        """
        logger.debug(f"Registering tool class: {tool_class.__name__}")
        tool_instance = tool_class(**kwargs)
        schemas = tool_instance.get_schemas()
        
        logger.debug(f"Available schemas for {tool_class.__name__}: {list(schemas.keys())}")
        
        registered_openapi = 0
        registered_xml = 0
        
        for func_name, schema_list in schemas.items():
            if function_names is None or func_name in function_names:
                for schema in schema_list:
                    if schema.schema_type == SchemaType.OPENAPI:
                        self.tools[func_name] = {
                            "instance": tool_instance,
                            "schema": schema
                        }
                        registered_openapi += 1
                        logger.debug(f"Registered OpenAPI function {func_name} from {tool_class.__name__}")
                    
                    if schema.schema_type == SchemaType.XML and schema.xml_schema:
                        self.xml_tools[schema.xml_schema.tag_name] = {
                            "instance": tool_instance,
                            "method": func_name,
                            "schema": schema
                        }
                        registered_xml += 1
                        logger.debug(f"Registered XML tag {schema.xml_schema.tag_name} -> {func_name} from {tool_class.__name__}")
        
        logger.debug(f"Tool registration complete for {tool_class.__name__}: {registered_openapi} OpenAPI functions, {registered_xml} XML tags")

    def get_available_functions(self) -> Dict[str, Callable]:
        """Get all available tool functions.
        
        Returns:
            Dict mapping function names to their implementations
        """
        available_functions = {}
        
        # Get OpenAPI tool functions
        for tool_name, tool_info in self.tools.items():
            tool_instance = tool_info['instance']
            function_name = tool_name
            function = getattr(tool_instance, function_name)
            available_functions[function_name] = function
            
        # Get XML tool functions
        for tag_name, tool_info in self.xml_tools.items():
            tool_instance = tool_info['instance']
            method_name = tool_info['method']
            function = getattr(tool_instance, method_name)
            available_functions[method_name] = function
            
        logger.debug(f"Retrieved {len(available_functions)} available functions")
        return available_functions

    def get_tool(self, tool_name: str) -> Dict[str, Any]:
        """Get a specific tool by name.
        
        Args:
            tool_name: Name of the tool function
            
        Returns:
            Dict containing tool instance and schema, or empty dict if not found
        """
        tool = self.tools.get(tool_name, {})
        if not tool:
            logger.warning(f"Tool not found: {tool_name}")
        return tool

    def get_xml_tool(self, tag_name: str) -> Dict[str, Any]:
        """Get tool info by XML tag name.
        
        Args:
            tag_name: XML tag name for the tool
            
        Returns:
            Dict containing tool instance, method name, and schema
        """
        tool = self.xml_tools.get(tag_name, {})
        if not tool:
            logger.warning(f"XML tool not found for tag: {tag_name}")
        return tool

    def get_openapi_schemas(self) -> List[Dict[str, Any]]:
        """Get OpenAPI schemas for function calling.
        
        Returns:
            List of OpenAPI-compatible schema definitions
        """
        schemas = [
            tool_info['schema'].schema 
            for tool_info in self.tools.values()
            if tool_info['schema'].schema_type == SchemaType.OPENAPI
        ]
        logger.debug(f"Retrieved {len(schemas)} OpenAPI schemas")
        return schemas

    def get_xml_examples(self) -> Dict[str, str]:
        """Get all XML tag examples.
        
        Returns:
            Dict mapping tag names to their example usage
        """
        examples = {}
        for tool_info in self.xml_tools.values():
            schema = tool_info['schema']
            if schema.xml_schema and schema.xml_schema.example:
                examples[schema.xml_schema.tag_name] = schema.xml_schema.example
        logger.debug(f"Retrieved {len(examples)} XML examples")
        return examples