Spaces:
Sleeping
Sleeping
File size: 7,633 Bytes
7a3c0ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
from aworld.logs.util import logger
from aworld.trace.instrumentation.http_util import (
collect_attributes_from_request,
url_disabled,
get_excluded_urls,
parse_excluded_urls,
HTTP_RESPONSE_STATUS_CODE,
HTTP_FLAVOR
)
from aworld.metrics.context_manager import MetricContext
from aworld.metrics.template import MetricTemplate
from aworld.metrics.metric import MetricType
from aworld.trace.instrumentation import Instrumentor
from aworld.trace.propagator.carrier import DictCarrier
import functools
from timeit import default_timer
from requests import sessions
from requests.models import PreparedRequest, Response
from requests.structures import CaseInsensitiveDict
from typing import Collection, Any, Callable
from aworld.trace.base import TraceProvider, TraceContext, Tracer, SpanType, get_tracer_provider
from aworld.trace.propagator import get_global_trace_propagator
def _wrapped_send(
tracer: Tracer = None,
excluded_urls=None,
request_hook: Callable = None,
response_hook: Callable = None,
duration_histogram: MetricTemplate = None
):
oringinal_send = sessions.Session.send
@functools.wraps(oringinal_send)
def instrumented_send(
self: sessions.Session, request: PreparedRequest, **kwargs: Any
):
if excluded_urls and url_disabled(request.url, excluded_urls):
return oringinal_send(self, request, **kwargs)
def get_or_create_headers():
request.headers = (
request.headers
if request.headers is not None
else CaseInsensitiveDict()
)
return request.headers
method = request.method
if method is None:
method = "HTTP"
span_name = method.upper()
span_attributes = collect_attributes_from_request(request)
with tracer.start_as_current_span(
span_name, span_type=SpanType.CLIENT, attributes=span_attributes
) as span:
exception = None
if callable(request_hook):
request_hook(span, request)
headers = get_or_create_headers()
trace_context = TraceContext(
trace_id=span.get_trace_id(),
span_id=span.get_span_id(),
)
propagator = get_global_trace_propagator()
if propagator:
propagator.inject(trace_context, DictCarrier(headers))
start_time = default_timer()
try:
logger.info("Sending headers: %s", request.headers)
result = oringinal_send(
self, request, **kwargs
) # *** PROCEED
except Exception as exc: # pylint: disable=W0703
exception = exc
result = getattr(exc, "response", None)
finally:
elapsed_time = max(default_timer() - start_time, 0)
if isinstance(result, Response):
span_attributes = {}
span_attributes[HTTP_RESPONSE_STATUS_CODE] = result.status_code
if result.raw is not None:
version = getattr(result.raw, "version", None)
if version:
# Only HTTP/1 is supported by requests
version_text = "1.1" if version == 11 else "1.0"
span_attributes[HTTP_FLAVOR] = version_text
span.set_attributes(span_attributes)
if callable(response_hook):
response_hook(span, request, result)
if exception is not None:
span.record_exception(exception)
if duration_histogram is not None and MetricContext.metric_initialized():
MetricContext.histogram_record(
duration_histogram,
elapsed_time,
span_attributes
)
if exception is not None:
raise exception.with_traceback(exception.__traceback__)
return result
return instrumented_send
class _InstrumentedSession(sessions.Session):
"""
An instrumented requests.Session class.
"""
_excluded_urls = None
_tracer_provider: TraceProvider = None
_request_hook = None
_response_hook = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
tracer = self._tracer_provider.get_tracer(
"aworld.trace.instrumentation.requests")
excluded_urls = kwargs.get("excluded_urls")
duration_histogram = MetricTemplate(
type=MetricType.HISTOGRAM,
name="client_request_duration_histogram",
unit="s",
description="Duration of HTTP client requests."
)
self.send = functools.partial(_wrapped_send(
tracer=tracer,
excluded_urls=excluded_urls,
request_hook=self._request_hook,
response_hook=self._response_hook,
duration_histogram=duration_histogram
), self)
class RequestsInstrumentor(Instrumentor):
"""
An instrumentor for the requests module.
"""
def instrumentation_dependencies(self) -> Collection[str]:
return ["requests"]
def _instrument(self, **kwargs):
"""
Instruments the requests module.
"""
logger.info("requests _instrument entered.")
self._original_session = sessions.Session
request_hook = kwargs.get("request_hook")
response_hook = kwargs.get("response_hook")
if callable(request_hook):
_InstrumentedSession._request_hook = request_hook
if callable(response_hook):
_InstrumentedSession._response_hook = response_hook
tracer_provider = kwargs.get("tracer_provider")
_InstrumentedSession._tracer_provider = tracer_provider
excluded_urls = kwargs.get("excluded_urls")
_InstrumentedSession._excluded_urls = (
get_excluded_urls("FLASK")
if excluded_urls is None
else parse_excluded_urls(excluded_urls)
)
sessions.Session = _InstrumentedSession
logger.info("requests _instrument exited.")
def _uninstrument(self, **kwargs):
"""
Uninstruments the requests module.
"""
sessions.Session = self._original_session
def instrument_requests(excluded_urls: str = None,
request_hook: Callable = None,
response_hook: Callable = None,
tracer_provider: TraceProvider = None,
**kwargs: Any,
):
"""
Instruments the requests module.
Args:
excluded_urls: A comma separated list of URLs to exclude from tracing.
request_hook: A function that will be called before a request is sent.
The function will be called with the span and the request.
response_hook: A function that will be called after a response is received.
The function will be called with the span and the response.
tracer_provider: The tracer provider to use. If not provided, the global
tracer provider will be used.
kwargs: Additional keyword arguments.
"""
all_kwargs = {
"excluded_urls": excluded_urls,
"request_hook": request_hook,
"response_hook": response_hook,
"tracer_provider": tracer_provider or get_tracer_provider(),
**kwargs
}
RequestsInstrumentor().instrument(**all_kwargs)
logger.info("Requests instrumented.")
|