|
import torch |
|
import warnings |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
symbolic_assert = torch._assert |
|
except AttributeError: |
|
symbolic_assert = torch.Assert |
|
|
|
|
|
|
|
|
|
class suppress_tracer_warnings(warnings.catch_warnings): |
|
def __enter__(self): |
|
super().__enter__() |
|
warnings.simplefilter('ignore', category=torch.jit.TracerWarning) |
|
return self |
|
|
|
|
|
|
|
|
|
|
|
|
|
def assert_shape(tensor, ref_shape): |
|
if tensor.ndim != len(ref_shape): |
|
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') |
|
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): |
|
if ref_size is None: |
|
pass |
|
elif isinstance(ref_size, torch.Tensor): |
|
with suppress_tracer_warnings(): |
|
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') |
|
elif isinstance(size, torch.Tensor): |
|
with suppress_tracer_warnings(): |
|
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') |
|
elif size != ref_size: |
|
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') |
|
|