Duibonduil commited on
Commit
bf4ab00
·
verified ·
1 Parent(s): a16a346

Upload 2 files

Browse files
aworld/trace/propagator/carrier.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypeVar
2
+ from aworld.trace.base import Carrier
3
+ from aworld.logs.util import logger
4
+
5
+ T = TypeVar("T")
6
+
7
+
8
+ class ListTupleCarrier(Carrier):
9
+
10
+ def __init__(self, headers: list[tuple[str, T]]):
11
+ self.headers = headers
12
+
13
+ def get(self, key: str) -> T:
14
+ for header, value in self.headers:
15
+ header_str = header.decode(
16
+ 'utf-8') if isinstance(header, bytes) else header
17
+ key_str = key.decode('utf-8') if isinstance(key, bytes) else key
18
+ if header_str.lower() == key_str.lower():
19
+ return value.decode('utf-8') if isinstance(value, bytes) else value
20
+ return None
21
+
22
+ def set(self, key: str, value: T) -> None:
23
+ for i, (header, _) in enumerate(self.headers):
24
+ header_str = header.decode(
25
+ 'utf-8') if isinstance(header, bytes) else header
26
+ key_str = key.decode('utf-8') if isinstance(key, bytes) else key
27
+ if header_str.lower() == key_str.lower():
28
+ self.headers[i] = (header, value)
29
+ return
30
+ self.headers.append((key, value))
31
+
32
+
33
+ class DictCarrier(Carrier):
34
+ def __init__(self, headers: dict[str, T]):
35
+ self.headers = headers
36
+
37
+ def get(self, key: str) -> T:
38
+ return self.headers.get(key)
39
+
40
+ def set(self, key: str, value: T) -> None:
41
+ logger.info(f"set header {key}={value}")
42
+ self.headers[key] = value
aworld/trace/propagator/w3c.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Tuple, List
3
+ from aworld.logs.util import logger
4
+ from aworld.trace.base import Propagator, Carrier, TraceContext
5
+
6
+
7
+ class W3CTraceContextPropagator(Propagator):
8
+ """
9
+ OtelPropagator is a Propagator that extracts and injects using w3c TraceContext's headers.
10
+ carrier = {
11
+ "traceparent": "00-0af7651916cd43dd8448eb211c80319c-00f067aa0ba902b7-01",
12
+ "tracestate": "congo=t61rcWkgMzE",
13
+ "baggage": "key1=value1,key2=value2"
14
+ }
15
+ """
16
+ _STATE_KEY_FORMAT = (
17
+ r"[a-z][_0-9a-z\-\*\/]{0,255}|"
18
+ r"[a-z0-9][_0-9a-z\-\*\/]{0,240}@[a-z][_0-9a-z\-\*\/]{0,13}"
19
+ )
20
+ _STATE_VALUE_FORMAT = (
21
+ r"[\x20-\x2b\x2d-\x3c\x3e-\x7e]{0,255}[\x21-\x2b\x2d-\x3c\x3e-\x7e]"
22
+ )
23
+ _state_delimiter_pattern = re.compile(r"[ \t]*,[ \t]*")
24
+ _state_member_pattern = re.compile(
25
+ f"({_STATE_KEY_FORMAT})(=)({_STATE_VALUE_FORMAT})[ \t]*")
26
+
27
+ _TRACEPARENT_HEADER_NAME = "traceparent"
28
+ _TRACESTATE_HEADER_NAME = "tracestate"
29
+ _TRACEPARENT_HEADER_FORMAT = (
30
+ "^[ \t]*([0-9a-f]{2})-([0-9a-f]{32})-([0-9a-f]{16})-([0-9a-f]{2})"
31
+ + "(-.*)?[ \t]*$"
32
+ )
33
+ _TRACEPARENT_HEADER_FORMAT_RE = re.compile(_TRACEPARENT_HEADER_FORMAT)
34
+
35
+ def extract(self, carrier: Carrier) -> TraceContext:
36
+ """
37
+ Extract trace context from carrier.
38
+ Args:
39
+ carrier: The carrier to extract trace context from.
40
+ Returns:
41
+ A dict of trace context.
42
+ """
43
+ header = carrier.get(self._TRACEPARENT_HEADER_NAME) or carrier.get(
44
+ 'HTTP_' + self._TRACEPARENT_HEADER_NAME.upper())
45
+
46
+ if header is None:
47
+ return None
48
+
49
+ match = re.search(self._TRACEPARENT_HEADER_FORMAT_RE, header)
50
+ if not match:
51
+ return None
52
+
53
+ version: str = match.group(1)
54
+ trace_id: str = match.group(2)
55
+ span_id: str = match.group(3)
56
+ trace_flags: str = match.group(4)
57
+
58
+ logger.info(
59
+ f"extract trace_id: {trace_id}, span_id: {span_id}, trace_flags: {trace_flags}, version: {version}")
60
+
61
+ if trace_id == "0" * 32 or span_id == "0" * 16:
62
+ return None
63
+ if version == "00":
64
+ if match.group(5): # type: ignore
65
+ return None
66
+ if version == "ff":
67
+ return None
68
+
69
+ state_header = carrier.get(self._TRACESTATE_HEADER_NAME) or carrier.get(
70
+ 'HTTP_' + self._TRACESTATE_HEADER_NAME.upper())
71
+ return TraceContext(
72
+ trace_id=trace_id,
73
+ span_id=span_id,
74
+ trace_flags=trace_flags,
75
+ version=version,
76
+ attributes=(self._extract_state_from_header(state_header))
77
+ )
78
+
79
+ def inject(self, trace_context: TraceContext, carrier: Carrier) -> None:
80
+ """
81
+ Inject trace context into carrier.
82
+ Args:
83
+ context: The trace context to inject.
84
+ carrier: The carrier to inject trace context into.
85
+ """
86
+ attribute_copy = trace_context.attributes.copy()
87
+ version: str = trace_context.version
88
+ trace_flags: str = trace_context.trace_flags
89
+ trace_id = trace_context.trace_id
90
+ span_id = trace_context.span_id
91
+ logger.info(
92
+ f"inject trace_id: {trace_id}, span_id: {span_id}, trace_flags: {trace_flags}, version: {version}")
93
+ if (not trace_id or trace_id == "0" * 32
94
+ or not span_id or span_id == "0" * 16):
95
+ return
96
+
97
+ if isinstance(trace_id, int):
98
+ trace_id = format(trace_id, "032x")
99
+ if isinstance(span_id, int):
100
+ span_id = format(span_id, "016x")
101
+ traceparent_string = f"{version}-{trace_id}-{span_id}-{trace_flags}"
102
+ carrier.set(self._TRACEPARENT_HEADER_NAME, traceparent_string)
103
+ tracestate_string = ",".join(
104
+ f"{key}={value}" for key, value in attribute_copy.items())
105
+ if tracestate_string:
106
+ carrier.set(self._TRACESTATE_HEADER_NAME, tracestate_string)
107
+
108
+ def _extract_state_from_header(self, header: str) -> dict:
109
+ """
110
+ Extract state from header.
111
+ Args:
112
+ header: The header to extract state from.
113
+ Returns:
114
+ A dict of state.
115
+ """
116
+ if header is None:
117
+ return {}
118
+ state = {}
119
+ members: List[str] = re.split(self._state_delimiter_pattern, header)
120
+ for member in members:
121
+ # empty members are valid, but no need to process further.
122
+ if not member:
123
+ continue
124
+ match = self._state_member_pattern.fullmatch(member)
125
+ if not match:
126
+ logger.warning(
127
+ "Member doesn't match the w3c identifiers format {member}")
128
+ return state
129
+ groups: Tuple[str, ...] = match.groups()
130
+ key, _eq, value = groups
131
+ # duplicate keys are not legal in header
132
+ if key in state:
133
+ return state
134
+ state[key] = value
135
+ return state