raw2logit / utils /debug.py
Luis Oala
fix aws access
d9c7582
raw
history blame
13.1 kB
import torch
import numpy as np
import inspect
from functools import reduce, wraps
from collections.abc import Iterable
from IPython import embed
try:
get_ipython() # pylint: disable=undefined-variable
interactive_notebook = True
except:
interactive_notebook = False
_NONE = "__UNSET_VARIABLE__"
def debug_init():
debug.disable = False
debug.silent = False
debug.verbose = 2
debug.expand_ignore = ["DataLoader", "Dataset", "Subset"]
debug.max_expand = 10
debug.show_tensor = False
debug.raise_exception = True
debug.full_stack = True
debug.restore_defaults_on_exception = not interactive_notebook
debug._indent = 0
debug._stack = ""
debug.embed = embed
debug.show = debug_show
debug.pause = debug_pause
def debug_pause():
input("Press Enter to continue...")
def debug(*args, assert_true=False):
"""Decorator for debugging functions and tensors.
Will throw an exception as soon as a nan is encountered.
If used on iterables, these will be expanded and also searched for nans.
Usage:
debug(x)
Or:
@debug
def function():
...
If used as a function wrapper, all arguments will be searched and printed.
"""
single_arg = len(args) == 1
if debug.disable:
return args[0] if single_arg else None
try:
call_line = ''.join(inspect.stack()[1][4]).strip()
except:
call_line = '...'
used_as_wrapper = 'def ' == call_line[:4]
expect_return_arg = single_arg and 'debug' in call_line and call_line.split('debug')[0].strip() != ''
is_func = single_arg and hasattr(args[0], '__call__')
if is_func and (used_as_wrapper or expect_return_arg):
func = args[0]
sig_parameters = inspect.signature(func).parameters
sig_argnames = [p.name for p in sig_parameters.values()]
sig_defaults = {
k: v.default
for k, v in sig_parameters.items()
if v.default is not inspect.Parameter.empty
}
@wraps(func)
def _func(*args, **kwargs):
if debug.disable:
return func(*args, **kwargs)
if debug._indent == 0:
debug._stack = ""
stack_before = debug._stack
indent = ' ' * 4 * debug._indent
debug._indent += 1
args_kw = dict(zip(sig_argnames, args))
defaults = {k: v for k, v in sig_defaults.items()
if k not in kwargs
if k not in args_kw}
all_args = {**args_kw, **kwargs, **defaults}
func_name = None
if hasattr(func, '__name__'):
func_name = func.__name__
elif hasattr(func, '__class__'):
func_name = func.__class__.__name__
if func_name is None:
func_name = '... ' + call_line + '...'
else:
func_name = '@' + func_name + '()'
_debug_log('', indent=indent)
_debug_log(func_name, indent=indent)
debug._last_call = func
debug._last_args = all_args
debug._last_args_sig = sig_argnames
for argtype, params in [("args", args_kw.items()),
("kwargs", kwargs.items()),
("defaults", defaults.items())]:
if params:
_debug_log(f"{argtype}:", indent=indent + ' ' * 6)
for argname, arg in params:
if argname == 'self':
# _debug_log(f"- self: ...", indent=indent + ' ' * 8)
pass
else:
_debug_log(f"- {argname}: ", arg, indent + ' ' * 8, assert_true)
try:
out = func(*args, **kwargs)
except:
_debug_crash_save()
debug._stack = ""
debug._indent = 0
raise
debug.out = out
_debug_log("returned: ", out, indent, assert_true)
_debug_log('', indent=indent)
debug._indent -= 1
if not debug.full_stack:
debug._stack = stack_before
return out
return _func
else:
if debug._indent == 0:
debug._stack = ""
argname = ')'.join('('.join(call_line.split('(')[1:]).split(')')[:-1])
if assert_true:
argname = ','.join(argname.split(',')[:-1])
_debug_log(f"assert{{{argname}}} ", args[0], ' ' * 4 * debug._indent, assert_true)
else:
for arg in args:
_debug_log(f"{{{argname}}} = ", arg, ' ' * 4 * debug._indent, assert_true)
if expect_return_arg:
return args[0]
return
def is_iterable(x):
return isinstance(x, Iterable) or hasattr(x, '__getitem__') and not isinstance(x, str)
def ndarray_repr(t, assert_all=False):
exception_encountered = False
info = []
shape = tuple(t.shape)
single_entry = shape == () or shape == (1,)
if single_entry:
info.append(f"[{t.item():.4f}]")
else:
info.append(f"({', '.join(map(repr, shape))})")
invalid_sum = (~np.isfinite(t)).sum().item()
if invalid_sum:
info.append(
f"{invalid_sum} INVALID ENTR{'Y' if invalid_sum == 1 else 'IES'}")
exception_encountered = True
if debug.verbose > 1:
if not invalid_sum and not single_entry:
info.append(f"|x|={np.linalg.norm(t):.1f}")
if t.size:
info.append(f"x in [{t.min():.1f}, {t.max():.1f}]")
if debug.verbose and t.dtype != np.float:
info.append(f"dtype={str(t.dtype)}".replace("'", ''))
if assert_all:
assert_val = t.all()
if not assert_val:
exception_encountered = True
if assert_all and not exception_encountered:
output = "passed"
else:
if assert_all and not assert_val:
output = f"ndarray({info[0]})"
else:
output = f"ndarray({', '.join(info)})"
if exception_encountered and (not hasattr(debug, 'raise_exception') or debug.raise_exception):
if debug.restore_defaults_on_exception:
debug.raise_exception = False
debug.silent = False
debug.x = t
msg = output
debug._stack += output
if debug._stack and '\n' in debug._stack:
msg += '\nSTACK: ' + debug._stack
if assert_all:
assert assert_val, "Assert did not pass on " + msg
raise Exception("Invalid entries encountered in " + msg)
return output
def tensor_repr(t, assert_all=False):
exception_encountered = False
info = []
shape = tuple(t.shape)
single_entry = shape == () or shape == (1,)
if single_entry:
info.append(f"[{t.item():.3f}]")
else:
info.append(f"({', '.join(map(repr, shape))})")
invalid_sum = (~torch.isfinite(t)).sum().item()
if invalid_sum:
info.append(
f"{invalid_sum} INVALID ENTR{'Y' if invalid_sum == 1 else 'IES'}")
exception_encountered = True
if debug.verbose and t.requires_grad:
info.append('req_grad')
if debug.verbose > 2:
if t.is_leaf:
info.append('leaf')
if hasattr(t, 'retains_grad') and t.retains_grad:
info.append('retains_grad')
has_grad = (t.is_leaf or hasattr(t, 'retains_grad') and t.retains_grad) and t.grad is not None
if has_grad:
grad_invalid_sum = (~torch.isfinite(t.grad)).sum().item()
if grad_invalid_sum:
info.append(
f"GRAD {grad_invalid_sum} INVALID ENTR{'Y' if grad_invalid_sum == 1 else 'IES'}")
exception_encountered = True
if debug.verbose > 1:
if not invalid_sum and not single_entry:
info.append(f"|x|={t.float().norm():.1f}")
if t.numel():
info.append(f"x in [{t.min():.2f}, {t.max():.2f}]")
if has_grad and not grad_invalid_sum:
if single_entry:
info.append(f"grad={t.grad.float().item():.3f}")
else:
info.append(f"|grad|={t.grad.float().norm():.1f}")
if debug.verbose and t.dtype != torch.float:
info.append(f"dtype={str(t.dtype).split('.')[-1]}")
if debug.verbose and t.device.type != 'cpu':
info.append(f"device={t.device.type}")
if assert_all:
assert_val = t.all()
if not assert_val:
exception_encountered = True
if assert_all and not exception_encountered:
output = "passed"
else:
if assert_all and not assert_val:
output = f"tensor({info[0]})"
else:
output = f"tensor({', '.join(info)})"
if exception_encountered and (not hasattr(debug, 'raise_exception') or debug.raise_exception):
if debug.restore_defaults_on_exception:
debug.raise_exception = False
debug.silent = False
debug.x = t
msg = output
debug._stack += output
if debug._stack and '\n' in debug._stack:
msg += '\nSTACK: ' + debug._stack
if assert_all:
assert assert_val, "Assert did not pass on " + msg
raise Exception("Invalid entries encountered in " + msg)
return output
def _debug_crash_save():
if debug._indent:
debug.args = debug._last_args
debug.func = debug._last_call
@wraps(debug.func)
def _recall(*args, **kwargs):
call_args = {**debug.args, **kwargs, **dict(zip(debug._last_args_sig, args))}
return debug(debug.func)(**call_args)
def print_stack(stack=debug._stack):
print('\nSTACK: ' + stack)
debug.stack = print_stack
debug.recall = _recall
debug._indent = 0
def _debug_log(output, var=_NONE, indent='', assert_true=False, expand=True):
debug._stack += indent + output
if not debug.silent:
print(indent + output, end='')
if var is not _NONE:
type_str = type(var).__name__.lower()
if var is None:
_debug_log('None')
elif isinstance(var, str):
_debug_log(f"'{var}'")
elif type_str == 'ndarray':
_debug_log(ndarray_repr(var, assert_true))
if debug.show_tensor:
_debug_show_print(var, indent=indent + 4 * ' ')
# elif type_str in ('tensor', 'parameter'):
elif type_str == 'tensor':
_debug_log(tensor_repr(var, assert_true))
if debug.show_tensor:
_debug_show_print(var, indent=indent + 4 * ' ')
elif hasattr(var, 'named_parameters'):
_debug_log(type_str)
params = list(var.named_parameters())
_debug_log(f"{type_str}[{len(params)}] {{")
for k, v in params:
_debug_log(f"'{k}': ", v, indent + 6 * ' ')
_debug_log(indent + 4 * ' ' + '}')
elif is_iterable(var):
expand = debug.expand_ignore != '*' and expand
if expand:
if isinstance(debug.expand_ignore, str):
if type_str == str(debug.expand_ignore).lower():
expand = False
elif is_iterable(debug.expand_ignore):
for ignore in debug.expand_ignore:
if type_str == ignore.lower():
expand = False
if hasattr(var, '__len__'):
length = len(var)
else:
var = list(var)
length = len(var)
if expand and length > 0:
_debug_log(f"{type_str}[{length}] {{")
if isinstance(var, dict):
for k, v in var.items():
_debug_log(f"'{k}': ", v, indent + 6 * ' ', assert_true)
else:
i = 0
for k, i in zip(var, range(debug.max_expand)):
_debug_log('- ', k, indent + 6 * ' ', assert_true)
if i < length - 1:
_debug_log('- ' + ' ' * 6 + '...', indent=indent + 6 * ' ')
_debug_log(indent + 4 * ' ' + '}')
else:
_debug_log(f"{type_str}[{length}]")
else:
_debug_log(str(var))
else:
debug._stack += '\n'
if not debug.silent:
print()
def debug_show(x):
assert is_iterable(x)
debug(x)
_debug_show_print(x, indent=' ' * 4 * debug._indent)
def _debug_show_print(x, indent=''):
is_tensor = type(x).__name__ in ('Tensor', 'ndarray')
if is_tensor:
x = x.flatten()
if type(x).__name__ == 'Tensor' and x.dim() == 0:
return
n_samples = min(10, len(x))
di = len(x) // n_samples
var = list(x[i * di] for i in range(n_samples))
if is_tensor or type(var[0]) == float:
var = [round(float(v), 4) for v in var]
_debug_log('--> ', str(var), indent, expand=False)
debug_init()