Duibonduil commited on
Commit
9c9d7c5
·
verified ·
1 Parent(s): 3d06c91

Upload trace_tool.py

Browse files
Files changed (1) hide show
  1. examples/tools/trace/trace_tool.py +155 -0
examples/tools/trace/trace_tool.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import aworld.trace as trace
2
+ import aworld.trace.instrumentation.semconv as semconv
3
+ from aworld.trace.server import get_trace_server
4
+ from aworld.trace.server.util import build_trace_tree
5
+ from aworld.core.tool.base import Tool, AgentInput, ToolFactory
6
+ from examples.tools.tool_action import GetTraceAction
7
+ from aworld.tools.utils import build_observation
8
+ from aworld.config.conf import ToolConfig
9
+ from aworld.core.common import Observation, ActionModel, ActionResult
10
+ from typing import Tuple, Dict, Any, List
11
+ from aworld.logs.util import logger
12
+
13
+
14
+ @ToolFactory.register(name="trace",
15
+ desc="Get the trace of the current execution.",
16
+ supported_action=GetTraceAction,
17
+ conf_file_name=f'trace_tool.yaml')
18
+ class TraceTool(Tool):
19
+ def __init__(self,
20
+ conf: ToolConfig,
21
+ **kwargs) -> None:
22
+ """
23
+ Initialize the TraceTool
24
+ Args:
25
+ conf: tool config
26
+ **kwargs: -
27
+ Return:
28
+ None
29
+ """
30
+ super(TraceTool, self).__init__(conf, **kwargs)
31
+ self.type = "function"
32
+ self.get_trace_url = self.conf.get('get_trace_url')
33
+
34
+ def reset(self,
35
+ *,
36
+ seed: int | None = None,
37
+ options: Dict[str, str] | None = None) -> Tuple[AgentInput, dict[str, Any]]:
38
+ """
39
+ Reset the executor
40
+ Args:
41
+ seed: -
42
+ options: -
43
+ Returns:
44
+ AgentInput, dict[str, Any]: -
45
+ """
46
+ self._finished = False
47
+ return build_observation(observer=self.name(),
48
+ ability=GetTraceAction.GET_TRACE.value.name), {}
49
+
50
+ def close(self) -> None:
51
+ """
52
+ Close the executor
53
+ Returns:
54
+ None
55
+ """
56
+ self._finished = True
57
+
58
+ def do_step(self,
59
+ actions: List[ActionModel],
60
+ **kwargs) -> Tuple[Observation, float, bool, bool, dict[str, Any]]:
61
+ reward = 0
62
+ fail_error = ""
63
+ observation = build_observation(observer=self.name(),
64
+ ability=GetTraceAction.GET_TRACE.value.name)
65
+ results = []
66
+ try:
67
+ if not actions:
68
+ return (observation, reward,
69
+ kwargs.get("terminated",
70
+ False), kwargs.get("truncated", False), {
71
+ "exception": "actions is empty"
72
+ })
73
+ for action in actions:
74
+ trace_id = action.params.get("trace_id", "")
75
+ if not trace_id:
76
+ current_span = trace.get_current_span()
77
+ if current_span:
78
+ trace_id = current_span.get_trace_id()
79
+ if not trace_id:
80
+ logger.warning(f"{action} no trace_id to fetch.")
81
+ observation.action_result.append(
82
+ ActionResult(is_done=True,
83
+ success=False,
84
+ content="",
85
+ error="no trace_id to fetch",
86
+ keep=False))
87
+ continue
88
+ try:
89
+ trace_data = self.fetch_trace_data(trace_id)
90
+ logger.info(f"trace_data={trace_data}")
91
+ error = ""
92
+ except Exception as e:
93
+ error = str(e)
94
+ results.append(trace_data)
95
+ observation.action_result.append(
96
+ ActionResult(is_done=True,
97
+ success=False if error else True,
98
+ content=f"{trace_data}",
99
+ error=f"{error}",
100
+ keep=False))
101
+
102
+ observation.content = f"{results}"
103
+ reward = 1
104
+ except Exception as e:
105
+ fail_error = str(e)
106
+ finally:
107
+ self._finished = True
108
+
109
+ info = {"exception": fail_error}
110
+ info.update(kwargs)
111
+ return (observation, reward, kwargs.get("terminated", False),
112
+ kwargs.get("truncated", False), info)
113
+
114
+ def fetch_trace_data(self, trace_id=None):
115
+ '''
116
+ fetch trace data from trace server.
117
+ return trace data, like:
118
+ {
119
+ 'trace_id': trace_id,
120
+ 'root_span': [],
121
+ }
122
+ '''
123
+ try:
124
+ if trace_id:
125
+ trace_server = get_trace_server()
126
+ if not trace_server:
127
+ logger.error("No memory trace server has been set.")
128
+ else:
129
+ trace_storage = trace_server.get_storage()
130
+ spans = trace_storage.get_all_spans(trace_id)
131
+ if spans:
132
+ return self.proccess_trace(build_trace_tree(spans))
133
+ return {"trace_id": trace_id, "root_span": []}
134
+ except Exception as e:
135
+ logger.error(f"Error fetching trace data: {e}")
136
+ return {"trace_id": trace_id, "root_span": []}
137
+
138
+ def proccess_trace(self, trace_data):
139
+ root_spans = trace_data.get("root_span")
140
+ for span in root_spans:
141
+ self.choose_attribute(span)
142
+ return trace_data
143
+
144
+ def choose_attribute(self, span):
145
+ include_attr = [semconv.GEN_AI_USAGE_INPUT_TOKENS,
146
+ semconv.GEN_AI_USAGE_OUTPUT_TOKENS, semconv.GEN_AI_USAGE_TOTAL_TOKENS]
147
+ result_attributes = {}
148
+ origin_attributes = span.get("attributes") or {}
149
+ for key, value in origin_attributes.items():
150
+ if key in include_attr:
151
+ result_attributes[key] = value
152
+ span["attributes"] = result_attributes
153
+ if span.get("children"):
154
+ for child in span.get("children"):
155
+ self.choose_attribute(child)