Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This file from the xFormers repo is just a example of how to implement | |
| # probing of the activations of a model, without changing anything. | |
| # By default, the linear inputs/outputs/gradients are logged, as well as | |
| # the attention logits+entropy. It is possible to log an additional tensor, eg: | |
| # x = log_stats(x, "name") | |
| # | |
| # Known limitations: | |
| # * Only a subset of the attention biases is supported | |
| # * Torch-compile is disabled automatically when this is enabled | |
| # * Only tested with bf16/f16/f32 datatypes | |
| import contextlib | |
| import functools | |
| import json | |
| import math | |
| import os | |
| import uuid | |
| from collections import defaultdict | |
| from enum import Enum | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | |
| CheckpointImpl, | |
| checkpoint_wrapper, | |
| ) | |
| from torch.fx.operator_schemas import normalize_function | |
| from torch.nn.attention import SDPBackend, sdpa_kernel | |
| from torch.utils._python_dispatch import TorchDispatchMode | |
| from torch.utils._pytree import tree_map | |
| from torch.utils.module_tracker import ModuleTracker | |
| from xformers.ops import fmha | |
| def _log(x: torch.Tensor, name: str, uid: str) -> None: | |
| pass | |
| def _log_fake(x: torch.Tensor, name: str, uid: str) -> None: | |
| pass | |
| class _LogStats(torch.autograd.Function): | |
| def forward(ctx, x: torch.Tensor, name: str): | |
| uid = str(uuid.uuid4()) | |
| torch.ops.torchprobe.log(x, name, uid) | |
| ctx.name = name | |
| ctx.uid = uid | |
| return x | |
| def backward(ctx, grad: torch.Tensor): | |
| torch.ops.torchprobe.log(grad, f"{ctx.name}.g", ctx.uid) | |
| return grad, None | |
| _PROBING_ENABLED = False | |
| def log_stats(x: torch.Tensor, name: str) -> torch.Tensor: | |
| if not _PROBING_ENABLED: | |
| return x | |
| return _LogStats.apply(x, name) | |
| QUANTILES = [ | |
| 0.0000001, | |
| 0.000001, | |
| 0.00001, | |
| 0.0001, | |
| 0.001, | |
| 0.01, | |
| 0.05, | |
| 0.1, | |
| 0.3, | |
| 0.5, | |
| 0.7, | |
| 0.9, | |
| 0.95, | |
| 0.99, | |
| 0.999, | |
| 0.9999, | |
| 0.99999, | |
| 0.999999, | |
| 0.9999999, | |
| ] | |
| def _get_quantiles(device: torch.device, dtype) -> torch.Tensor: | |
| return torch.tensor(QUANTILES, device=device, dtype=dtype) | |
| def _get_stats(x_: torch.Tensor, remove_inf=False) -> Dict[str, Any]: | |
| if x_.dtype not in [torch.float, torch.double, torch.float16, torch.bfloat16]: | |
| return {} | |
| x = x_.flatten() | |
| if remove_inf: | |
| x = x[x.abs() < float("inf")] | |
| if x.dtype is not torch.double: | |
| x = x.float() | |
| xabs = x.abs() | |
| quantiles = _get_quantiles(x.device, x.dtype) | |
| mean = x.mean() | |
| std = x.std() | |
| return { | |
| "shape": tuple(x_.shape), | |
| "mean": mean, | |
| "std": std, | |
| "skew": (((x - mean) / std) ** 3).double().mean(), | |
| "kurtosis": (((x - mean) / std) ** 4).double().mean(), | |
| "abs.mean": xabs.mean(), | |
| "max": x.max(), | |
| "min": x.min(), | |
| # Note: `quantile` takes at most 2**24 elements, see | |
| # https://github.com/pytorch/pytorch/issues/64947 | |
| "quantiles": torch.quantile(x[: 2**24], quantiles), | |
| } | |
| def _mask_attn_causal_inplace(logits: torch.Tensor, q_idx, q_len, kv_len) -> None: | |
| assert logits.ndim == 4 | |
| logits[:, :, :, q_idx + kv_len - q_len + 1 :] = -math.inf | |
| def _mask_attn_logits( | |
| logits: torch.Tensor, | |
| q_idx: List[int], | |
| *, | |
| causal: bool, | |
| cu_seqlens_q: Optional[torch.Tensor] = None, | |
| cu_seqlens_k: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| assert logits.dtype is torch.float32 | |
| # Handle BlockDiagonalMask | |
| if cu_seqlens_q is not None: | |
| assert cu_seqlens_k is not None | |
| # Expect BHMqMkv | |
| assert logits.ndim == 4, logits.shape | |
| qs = cu_seqlens_q.tolist() | |
| ks = cu_seqlens_k.tolist() | |
| q_batchid = [] | |
| k_batchid = [-2] * logits.shape[-1] | |
| q_idx_i = 0 | |
| for bid, (q0, q1, k0, k1) in enumerate(zip(qs, qs[1:], ks, ks[1:])): | |
| for k in range(k0, k1): | |
| k_batchid[k] = bid | |
| while q_idx_i < len(q_idx) and q_idx[q_idx_i] < q1: | |
| q_batchid.append(bid) | |
| if causal: | |
| _mask_attn_causal_inplace( | |
| logits[:, :, q_idx_i : q_idx_i + 1, k0:k1], | |
| q_idx[q_idx_i] - q0, | |
| q1 - q0, | |
| k1 - k0, | |
| ) | |
| q_idx_i += 1 | |
| mask_out = ( | |
| torch.tensor(q_batchid, device=logits.device)[None, None, :, None] | |
| != torch.tensor(k_batchid, device=logits.device)[None, None, None, :] | |
| ) | |
| logits[mask_out.expand_as(logits)] = -math.inf | |
| assert q_idx_i == len(q_idx) | |
| elif causal: | |
| for q_idx_i in range(len(q_idx)): | |
| _mask_attn_causal_inplace( | |
| logits[:, :, q_idx_i : q_idx_i + 1, :], | |
| q_idx[q_idx_i], | |
| logits.shape[2], | |
| logits.shape[3], | |
| ) | |
| return logits | |
| def _attn_queries_subset(num_queries: int) -> List[int]: | |
| return list(range(0, num_queries, max(1, num_queries // 128))) | |
| def _compute_attn_stats_sdpa( | |
| probe, | |
| path: str, | |
| # supports arguments both cudnn + flash backends | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attn_mask=None, | |
| attn_bias=None, | |
| dropout_p=0.0, | |
| is_causal=False, | |
| scale=None, | |
| compute_log_sumexp=True, | |
| return_debug_mask=False, | |
| **kwargs, | |
| ): | |
| if scale is None: | |
| scale = 1 / (query.shape[-1] ** 0.5) | |
| # Filter-out not supported cases | |
| if attn_mask is not None or attn_bias is not None or dropout_p != 0.0 or kwargs: | |
| probe.store[f"{path}::attn"] = { | |
| "query.shape": tuple(query.shape), | |
| "key.shape": tuple(key.shape), | |
| "value.shape": tuple(value.shape), | |
| "attn_mask": attn_mask.shape if attn_mask is not None else None, | |
| "dropout_p": dropout_p, | |
| "is_causal": is_causal, | |
| "scale": scale, | |
| "unk_kwargs": list(kwargs.keys()), | |
| } | |
| return | |
| # Take a subset of the queries and compute the logits | |
| query_s = _attn_queries_subset(query.shape[-2]) | |
| logits = query[:, :, query_s] @ key.transpose(-1, -2) * scale | |
| logits = _mask_attn_logits(logits.float(), query_s, causal=is_causal) | |
| p = logits.float().softmax(-1) | |
| masked_logsoft = logits.log_softmax(-1).where( | |
| (logits > -math.inf), torch.zeros_like(logits) | |
| ) | |
| entropy = -(p * masked_logsoft).sum(-1) | |
| probe.log_tensor(f"{path}::attn_entropy", entropy) | |
| probe.log_tensor(f"{path}::attn_logits", logits, remove_inf=True) | |
| def _compute_attn_stats_flash( | |
| probe, | |
| path: str, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| cu_seqlens_q: Optional[torch.Tensor], | |
| cu_seqlens_k: Optional[torch.Tensor], | |
| seqused_k: Optional[torch.Tensor], | |
| max_seqlen_q: int, | |
| max_seqlen_k: int, | |
| p: float, | |
| softmax_scale: float, | |
| is_causal: bool, | |
| window_left: int, | |
| window_right: int, | |
| return_softmax: bool, | |
| block_tables: Optional[torch.Tensor], | |
| unpadded_lse: bool = False, | |
| ) -> None: | |
| # Filter-out not supported cases | |
| if ( | |
| seqused_k is not None | |
| or p != 0.0 | |
| or window_left >= 0 | |
| or window_right >= 0 | |
| or block_tables is not None | |
| ): | |
| probe.store[f"{path}::attn"] = { | |
| "query.shape": tuple(query.shape), | |
| "key.shape": tuple(key.shape), | |
| "value.shape": tuple(value.shape), | |
| "op": "flash", | |
| } | |
| return | |
| if cu_seqlens_q is not None: | |
| assert query.ndim == 3, query.shape | |
| query, key, value = query[None], key[None], value[None] | |
| assert query.ndim == 4, query.shape | |
| # Take a subset of the queries and compute the logits | |
| query_s = _attn_queries_subset(query.shape[1]) | |
| logits = ( | |
| query[:, query_s].transpose(1, 2) | |
| * softmax_scale | |
| ) | |
| logits = _mask_attn_logits( | |
| logits.float(), | |
| query_s, | |
| cu_seqlens_q=cu_seqlens_q, | |
| cu_seqlens_k=cu_seqlens_k, | |
| causal=is_causal, | |
| ) | |
| p = logits.float().softmax(-1) | |
| masked_logsoft = logits.log_softmax(-1).where( | |
| (logits > -math.inf), torch.zeros_like(logits) | |
| ) | |
| entropy = -(p * masked_logsoft).sum(-1) | |
| probe.log_tensor(f"{path}::attn_entropy", entropy) | |
| probe.log_tensor(f"{path}::attn_logits", logits, remove_inf=True) | |
| def _tensors_to_python(x): | |
| if not isinstance(x, torch.Tensor): | |
| return x | |
| return x.tolist() | |
| # class syntax | |
| class LinearBwType(Enum): | |
| DW = 1 | |
| DX = 2 | |
| UNKNOWN = 3 | |
| class AutoProbeD(TorchDispatchMode): | |
| def __init__(self, module: nn.Module, write_file: Optional[str] = None) -> None: | |
| self.write_file = Path(write_file) if write_file is not None else None | |
| self.write_tensors_tmpdir: Optional[Path] = None | |
| self.compile_disabler = TorchCompileDisabler(module) | |
| self.mod_tracker = ModuleTracker() | |
| self.count_per_path: Dict[str, int] = defaultdict(int) | |
| self.store: Dict[str, Dict[str, Any]] = {} | |
| self.linear_data: Dict[str, Tuple[Any, Any, Any, Any, Any]] = {} | |
| self.uid_to_path: Dict[str, str] = {} | |
| self.metadata: Any = None | |
| self.enabled = False | |
| self.verbose = bool(int(os.environ.get("PROBE_VERBOSE", "0"))) | |
| def __enter__(self): | |
| global _PROBING_ENABLED | |
| assert not self.enabled, "Entered probe twice" | |
| self.compile_disabler.__enter__() | |
| self.mod_tracker.__enter__() | |
| super().__enter__() | |
| self.enabled = True | |
| _PROBING_ENABLED = True | |
| # self._setup_tensors_logging() | |
| return self | |
| def __exit__(self, *args) -> None: | |
| global _PROBING_ENABLED | |
| assert self.enabled, "Exiting probe without entering it" | |
| super().__exit__(*args) | |
| self.mod_tracker.__exit__(*args) | |
| self.compile_disabler.__exit__(*args) | |
| self._flush_and_clear() | |
| _PROBING_ENABLED = False | |
| self.enabled = False | |
| def _setup_tensors_logging(self): | |
| if self.write_file is not None: | |
| self.write_file.parent.mkdir(exist_ok=True) | |
| self.write_tensors_tmpdir = ( | |
| self.write_file.parent | |
| / f"{self.write_file.name}-tmp-{str(uuid.uuid4())[:8]}" | |
| ) | |
| self.write_tensors_tmpdir.mkdir(exist_ok=True) | |
| def _flush_and_clear(self) -> None: | |
| if self.write_file is not None: | |
| dump_data = tree_map(_tensors_to_python, self.store) | |
| with self.write_file.open("a") as fd: | |
| json.dump( | |
| { | |
| "data": dump_data, | |
| "meta": self.metadata, | |
| "version": 2, | |
| "quantiles": QUANTILES, | |
| }, | |
| fd, | |
| ) | |
| fd.write("\n") | |
| if self.write_tensors_tmpdir is not None: | |
| assert self.write_file is not None | |
| dump_dir = self.write_tensors_tmpdir.parent / f"{self.write_file.name}-dump" | |
| dump_dir.mkdir(exist_ok=True) | |
| dir_name = "" | |
| if "it" in self.metadata: | |
| dir_name = f"it{int(self.metadata['it']):010}" | |
| if dir_name == "" or (dump_dir / dir_name).exists(): | |
| num_files = len(list(dump_dir.glob(f"{dir_name}v*"))) | |
| dir_name = f"{dir_name}v{num_files}" | |
| dump_dir = dump_dir / dir_name | |
| assert not dump_dir.exists() | |
| self.write_tensors_tmpdir.rename(dump_dir) | |
| self.write_tensors_tmpdir = None | |
| self.store.clear() | |
| self.count_per_path.clear() | |
| self.uid_to_path.clear() | |
| def _find_bw_path_and_type( | |
| self, path: str, out: torch.Tensor, args | |
| ) -> Tuple[str, LinearBwType]: | |
| """ | |
| We are in the BW pass, and process a GEMM. | |
| Let's figure out: | |
| (1) The path for the FW pass (might differ in case of ModuleTracker bug) | |
| (2) The type of BW pass (eg `dw` or `dx`) | |
| """ | |
| def _is_path_correct_dw(path: str) -> bool: | |
| # dW.t = dY.t @ X | |
| in_shape, w_shape, out_shape, input_sm, weight_sm = self.linear_data[path] | |
| return out.shape == (w_shape[1], w_shape[0]) and torch.allclose( | |
| input_sm, args[1][:4, :4] | |
| ) | |
| def _is_path_correct_dx(path: str) -> bool: | |
| # dX = dY @ W.t | |
| in_shape, w_shape, out_shape, input_sm, weight_sm = self.linear_data[path] | |
| return out.shape == in_shape and torch.allclose(weight_sm, args[1][:4, :4]) | |
| if path in self.linear_data: | |
| if _is_path_correct_dw(path): | |
| return path, LinearBwType.DW | |
| if _is_path_correct_dx(path): | |
| return path, LinearBwType.DX | |
| for candidate_path in self.mod_tracker.parents: | |
| if candidate_path not in self.linear_data: | |
| continue | |
| if _is_path_correct_dw(candidate_path): | |
| return candidate_path, LinearBwType.DW | |
| if _is_path_correct_dx(candidate_path): | |
| return candidate_path, LinearBwType.DX | |
| return path, LinearBwType.UNKNOWN | |
| def log_tensor(self, name: str, x: torch.Tensor, **kwargs) -> None: | |
| self.store[name] = _get_stats(x, **kwargs) | |
| if self.write_tensors_tmpdir is not None: | |
| name_safe = name.replace("::", "__").replace("/", "") | |
| torch.save(x, self.write_tensors_tmpdir / f"{name_safe}.pkl") | |
| def __torch_dispatch__(self, func, types, args=(), kwargs=None): | |
| kwargs = kwargs if kwargs else {} | |
| path = None | |
| # Find longest path | |
| for p in self.mod_tracker.parents: | |
| if p == "Global": | |
| continue | |
| if path is None or len(p) > len(path): | |
| path = p | |
| if path is None: | |
| path = "Global" | |
| path = path.replace("._checkpoint_wrapped_module", "") | |
| out = func(*args, **kwargs) | |
| # Handle linear layers | |
| if func._overloadpacket in [torch.ops.aten.addmm, torch.ops.aten.mm]: | |
| weight: torch.Tensor | |
| input: torch.Tensor | |
| if not self.mod_tracker.is_bw: | |
| # (technically, weight is transposed) | |
| if func._overloadpacket == torch.ops.aten.addmm: | |
| _bias, input, weight = args[:3] | |
| else: | |
| assert func._overloadpacket == torch.ops.aten.mm | |
| input, weight = args[:2] | |
| self.log_tensor(f"{path}::in", input) | |
| self.log_tensor(f"{path}::w", weight) | |
| self.log_tensor(f"{path}::out", out) | |
| self.linear_data[path] = ( | |
| input.shape, | |
| weight.shape, | |
| out.shape, | |
| input[:4, :4].clone(), | |
| weight[:4, :4].T.clone(), | |
| ) | |
| elif func._overloadpacket == torch.ops.aten.mm: | |
| # XXX: Try to find the actual path for the linear layer | |
| # This is messed with with Francisco's FSDP sometimes | |
| new_path, bwtype = self._find_bw_path_and_type(path, out, args) | |
| if new_path != path: | |
| if self.verbose: | |
| print(f"E: Fixing path `{path}` -> `{new_path}") | |
| path = new_path | |
| if bwtype == LinearBwType.DW: | |
| # dW.t = dY.t @ X | |
| self.log_tensor(f"{path}::w.g", out) | |
| elif bwtype == LinearBwType.DX: | |
| # dX = dY @ W.t | |
| self.log_tensor(f"{path}::in.g", out) | |
| self.log_tensor(f"{path}::out.g", args[0]) | |
| elif func._overloadpacket in [ | |
| torch.ops.aten._scaled_dot_product_flash_attention, | |
| torch.ops.aten._scaled_dot_product_cudnn_attention, | |
| ]: | |
| _, kwargs = normalize_function( | |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
| ) | |
| _compute_attn_stats_sdpa(self, path, **kwargs) | |
| elif func._overloadpacket == fmha.flash.FwOp.OPERATOR: | |
| _, kwargs = normalize_function( | |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
| ) | |
| _compute_attn_stats_flash(self, path, **kwargs) | |
| elif func._overloadpacket == torch.ops.torchprobe.log: | |
| uid = args[2] | |
| path = self.uid_to_path.setdefault(uid, path) | |
| self.log_tensor(f"{path}::{args[1]}", args[0]) | |
| if self.verbose: | |
| print(f"{'[BW]' if self.mod_tracker.is_bw else '[FW]'} `{path}`: {func}") | |
| return out | |
| def _find_all_submodules_compiled(out: List[nn.Module], module: nn.Module) -> None: | |
| if module._compiled_call_impl is not None: | |
| out.append(module) | |
| for c in module.children(): | |
| _find_all_submodules_compiled(out, module=c) | |
| class TorchCompileDisabler: | |
| def __init__(self, module: nn.Module) -> None: | |
| self.module = module | |
| self.submodules_compiled: List[nn.Module] = [] | |
| self.compiled_call_impl: List[Any] = [] | |
| self.disable_compile = torch.compiler.disable() | |
| torch._dynamo.config.raise_on_ctx_manager_usage = False # type: ignore | |
| def __enter__(self) -> None: | |
| # Remove all `_compiled_call_impl` attributes to effectively | |
| # "undo" compilation | |
| self.submodules_compiled.clear() | |
| _find_all_submodules_compiled(self.submodules_compiled, self.module) | |
| self.compiled_call_impl = [ | |
| m._compiled_call_impl for m in self.submodules_compiled | |
| ] | |
| for m in self.submodules_compiled: | |
| m._compiled_call_impl = None | |
| self.disable_compile.__enter__() # type: ignore | |
| def __exit__(self, *args) -> None: | |
| self.disable_compile.__exit__(*args) # type: ignore | |
| for m, c_impl in zip(self.submodules_compiled, self.compiled_call_impl): | |
| m._compiled_call_impl = c_impl | |
| self.compiled_call_impl = [] | |
| Probe = AutoProbeD | |
| # EXAMPLE USAGE | |
| d = 512 | |
| seqlen = 4 | |
| bs = 2 | |
| class Attention1(nn.Module): | |
| def forward(self, x): | |
| attn_bias = fmha.attn_bias.LowerTriangularFromBottomRightMask() | |
| return fmha.memory_efficient_attention(x, x, x, attn_bias=attn_bias).reshape( | |
| [x.shape[0], seqlen, -1] | |
| ) | |
| class Attention2(nn.Module): | |
| def forward(self, x): | |
| attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens( | |
| [seqlen] * bs | |
| ).make_causal() | |
| xr = x.reshape([1, 2 * seqlen, x.shape[2], x.shape[3]]) | |
| return fmha.memory_efficient_attention(xr, xr, xr, attn_bias=attn_bias).reshape( | |
| [x.shape[0], seqlen, -1] | |
| ) | |
| class AttentionSDPA(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.wo = nn.Linear(d, d) | |
| def forward(self, x): | |
| x = x.transpose(1, 2) | |
| return self.wo( | |
| F.scaled_dot_product_attention(x, x, x) | |
| .transpose(1, 2) | |
| .reshape([x.shape[0], seqlen, -1]) | |
| ) | |
| class AttentionSDPAFlash(AttentionSDPA): | |
| def forward(self, x): | |
| x = x.transpose(1, 2) | |
| with sdpa_kernel(SDPBackend.FLASH_ATTENTION): | |
| return self.wo( | |
| F.scaled_dot_product_attention(x, x, x) | |
| .transpose(1, 2) | |
| .reshape([x.shape[0], seqlen, -1]) | |
| ) | |
| class Model(nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.head = nn.Linear(d, 16) | |
| self.trunk = nn.Sequential( | |
| nn.Linear(d, d), | |
| nn.Linear(d, d), | |
| ) | |
| self.q_proj = nn.Linear(d, d, bias=False) | |
| self.trunk.compile() | |
| self.attn1 = Attention1() | |
| self.attn2 = Attention2() | |
| self.attnSDPA = AttentionSDPA() | |
| self.attnSDPAflash = AttentionSDPAFlash() | |
| def forward(self, x): | |
| B, nHeads, D = x.shape[0], d // 64, 64 | |
| x = self.q_proj(x).reshape([B, seqlen, nHeads, D]) | |
| x = self.attn1(x) + self.attn2(x) + self.attnSDPA(x) + self.attnSDPAflash(x) | |
| x = log_stats(x, "attns_out") | |
| return self.head(self.trunk(x)) | |
| def test_masking() -> None: | |
| q_seqlen = [1, 1, 14, 12] | |
| kv_seqlen = [2, 2, 14, 18] | |
| attn_bias = fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( | |
| q_seqlen, kv_seqlen | |
| ).make_causal_from_bottomright() | |
| logits = torch.randn( | |
| [1, 1, sum(q_seqlen), sum(kv_seqlen)], dtype=torch.float32, device="cuda" | |
| ) | |
| bias = attn_bias.materialize(logits.shape, dtype=logits.dtype, device=logits.device) | |
| logits_masked = logits.clone() | |
| _mask_attn_logits( | |
| logits_masked, | |
| list(range(logits.shape[2])), | |
| causal=True, | |
| cu_seqlens_q=attn_bias.q_seqinfo.seqstart, | |
| cu_seqlens_k=attn_bias.k_seqinfo.seqstart, | |
| ) | |
| assert (logits + bias == logits_masked).all().item() | |
| def test_toy_model() -> None: | |
| # Test masking | |
| kw = dict(device="cuda", dtype=torch.float16) | |
| x = torch.randn([bs, seqlen, d], **kw) | |
| m = Model() | |
| m.head = checkpoint_wrapper( | |
| m.head, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False | |
| ) | |
| m.to(**kw) | |
| m.compile() | |
| optim = torch.optim.SGD(m.parameters(), lr=0.0) | |
| probe = AutoProbeD(m, "./probe.json") | |
| for i in range(4): | |
| with contextlib.ExitStack() as stack: | |
| print(f"########### STEP {i}") | |
| if i % 4 == 1: | |
| stack.enter_context(probe) | |
| probe.metadata = {"it": i} | |
| y = m(x) | |
| g = torch.randn_like(y) | |
| y.backward(g) | |
| if i % 4 == 1: | |
| assert probe.enabled | |
| # Make sure we registered all linears | |
| print(list(probe.store.keys())) | |
| for key in [ | |
| "Model::attns_out", | |
| "Model::attns_out.g", | |
| "Model.attn1::attn_logits", | |
| "Model.attn2::attn_logits", | |
| "Model.attnSDPA::attn_logits", | |
| "Model.attnSDPAflash::attn_logits", | |
| "Model.head::w", | |
| "Model.head::w.g", | |
| "Model.head::in", | |
| "Model.head::in.g", | |
| "Model.head::out", | |
| "Model.head::out.g", | |
| "Model.trunk.0::in", | |
| "Model.trunk.1::in", | |
| ]: | |
| assert key in probe.store, f"Missing key: '{key}'" | |
| # .. and that the values are correct | |
| for key, tensor in [ | |
| ("Model.head::w", m.head.weight), | |
| ("Model.head::w.g", m.head.weight.grad), | |
| ("Model.q_proj::in", x), | |
| ("Model.q_proj::w.g", m.q_proj.weight.grad), | |
| ("Model.head::out", y), | |
| ("Model.head::out.g", g), | |
| ]: | |
| assert key in probe.store, f"Missing key: '{key}'" | |
| assert torch.allclose( | |
| probe.store[key]["abs.mean"], tensor.float().abs().mean() | |
| ), f"'{key}' mismatches" | |
| # Check we don't have `nans` | |
| for key, value in probe.store.items(): | |
| if "abs.mean" in value: | |
| assert math.isfinite( | |
| value["abs.mean"].item() | |
| ), f"Inf/Nan for {key}" | |
| optim.step() | |
| optim.zero_grad() | |