Spaces:
Sleeping
Sleeping
File size: 9,766 Bytes
7a3c0ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 |
import flask
import weakref
from typing import Any, Callable, Collection
from time import time_ns
from timeit import default_timer
from importlib_metadata import version
from packaging import version as package_version
from aworld.trace.instrumentation import Instrumentor
from aworld.trace.base import Span, TraceProvider, TraceContext, Tracer, SpanType, get_tracer_provider
from aworld.metrics.metric import MetricType
from aworld.metrics.template import MetricTemplate
from aworld.logs.util import logger
from aworld.trace.instrumentation.http_util import (
collect_request_attributes,
url_disabled,
get_excluded_urls,
parse_excluded_urls,
HTTP_ROUTE
)
from aworld.trace.propagator import get_global_trace_propagator
from aworld.metrics.context_manager import MetricContext
from aworld.trace.propagator.carrier import ListTupleCarrier, DictCarrier
_ENVIRON_STARTTIME_KEY = "aworld-flask.starttime_key"
_ENVIRON_SPAN_KEY = "aworld-flask.span_key"
_ENVIRON_REQCTX_REF_KEY = "aworld-flask.reqctx_ref_key"
flask_version = version("flask")
if package_version.parse(flask_version) >= package_version.parse("2.2.0"):
def _request_ctx_ref() -> weakref.ReferenceType:
return weakref.ref(flask.globals.request_ctx._get_current_object())
else:
def _request_ctx_ref() -> weakref.ReferenceType:
return weakref.ref(flask._request_ctx_stack.top)
def _rewrapped_app(
wsgi_app,
active_requests_counter,
duration_histogram,
response_hook=None,
excluded_urls=None,
):
def _wrapped_app(wrapped_app_environ, start_response):
# We want to measure the time for route matching, etc.
# In theory, we could start the span here and use
# update_name later but that API is "highly discouraged" so
# we better avoid it.
wrapped_app_environ[_ENVIRON_STARTTIME_KEY] = time_ns()
start = default_timer()
attributes = collect_request_attributes(wrapped_app_environ)
if MetricContext.metric_initialized():
MetricContext.inc(active_requests_counter, 1, attributes)
request_route = None
def _start_response(status, response_headers, *args, **kwargs):
if flask.request and (
excluded_urls is None
or not url_disabled(flask.request.url, excluded_urls)
):
nonlocal request_route
request_route = flask.request.url_rule
span: Span = flask.request.environ.get(_ENVIRON_SPAN_KEY)
propagator = get_global_trace_propagator()
if propagator and span:
trace_context = TraceContext(
trace_id=span.get_trace_id(),
span_id=span.get_span_id()
)
propagator.inject(
trace_context, ListTupleCarrier(response_headers))
if span and span.is_recording():
status_code_str, _ = status.split(" ", 1)
try:
status_code = int(status_code_str)
except ValueError:
status_code = -1
span.set_attribute(
"http.response.status_code", status_code)
span.set_attributes(attributes)
if response_hook is not None:
response_hook(span, status, response_headers)
return start_response(status, response_headers, *args, **kwargs)
result = wsgi_app(wrapped_app_environ, _start_response)
duration_s = default_timer() - start
if MetricContext.metric_initialized():
MetricContext.histogram_record(
duration_histogram,
duration_s,
attributes
)
MetricContext.dec(active_requests_counter, 1, attributes)
return result
return _wrapped_app
def _wrapped_before_request(
request_hook=None,
tracer: Tracer = None,
excluded_urls=None
):
def _before_request():
if excluded_urls and url_disabled(flask.request.url, excluded_urls):
return
flask_request_environ = flask.request.environ
logger.info(
f"_wrapped_before_request flask_request_environ={flask_request_environ}")
attributes = collect_request_attributes(flask_request_environ)
if flask.request.url_rule:
# For 404 that result from no route found, etc, we
# don't have a url_rule.
attributes[HTTP_ROUTE] = flask.request.url_rule.rule
span_name = f"HTTP {flask.request.url_rule.rule}"
else:
span_name = f"HTTP {flask.request.url}"
propagator = get_global_trace_propagator()
trace_context = None
if propagator:
trace_context = propagator.extract(
DictCarrier(flask_request_environ))
logger.info(f"_wrapped_before_request trace_context={trace_context}")
span = tracer.start_span(
span_name,
SpanType.SERVER,
attributes=attributes,
start_time=flask_request_environ.get(_ENVIRON_STARTTIME_KEY),
trace_context=trace_context
)
if request_hook:
request_hook(span, flask_request_environ)
flask_request_environ[_ENVIRON_SPAN_KEY] = span
flask_request_environ[_ENVIRON_REQCTX_REF_KEY] = _request_ctx_ref()
return _before_request
def _wrapped_teardown_request(
excluded_urls=None,
):
def _teardown_request(exc):
if excluded_urls and url_disabled(flask.request.url, excluded_urls):
return
span: Span = flask.request.environ.get(_ENVIRON_SPAN_KEY)
original_reqctx_ref = flask.request.environ.get(
_ENVIRON_REQCTX_REF_KEY
)
current_reqctx_ref = _request_ctx_ref()
if not span or original_reqctx_ref != current_reqctx_ref:
return
if exc is None:
span.end()
else:
span.record_exception(exc)
span.end()
return _teardown_request
class _InstrumentedFlask(flask.Flask):
_excluded_urls = None
_tracer_provider: TraceProvider = None
_request_hook = None
_response_hook = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
tracer = self._tracer_provider.get_tracer(
"aworld.trace.instrumentation.flask")
duration_histogram = MetricTemplate(
type=MetricType.HISTOGRAM,
name="flask_request_duration_histogram",
description="Duration of flask HTTP server requests."
)
active_requests_counter = MetricTemplate(
type=MetricType.UPDOWNCOUNTER,
name="flask_active_request_counter",
unit="1",
description="Number of active HTTP server requests.",
)
self.wsgi_app = _rewrapped_app(
self.wsgi_app,
active_requests_counter,
duration_histogram,
_InstrumentedFlask._response_hook,
excluded_urls=_InstrumentedFlask._excluded_urls
)
_before_request = _wrapped_before_request(
_InstrumentedFlask._request_hook,
tracer,
excluded_urls=_InstrumentedFlask._excluded_urls
)
self._before_request = _before_request
self.before_request(_before_request)
_teardown_request = _wrapped_teardown_request(
excluded_urls=_InstrumentedFlask._excluded_urls,
)
self.teardown_request(_teardown_request)
class FlaskInstrumentor(Instrumentor):
def instrumentation_dependencies(self) -> Collection[str]:
return ("flask >= 1.0",)
def _instrument(self, **kwargs: Any):
logger.info("Flask _instrument entered.")
self._original_flask = flask.Flask
request_hook = kwargs.get("request_hook")
response_hook = kwargs.get("response_hook")
if callable(request_hook):
_InstrumentedFlask._request_hook = request_hook
if callable(response_hook):
_InstrumentedFlask._response_hook = response_hook
tracer_provider = kwargs.get("tracer_provider")
_InstrumentedFlask._tracer_provider = tracer_provider
excluded_urls = kwargs.get("excluded_urls")
_InstrumentedFlask._excluded_urls = (
get_excluded_urls("FLASK")
if excluded_urls is None
else parse_excluded_urls(excluded_urls)
)
flask.Flask = _InstrumentedFlask
logger.info("Flask _instrument exited.")
def _uninstrument(self, **kwargs):
flask.Flask = self._original_flask
def instrument_flask(excluded_urls: str = None,
request_hook: Callable = None,
response_hook: Callable = None,
tracer_provider: TraceProvider = None,
**kwargs: Any,
):
"""
Instrument the Flask application.
Args:
excluded_urls (str): A comma separated list of URLs to be excluded from instrumentation.
request_hook (Callable): A function to be called before a request is processed.
response_hook (Callable): A function to be called after a request is processed.
tracer_provider (TraceProvider): The trace provider to use.
"""
all_kwargs = {
"excluded_urls": excluded_urls,
"request_hook": request_hook,
"response_hook": response_hook,
"tracer_provider": tracer_provider or get_tracer_provider(),
**kwargs
}
FlaskInstrumentor().instrument(**all_kwargs)
logger.info("Flask instrumented.")
|