import re from typing import Tuple, List from aworld.logs.util import logger from aworld.trace.base import Propagator, Carrier, TraceContext class W3CTraceContextPropagator(Propagator): """ OtelPropagator is a Propagator that extracts and injects using w3c TraceContext's headers. carrier = { "traceparent": "00-0af7651916cd43dd8448eb211c80319c-00f067aa0ba902b7-01", "tracestate": "congo=t61rcWkgMzE", "baggage": "key1=value1,key2=value2" } """ _STATE_KEY_FORMAT = ( r"[a-z][_0-9a-z\-\*\/]{0,255}|" r"[a-z0-9][_0-9a-z\-\*\/]{0,240}@[a-z][_0-9a-z\-\*\/]{0,13}" ) _STATE_VALUE_FORMAT = ( r"[\x20-\x2b\x2d-\x3c\x3e-\x7e]{0,255}[\x21-\x2b\x2d-\x3c\x3e-\x7e]" ) _state_delimiter_pattern = re.compile(r"[ \t]*,[ \t]*") _state_member_pattern = re.compile( f"({_STATE_KEY_FORMAT})(=)({_STATE_VALUE_FORMAT})[ \t]*") _TRACEPARENT_HEADER_NAME = "traceparent" _TRACESTATE_HEADER_NAME = "tracestate" _TRACEPARENT_HEADER_FORMAT = ( "^[ \t]*([0-9a-f]{2})-([0-9a-f]{32})-([0-9a-f]{16})-([0-9a-f]{2})" + "(-.*)?[ \t]*$" ) _TRACEPARENT_HEADER_FORMAT_RE = re.compile(_TRACEPARENT_HEADER_FORMAT) def extract(self, carrier: Carrier) -> TraceContext: """ Extract trace context from carrier. Args: carrier: The carrier to extract trace context from. Returns: A dict of trace context. """ header = carrier.get(self._TRACEPARENT_HEADER_NAME) or carrier.get( 'HTTP_' + self._TRACEPARENT_HEADER_NAME.upper()) if header is None: return None match = re.search(self._TRACEPARENT_HEADER_FORMAT_RE, header) if not match: return None version: str = match.group(1) trace_id: str = match.group(2) span_id: str = match.group(3) trace_flags: str = match.group(4) logger.info( f"extract trace_id: {trace_id}, span_id: {span_id}, trace_flags: {trace_flags}, version: {version}") if trace_id == "0" * 32 or span_id == "0" * 16: return None if version == "00": if match.group(5): # type: ignore return None if version == "ff": return None state_header = carrier.get(self._TRACESTATE_HEADER_NAME) or carrier.get( 'HTTP_' + self._TRACESTATE_HEADER_NAME.upper()) return TraceContext( trace_id=trace_id, span_id=span_id, trace_flags=trace_flags, version=version, attributes=(self._extract_state_from_header(state_header)) ) def inject(self, trace_context: TraceContext, carrier: Carrier) -> None: """ Inject trace context into carrier. Args: context: The trace context to inject. carrier: The carrier to inject trace context into. """ attribute_copy = trace_context.attributes.copy() version: str = trace_context.version trace_flags: str = trace_context.trace_flags trace_id = trace_context.trace_id span_id = trace_context.span_id logger.info( f"inject trace_id: {trace_id}, span_id: {span_id}, trace_flags: {trace_flags}, version: {version}") if (not trace_id or trace_id == "0" * 32 or not span_id or span_id == "0" * 16): return if isinstance(trace_id, int): trace_id = format(trace_id, "032x") if isinstance(span_id, int): span_id = format(span_id, "016x") traceparent_string = f"{version}-{trace_id}-{span_id}-{trace_flags}" carrier.set(self._TRACEPARENT_HEADER_NAME, traceparent_string) tracestate_string = ",".join( f"{key}={value}" for key, value in attribute_copy.items()) if tracestate_string: carrier.set(self._TRACESTATE_HEADER_NAME, tracestate_string) def _extract_state_from_header(self, header: str) -> dict: """ Extract state from header. Args: header: The header to extract state from. Returns: A dict of state. """ if header is None: return {} state = {} members: List[str] = re.split(self._state_delimiter_pattern, header) for member in members: # empty members are valid, but no need to process further. if not member: continue match = self._state_member_pattern.fullmatch(member) if not match: logger.warning( "Member doesn't match the w3c identifiers format {member}") return state groups: Tuple[str, ...] = match.groups() key, _eq, value = groups # duplicate keys are not legal in header if key in state: return state state[key] = value return state