Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files
examples/sandbox/inner_tools/another_function_tools.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
|
4 |
+
import json
|
5 |
+
from typing import Dict, Any, List, Optional
|
6 |
+
|
7 |
+
from pydantic import Field
|
8 |
+
|
9 |
+
from aworld.tools import FunctionTools
|
10 |
+
|
11 |
+
# Create another function tool server with a different name
|
12 |
+
function = FunctionTools("another-server",
|
13 |
+
description="Another function tools server example")
|
14 |
+
|
15 |
+
@function.tool(description="Get weather information for a city")
|
16 |
+
def get_weather(
|
17 |
+
city: str = Field(
|
18 |
+
description="City name to get weather for"
|
19 |
+
),
|
20 |
+
days: int = Field(
|
21 |
+
3,
|
22 |
+
description="Number of days for forecast"
|
23 |
+
)
|
24 |
+
) -> Dict[str, Any]:
|
25 |
+
"""Get weather information for a city (simulated data)"""
|
26 |
+
# Simulated weather data
|
27 |
+
weather_types = ["Sunny", "Cloudy", "Rainy", "Windy", "Snowy"]
|
28 |
+
import random
|
29 |
+
|
30 |
+
forecast = []
|
31 |
+
for i in range(days):
|
32 |
+
forecast.append({
|
33 |
+
"date": f"2023-06-{i+1:02d}",
|
34 |
+
"weather": random.choice(weather_types),
|
35 |
+
"temperature": {
|
36 |
+
"min": random.randint(15, 25),
|
37 |
+
"max": random.randint(26, 35)
|
38 |
+
},
|
39 |
+
"humidity": random.randint(30, 90)
|
40 |
+
})
|
41 |
+
|
42 |
+
return {
|
43 |
+
"city": city,
|
44 |
+
"country": "Sample Country",
|
45 |
+
"forecast": forecast
|
46 |
+
}
|
47 |
+
|
48 |
+
@function.tool(description="Convert currency from one to another")
|
49 |
+
def convert_currency(
|
50 |
+
amount: float = Field(
|
51 |
+
description="Amount to convert"
|
52 |
+
),
|
53 |
+
from_currency: str = Field(
|
54 |
+
description="Source currency code (e.g. USD)"
|
55 |
+
),
|
56 |
+
to_currency: str = Field(
|
57 |
+
description="Target currency code (e.g. EUR)"
|
58 |
+
)
|
59 |
+
) -> Dict[str, Any]:
|
60 |
+
"""Currency conversion (simulated data)"""
|
61 |
+
# Simulated exchange rate data
|
62 |
+
rates = {
|
63 |
+
"USD": 1.0,
|
64 |
+
"EUR": 0.85,
|
65 |
+
"GBP": 0.75,
|
66 |
+
"JPY": 110.0,
|
67 |
+
"CNY": 6.5
|
68 |
+
}
|
69 |
+
|
70 |
+
# Check if currencies are supported
|
71 |
+
if from_currency not in rates:
|
72 |
+
return {"error": f"Currency {from_currency} not supported"}
|
73 |
+
if to_currency not in rates:
|
74 |
+
return {"error": f"Currency {to_currency} not supported"}
|
75 |
+
|
76 |
+
# Calculate conversion
|
77 |
+
usd_amount = amount / rates[from_currency]
|
78 |
+
converted_amount = usd_amount * rates[to_currency]
|
79 |
+
|
80 |
+
return {
|
81 |
+
"from": {
|
82 |
+
"currency": from_currency,
|
83 |
+
"amount": amount
|
84 |
+
},
|
85 |
+
"to": {
|
86 |
+
"currency": to_currency,
|
87 |
+
"amount": round(converted_amount, 2)
|
88 |
+
},
|
89 |
+
"rate": round(rates[to_currency] / rates[from_currency], 4)
|
90 |
+
}
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
# Test tools
|
94 |
+
print("=== Testing get_weather tool ===")
|
95 |
+
weather = function.call_tool("get_weather", {"city": "Beijing"})
|
96 |
+
print(json.dumps(weather, indent=2))
|
97 |
+
|
98 |
+
print("\n=== Testing convert_currency tool ===")
|
99 |
+
conversion = function.call_tool("convert_currency", {
|
100 |
+
"amount": 100,
|
101 |
+
"from_currency": "USD",
|
102 |
+
"to_currency": "EUR"
|
103 |
+
})
|
104 |
+
print(json.dumps(conversion, indent=2))
|
examples/sandbox/inner_tools/aworldsearch_function_tools.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
|
4 |
+
import asyncio
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
import pprint
|
9 |
+
from typing import List, Dict, Any, Optional, Union
|
10 |
+
|
11 |
+
import aiohttp
|
12 |
+
from mcp.types import TextContent
|
13 |
+
from pydantic import Field
|
14 |
+
|
15 |
+
from aworld.tools import FunctionTools
|
16 |
+
|
17 |
+
# Create function tools server
|
18 |
+
function = FunctionTools("aworldsearch_server",
|
19 |
+
description="Search service for AWorld")
|
20 |
+
|
21 |
+
async def search_single(query: str, num: int = 5) -> Optional[Dict[str, Any]]:
|
22 |
+
"""Execute a single search query, returns None on error"""
|
23 |
+
try:
|
24 |
+
url = os.getenv('AWORLD_SEARCH_URL')
|
25 |
+
searchMode = os.getenv('AWORLD_SEARCH_SEARCHMODE')
|
26 |
+
source = os.getenv('AWORLD_SEARCH_SOURCE')
|
27 |
+
domain = os.getenv('AWORLD_SEARCH_DOMAIN')
|
28 |
+
uid = os.getenv('AWORLD_SEARCH_UID')
|
29 |
+
if not url or not searchMode or not source or not domain:
|
30 |
+
logging.warning(f"Query failed: url, searchMode, source, domain parameters incomplete")
|
31 |
+
return None
|
32 |
+
|
33 |
+
headers = {
|
34 |
+
'Content-Type': 'application/json'
|
35 |
+
}
|
36 |
+
data = {
|
37 |
+
"domain": domain,
|
38 |
+
"extParams": {},
|
39 |
+
"page": 0,
|
40 |
+
"pageSize": num,
|
41 |
+
"query": query,
|
42 |
+
"searchMode": searchMode,
|
43 |
+
"source": source,
|
44 |
+
"userId": uid
|
45 |
+
}
|
46 |
+
|
47 |
+
async with aiohttp.ClientSession() as session:
|
48 |
+
try:
|
49 |
+
async with session.post(url, headers=headers, json=data) as response:
|
50 |
+
if response.status != 200:
|
51 |
+
logging.warning(f"Query failed: {query}, status code: {response.status}")
|
52 |
+
return None
|
53 |
+
|
54 |
+
result = await response.json()
|
55 |
+
return result
|
56 |
+
except aiohttp.ClientError:
|
57 |
+
logging.warning(f"Request error: {query}")
|
58 |
+
return None
|
59 |
+
except Exception:
|
60 |
+
logging.warning(f"Query exception: {query}")
|
61 |
+
return None
|
62 |
+
|
63 |
+
|
64 |
+
def filter_valid_docs(result: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
65 |
+
"""Filter valid document results, returns empty list if input is None"""
|
66 |
+
if result is None:
|
67 |
+
return []
|
68 |
+
|
69 |
+
try:
|
70 |
+
valid_docs = []
|
71 |
+
|
72 |
+
# Check success field
|
73 |
+
if not result.get("success"):
|
74 |
+
return valid_docs
|
75 |
+
|
76 |
+
# Check searchDocs field
|
77 |
+
search_docs = result.get("searchDocs", [])
|
78 |
+
if not search_docs:
|
79 |
+
return valid_docs
|
80 |
+
|
81 |
+
# Extract required fields
|
82 |
+
required_fields = ["title", "docAbstract", "url", "doc"]
|
83 |
+
|
84 |
+
for doc in search_docs:
|
85 |
+
# Check if all required fields exist and are non-empty
|
86 |
+
is_valid = True
|
87 |
+
for field in required_fields:
|
88 |
+
if field not in doc or not doc[field]:
|
89 |
+
is_valid = False
|
90 |
+
break
|
91 |
+
|
92 |
+
if is_valid:
|
93 |
+
# Only keep required fields
|
94 |
+
filtered_doc = {field: doc[field] for field in required_fields}
|
95 |
+
valid_docs.append(filtered_doc)
|
96 |
+
|
97 |
+
return valid_docs
|
98 |
+
except Exception:
|
99 |
+
return []
|
100 |
+
|
101 |
+
|
102 |
+
@function.tool(description="Search based on the user's input query list")
|
103 |
+
async def search(
|
104 |
+
query_list: List[str] = Field(
|
105 |
+
description="List format, queries to search for"
|
106 |
+
),
|
107 |
+
num: int = Field(
|
108 |
+
5,
|
109 |
+
description="Maximum number of results per query, default is 5, please keep the total results within 15"
|
110 |
+
)
|
111 |
+
) -> Union[str, TextContent]:
|
112 |
+
"""Execute main search function, supports single query or query list"""
|
113 |
+
try:
|
114 |
+
# Get configuration from environment variables
|
115 |
+
env_total_num = os.getenv('AWORLD_SEARCH_TOTAL_NUM')
|
116 |
+
if env_total_num and env_total_num.isdigit():
|
117 |
+
# Use environment variable to forcibly override the input num parameter
|
118 |
+
num = int(env_total_num)
|
119 |
+
|
120 |
+
# If no query is provided, return empty list
|
121 |
+
if not query_list:
|
122 |
+
# Initialize TextContent with additional parameters
|
123 |
+
return TextContent(
|
124 |
+
type="text",
|
125 |
+
text="", # Empty string instead of None
|
126 |
+
**{"metadata": {}} # Pass as additional field
|
127 |
+
)
|
128 |
+
|
129 |
+
# When query count >=3 or slice_num is set, use the corresponding value
|
130 |
+
slice_num = os.getenv('AWORLD_SEARCH_SLICE_NUM')
|
131 |
+
if slice_num and slice_num.isdigit():
|
132 |
+
actual_num = int(slice_num)
|
133 |
+
else:
|
134 |
+
actual_num = 2 if len(query_list) >= 3 else num
|
135 |
+
|
136 |
+
# Execute all queries in parallel
|
137 |
+
tasks = [search_single(q, actual_num) for q in query_list]
|
138 |
+
raw_results = await asyncio.gather(*tasks)
|
139 |
+
|
140 |
+
# Filter and merge results
|
141 |
+
all_valid_docs = []
|
142 |
+
for result in raw_results:
|
143 |
+
valid_docs = filter_valid_docs(result)
|
144 |
+
all_valid_docs.extend(valid_docs)
|
145 |
+
|
146 |
+
# If no valid results found, return empty list
|
147 |
+
if not all_valid_docs:
|
148 |
+
# Initialize TextContent with additional parameters
|
149 |
+
return TextContent(
|
150 |
+
type="text",
|
151 |
+
text="", # Empty string instead of None
|
152 |
+
**{"metadata": {}} # Pass as additional field
|
153 |
+
)
|
154 |
+
|
155 |
+
# Format results as JSON
|
156 |
+
result_json = json.dumps(all_valid_docs, ensure_ascii=False)
|
157 |
+
|
158 |
+
# Create dictionary structure directly
|
159 |
+
combined_query = ",".join(query_list)
|
160 |
+
|
161 |
+
search_items = []
|
162 |
+
# Use dictionary for URL deduplication
|
163 |
+
url_dict = {}
|
164 |
+
for doc in all_valid_docs:
|
165 |
+
url = doc.get("url", "")
|
166 |
+
if url not in url_dict:
|
167 |
+
url_dict[url] = {
|
168 |
+
"title": doc.get("title", ""),
|
169 |
+
"url": url,
|
170 |
+
"snippet": doc.get("doc", "")[:100] + "..." if len(doc.get("doc", "")) > 100 else doc.get("doc", ""),
|
171 |
+
"content": doc.get("doc", "") # Map doc field to content
|
172 |
+
}
|
173 |
+
|
174 |
+
# Convert dictionary values to list
|
175 |
+
search_items = list(url_dict.values())
|
176 |
+
|
177 |
+
search_output_dict = {
|
178 |
+
"artifact_type": "WEB_PAGES",
|
179 |
+
"artifact_data": {
|
180 |
+
"query": combined_query,
|
181 |
+
"results": search_items
|
182 |
+
}
|
183 |
+
}
|
184 |
+
|
185 |
+
# Log results
|
186 |
+
logging.info(f"Completed {len(query_list)} queries, found {len(all_valid_docs)} valid documents")
|
187 |
+
|
188 |
+
# Initialize TextContent with additional parameters
|
189 |
+
return TextContent(
|
190 |
+
type="text",
|
191 |
+
text=result_json,
|
192 |
+
**{"metadata": search_output_dict} # Pass processed data as metadata
|
193 |
+
)
|
194 |
+
except Exception as e:
|
195 |
+
# Handle errors
|
196 |
+
logging.error(f"Search error: {e}")
|
197 |
+
# Initialize TextContent with additional parameters
|
198 |
+
return TextContent(
|
199 |
+
type="text",
|
200 |
+
text="", # Empty string instead of None
|
201 |
+
**{"metadata": {}} # Pass as additional field
|
202 |
+
)
|
203 |
+
|
204 |
+
# Test code
|
205 |
+
if __name__ == "__main__":
|
206 |
+
import pprint
|
207 |
+
|
208 |
+
# Configure logging
|
209 |
+
logging.basicConfig(
|
210 |
+
level=logging.INFO,
|
211 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
212 |
+
)
|
213 |
+
|
214 |
+
# List all tools
|
215 |
+
print("Tool list:")
|
216 |
+
tools = function.list_tools()
|
217 |
+
print(tools)
|
218 |
+
res = function.call_tool("search", {"query_list": ["Tencent financial report", "Baidu financial report", "Alibaba financial report"],})
|
219 |
+
print(res)
|
220 |
+
# for tool in tools:
|
221 |
+
# print(f"Tool name: {tool.name}")
|
222 |
+
# print(f"Tool description: {tool.description}")
|
223 |
+
# print(f"Parameter schema: {tool.inputSchema}")
|
224 |
+
# if tool.annotations:
|
225 |
+
# print(f"Annotation information:")
|
226 |
+
# print(f" - Title: {tool.annotations.title}")
|
227 |
+
# print()
|