File size: 5,023 Bytes
bf4ab00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import re
from typing import Tuple, List
from aworld.logs.util import logger
from aworld.trace.base import Propagator, Carrier, TraceContext


class W3CTraceContextPropagator(Propagator):
    """
    OtelPropagator is a Propagator that extracts and injects using w3c TraceContext's headers.
        carrier = {
            "traceparent": "00-0af7651916cd43dd8448eb211c80319c-00f067aa0ba902b7-01",
            "tracestate": "congo=t61rcWkgMzE",
            "baggage": "key1=value1,key2=value2"
        }
    """
    _STATE_KEY_FORMAT = (
        r"[a-z][_0-9a-z\-\*\/]{0,255}|"
        r"[a-z0-9][_0-9a-z\-\*\/]{0,240}@[a-z][_0-9a-z\-\*\/]{0,13}"
    )
    _STATE_VALUE_FORMAT = (
        r"[\x20-\x2b\x2d-\x3c\x3e-\x7e]{0,255}[\x21-\x2b\x2d-\x3c\x3e-\x7e]"
    )
    _state_delimiter_pattern = re.compile(r"[ \t]*,[ \t]*")
    _state_member_pattern = re.compile(
        f"({_STATE_KEY_FORMAT})(=)({_STATE_VALUE_FORMAT})[ \t]*")

    _TRACEPARENT_HEADER_NAME = "traceparent"
    _TRACESTATE_HEADER_NAME = "tracestate"
    _TRACEPARENT_HEADER_FORMAT = (
        "^[ \t]*([0-9a-f]{2})-([0-9a-f]{32})-([0-9a-f]{16})-([0-9a-f]{2})"
        + "(-.*)?[ \t]*$"
    )
    _TRACEPARENT_HEADER_FORMAT_RE = re.compile(_TRACEPARENT_HEADER_FORMAT)

    def extract(self, carrier: Carrier) -> TraceContext:
        """
        Extract trace context from carrier.
        Args:
            carrier: The carrier to extract trace context from.
        Returns:
            A dict of trace context.
        """
        header = carrier.get(self._TRACEPARENT_HEADER_NAME) or carrier.get(
            'HTTP_' + self._TRACEPARENT_HEADER_NAME.upper())

        if header is None:
            return None

        match = re.search(self._TRACEPARENT_HEADER_FORMAT_RE, header)
        if not match:
            return None

        version: str = match.group(1)
        trace_id: str = match.group(2)
        span_id: str = match.group(3)
        trace_flags: str = match.group(4)

        logger.info(
            f"extract trace_id: {trace_id}, span_id: {span_id}, trace_flags: {trace_flags}, version: {version}")

        if trace_id == "0" * 32 or span_id == "0" * 16:
            return None
        if version == "00":
            if match.group(5):  # type: ignore
                return None
        if version == "ff":
            return None

        state_header = carrier.get(self._TRACESTATE_HEADER_NAME) or carrier.get(
            'HTTP_' + self._TRACESTATE_HEADER_NAME.upper())
        return TraceContext(
            trace_id=trace_id,
            span_id=span_id,
            trace_flags=trace_flags,
            version=version,
            attributes=(self._extract_state_from_header(state_header))
        )

    def inject(self, trace_context: TraceContext, carrier: Carrier) -> None:
        """
        Inject trace context into carrier.
        Args:
            context: The trace context to inject.
            carrier: The carrier to inject trace context into.
        """
        attribute_copy = trace_context.attributes.copy()
        version: str = trace_context.version
        trace_flags: str = trace_context.trace_flags
        trace_id = trace_context.trace_id
        span_id = trace_context.span_id
        logger.info(
            f"inject trace_id: {trace_id}, span_id: {span_id}, trace_flags: {trace_flags}, version: {version}")
        if (not trace_id or trace_id == "0" * 32
                or not span_id or span_id == "0" * 16):
            return

        if isinstance(trace_id, int):
            trace_id = format(trace_id, "032x")
        if isinstance(span_id, int):
            span_id = format(span_id, "016x")
        traceparent_string = f"{version}-{trace_id}-{span_id}-{trace_flags}"
        carrier.set(self._TRACEPARENT_HEADER_NAME, traceparent_string)
        tracestate_string = ",".join(
            f"{key}={value}" for key, value in attribute_copy.items())
        if tracestate_string:
            carrier.set(self._TRACESTATE_HEADER_NAME, tracestate_string)

    def _extract_state_from_header(self, header: str) -> dict:
        """
        Extract state from header.
        Args:
            header: The header to extract state from.
        Returns:
            A dict of state.
        """
        if header is None:
            return {}
        state = {}
        members: List[str] = re.split(self._state_delimiter_pattern, header)
        for member in members:
            # empty members are valid, but no need to process further.
            if not member:
                continue
            match = self._state_member_pattern.fullmatch(member)
            if not match:
                logger.warning(
                    "Member doesn't match the w3c identifiers format {member}")
                return state
            groups: Tuple[str, ...] = match.groups()
            key, _eq, value = groups
            # duplicate keys are not legal in header
            if key in state:
                return state
            state[key] = value
        return state