|
import time |
|
import sys |
|
import os |
|
from logging import StreamHandler |
|
from abc import ABC |
|
from logging import Logger, NOTSET, LogRecord, Filter, Formatter, Handler |
|
from typing import Optional, Union |
|
|
|
from aworld.trace.base import get_tracer_provider_silent, Tracer, AttributeValueType |
|
|
|
TRACE_LOG_FORMAT = '%(asctime)s - [%(trace_id)s] - [%(span_id)s] - %(name)s - %(levelname)s - %(message)s' |
|
SPECIAL_TRACE_LOG_FORMAT = '%(asctime)s - [trace_%(trace_id)s] - [%(span_id)s] - %(name)s - %(levelname)s - %(message)s' |
|
|
|
|
|
class LoggerProvider(ABC): |
|
"""A logger provider is a factory for loggers.""" |
|
|
|
|
|
_GLOBAL_LOG_PROVIDER: Optional[LoggerProvider] = None |
|
|
|
|
|
def set_log_provider(provider: str = "otlp", |
|
backend: str = "logfire", |
|
base_url: str = None, |
|
write_token: str = None, |
|
**kwargs): |
|
"""Set the global log provider.""" |
|
|
|
global _GLOBAL_LOG_PROVIDER |
|
|
|
if provider == "otlp": |
|
from .opentelemetry.otlp_log import OTLPLoggerProvider |
|
_GLOBAL_LOG_PROVIDER = OTLPLoggerProvider(backend=backend, |
|
base_url=base_url, |
|
write_token=write_token, |
|
**kwargs) |
|
|
|
|
|
def get_log_provider() -> LoggerProvider: |
|
""" |
|
Get the global log provider. |
|
""" |
|
global _GLOBAL_LOG_PROVIDER |
|
if _GLOBAL_LOG_PROVIDER is None: |
|
raise ValueError("No log provider has been set.") |
|
return _GLOBAL_LOG_PROVIDER |
|
|
|
|
|
def instrument_logging(logger: Logger, level: Union[int, str] = NOTSET) -> None: |
|
"""Instrument the logger.""" |
|
for handler in logger.root.handlers: |
|
if not any(isinstance(filter, TraceLoggingFilter) for filter in handler.filters): |
|
handler.setFormatter(Formatter(TRACE_LOG_FORMAT)) |
|
handler.addFilter(TraceLoggingFilter()) |
|
|
|
if not logger.handlers: |
|
print("No handlers found, adding a StreamHandler. logger=", logger.name) |
|
handler = StreamHandler() |
|
handler.setFormatter(Formatter(SPECIAL_TRACE_LOG_FORMAT)) |
|
handler.addFilter(TraceLoggingFilter()) |
|
logger.addHandler(handler) |
|
else: |
|
for handler in logger.handlers: |
|
if not any(isinstance(filter, TraceLoggingFilter) for filter in handler.filters): |
|
handler.setFormatter(Formatter(SPECIAL_TRACE_LOG_FORMAT)) |
|
handler.addFilter(TraceLoggingFilter()) |
|
logger.propagate = False |
|
logger.addHandler(TraceLogginHandler(level)) |
|
|
|
|
|
class TraceLoggingFilter(Filter): |
|
""" |
|
A filter that adds trace information to log records. |
|
""" |
|
|
|
def filter(self, record: LogRecord) -> bool: |
|
""" |
|
Add trace information to the log record. |
|
""" |
|
trace = get_tracer_provider_silent() |
|
if trace: |
|
span = trace.get_current_span() |
|
record.trace_id = span.get_trace_id() if span else None |
|
record.span_id = span.get_span_id() if span else None |
|
return True |
|
|
|
|
|
class TraceLogginHandler(Handler): |
|
""" |
|
A handler class which writes logging records, appropriately formatted, |
|
to a stream. Note that this class does not close the stream, as |
|
sys.stdout or sys.stderr may be used. |
|
""" |
|
@staticmethod |
|
def strip_color(text: str) -> str: |
|
"""Remove ANSI color codes from text""" |
|
import re |
|
return re.sub(r'\033\[[0-9;]*m', '', text) |
|
|
|
def __init__(self, |
|
level: Union[int, str] = NOTSET, |
|
tracer_name: str = "aworld.log") -> None: |
|
"""Initialize the handler.""" |
|
super().__init__(level=level) |
|
self._tracer_name = tracer_name |
|
self._tracer: Tracer = None |
|
|
|
def emit(self, record: LogRecord) -> None: |
|
"""Emit a record.""" |
|
trace = get_tracer_provider_silent() |
|
if not trace or not trace.get_current_span() or not trace.get_current_span().is_recording(): |
|
return |
|
|
|
if not self._tracer: |
|
self._tracer = trace.get_tracer(name=self._tracer_name) |
|
|
|
try: |
|
f = sys._getframe() |
|
while f: |
|
if 'logging/__init__.py' in f.f_code.co_filename or \ |
|
f.f_code.co_filename.startswith(os.path.dirname(__file__)): |
|
f = f.f_back |
|
else: |
|
break |
|
|
|
origin_msg = record.msg |
|
raw_msg = None |
|
if f: |
|
try: |
|
import linecache |
|
line = linecache.getline(f.f_code.co_filename, f.f_lineno) |
|
if 'logger.' in line: |
|
raw_msg = line.split('logger.', 1)[1].split( |
|
'(', 1)[1].split(')', 1)[0].strip() |
|
except: |
|
pass |
|
record.msg = self.strip_color(record.msg) |
|
msg_template = raw_msg if raw_msg else record.msg |
|
|
|
if len(msg_template) > 255: |
|
msg_template = msg_template[:255] + '...' |
|
|
|
attributes = { |
|
'code.filepath': f.f_code.co_filename if f else record.pathname, |
|
'code.lineno': f.f_lineno if f else record.lineno, |
|
'code.function': f.f_code.co_name if f else record.funcName, |
|
'log.level': record.levelname, |
|
'log.logger': record.name, |
|
'log.message': self.format(record), |
|
} |
|
record.msg = origin_msg |
|
self._create_span( |
|
span_name=msg_template, |
|
attributes=attributes, |
|
exc_info=record.exc_info, |
|
) |
|
except RecursionError: |
|
raise |
|
except Exception: |
|
self.handleError(record) |
|
|
|
def _create_span(self, |
|
span_name: str, |
|
attributes: dict[str, AttributeValueType] = None, |
|
exc_info: BaseException = None): |
|
start_time = time.time_ns() |
|
span = self._tracer.start_span( |
|
name=span_name, |
|
attributes=attributes, |
|
start_time=start_time, |
|
) |
|
if exc_info: |
|
span.record_exception(exception=exc_info, timestamp=start_time) |
|
span.end() |
|
|