Spaces:
Running
Running
feat(train): log norm and histograms (#143)
Browse files* feat(train): log norm and histograms
* feat: update shampoo
tools/train/scalable_shampoo/distributed_shampoo.py
CHANGED
|
@@ -832,8 +832,11 @@ def distributed_shampoo(
|
|
| 832 |
if not _skip_preconditioning(param):
|
| 833 |
sizes = [s[0] for s in shapes]
|
| 834 |
shapes = preconditioner.shapes_for_preconditioners()
|
| 835 |
-
statistics = [
|
| 836 |
-
|
|
|
|
|
|
|
|
|
|
| 837 |
padded_statistics.extend(statistics)
|
| 838 |
padded_preconditioners.extend(preconditioners)
|
| 839 |
exponent = (
|
|
@@ -1244,8 +1247,10 @@ def distributed_shampoo(
|
|
| 1244 |
preconditioners = []
|
| 1245 |
if not _skip_preconditioning(param):
|
| 1246 |
shapes = preconditioner.shapes_for_preconditioners()
|
| 1247 |
-
statistics = [
|
| 1248 |
-
|
|
|
|
|
|
|
| 1249 |
|
| 1250 |
diagonal_statistics = []
|
| 1251 |
if _graft_type_has_diagonal_statistics():
|
|
|
|
| 832 |
if not _skip_preconditioning(param):
|
| 833 |
sizes = [s[0] for s in shapes]
|
| 834 |
shapes = preconditioner.shapes_for_preconditioners()
|
| 835 |
+
statistics = [
|
| 836 |
+
matrix_epsilon * jnp.eye(max_size, dtype=jnp.float32)
|
| 837 |
+
for s in shapes
|
| 838 |
+
]
|
| 839 |
+
preconditioners = [jnp.eye(max_size, dtype=jnp.float32) for s in shapes]
|
| 840 |
padded_statistics.extend(statistics)
|
| 841 |
padded_preconditioners.extend(preconditioners)
|
| 842 |
exponent = (
|
|
|
|
| 1247 |
preconditioners = []
|
| 1248 |
if not _skip_preconditioning(param):
|
| 1249 |
shapes = preconditioner.shapes_for_preconditioners()
|
| 1250 |
+
statistics = [
|
| 1251 |
+
matrix_epsilon * jnp.eye(s[0], dtype=jnp.float32) for s in shapes
|
| 1252 |
+
]
|
| 1253 |
+
preconditioners = [jnp.eye(s[0], dtype=jnp.float32) for s in shapes]
|
| 1254 |
|
| 1255 |
diagonal_statistics = []
|
| 1256 |
if _graft_type_has_diagonal_statistics():
|
tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py
CHANGED
|
@@ -16,10 +16,11 @@
|
|
| 16 |
"""JAX Ops for symmetric matrices used by the Shampoo optimizer."""
|
| 17 |
|
| 18 |
import functools
|
| 19 |
-
from typing import List, Union
|
| 20 |
|
| 21 |
import jax
|
| 22 |
import jax.numpy as jnp
|
|
|
|
| 23 |
from flax import struct
|
| 24 |
from jax import lax
|
| 25 |
|
|
@@ -41,6 +42,7 @@ class SlicedSymmetricMatrix:
|
|
| 41 |
def product_with_transpose(
|
| 42 |
mat1,
|
| 43 |
mat2,
|
|
|
|
| 44 |
precision=lax.Precision.DEFAULT,
|
| 45 |
):
|
| 46 |
"""Returns mat1 * mat2^T for two matrices (possibly batched).
|
|
@@ -50,50 +52,85 @@ def product_with_transpose(
|
|
| 50 |
Args:
|
| 51 |
mat1: First matrix.
|
| 52 |
mat2: Second matrix.
|
|
|
|
| 53 |
precision: JAX precision to use for the multiplication.
|
| 54 |
"""
|
| 55 |
-
return jnp.
|
| 56 |
|
| 57 |
|
| 58 |
-
@functools.partial(jax.jit, static_argnames=("block_size", "precision"))
|
| 59 |
def sliced_transposed_product(
|
| 60 |
mat,
|
| 61 |
block_size,
|
|
|
|
| 62 |
precision=lax.Precision.DEFAULT,
|
| 63 |
):
|
| 64 |
-
"""Returns the blocked slices representing a symmetric
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
Args:
|
| 67 |
-
mat: The matrix for which we will compute
|
| 68 |
-
square, and may be batched.
|
| 69 |
block_size: The size of row blocks to compute.
|
|
|
|
| 70 |
precision: The precision to use in each computation.
|
| 71 |
|
| 72 |
Raises:
|
| 73 |
ValueError: Raised when the specified block size does not evenly divide
|
| 74 |
the number of rows of the input mat.
|
| 75 |
"""
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
if num_rows % block_size != 0:
|
| 78 |
raise ValueError(
|
| 79 |
"The row dimension must be divisible by block_size. "
|
| 80 |
f"Instead got row dimension={num_rows} and block_size={block_size}."
|
| 81 |
)
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
)
|
| 88 |
-
|
| 89 |
-
]
|
| 90 |
return SlicedSymmetricMatrix(block_rows=block_rows)
|
| 91 |
|
| 92 |
|
| 93 |
-
@functools.partial(jax.jit, static_argnames=("block_size", "precision"))
|
| 94 |
def sliced_transposed_product_concat(
|
| 95 |
mat,
|
| 96 |
block_size,
|
|
|
|
| 97 |
precision=lax.Precision.DEFAULT,
|
| 98 |
):
|
| 99 |
"""Returns the concatenated slices representing mat*mat^T.
|
|
@@ -102,6 +139,7 @@ def sliced_transposed_product_concat(
|
|
| 102 |
mat: The matrix for which we will compute mat*mat^T. It does not need to be
|
| 103 |
square, and may be batched.
|
| 104 |
block_size: The size of row blocks to compute.
|
|
|
|
| 105 |
precision: The precision to use in each computation.
|
| 106 |
|
| 107 |
Raises:
|
|
@@ -109,7 +147,7 @@ def sliced_transposed_product_concat(
|
|
| 109 |
the number of rows of the input mat.
|
| 110 |
"""
|
| 111 |
sliced_symmetric_matrix = sliced_transposed_product(
|
| 112 |
-
mat=mat, block_size=block_size, precision=precision
|
| 113 |
)
|
| 114 |
return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
|
| 115 |
|
|
@@ -179,12 +217,13 @@ def materialize_matrix_from_concat(
|
|
| 179 |
return materialize_matrix(SlicedSymmetricMatrix(block_rows=block_rows))
|
| 180 |
|
| 181 |
|
| 182 |
-
@functools.partial(jax.jit, static_argnames=("alpha", "beta"))
|
| 183 |
def update_sliced_rows(
|
| 184 |
symmetric_matrix,
|
| 185 |
mat,
|
| 186 |
alpha,
|
| 187 |
beta,
|
|
|
|
| 188 |
):
|
| 189 |
"""Implements the blocked equivalent of SYRK.
|
| 190 |
|
|
@@ -197,15 +236,45 @@ def update_sliced_rows(
|
|
| 197 |
should match that of symmetric_matrix.
|
| 198 |
alpha: The weight for the update.
|
| 199 |
beta: The weight for the original symmetric matrix.
|
|
|
|
| 200 |
|
| 201 |
Returns:
|
| 202 |
The updated rows of alpha * mat * mat^T + beta * symmetric_matrix.
|
| 203 |
"""
|
| 204 |
block_size = symmetric_matrix.block_rows[0].shape[-2]
|
| 205 |
-
sym_prod = sliced_transposed_product(mat=mat, block_size=block_size)
|
| 206 |
return SlicedSymmetricMatrix(
|
| 207 |
block_rows=[
|
| 208 |
update * alpha + row * beta
|
| 209 |
for update, row in zip(sym_prod.block_rows, symmetric_matrix.block_rows)
|
| 210 |
]
|
| 211 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"""JAX Ops for symmetric matrices used by the Shampoo optimizer."""
|
| 17 |
|
| 18 |
import functools
|
| 19 |
+
from typing import Any, List, Sequence, Union
|
| 20 |
|
| 21 |
import jax
|
| 22 |
import jax.numpy as jnp
|
| 23 |
+
import numpy as np
|
| 24 |
from flax import struct
|
| 25 |
from jax import lax
|
| 26 |
|
|
|
|
| 42 |
def product_with_transpose(
|
| 43 |
mat1,
|
| 44 |
mat2,
|
| 45 |
+
axes,
|
| 46 |
precision=lax.Precision.DEFAULT,
|
| 47 |
):
|
| 48 |
"""Returns mat1 * mat2^T for two matrices (possibly batched).
|
|
|
|
| 52 |
Args:
|
| 53 |
mat1: First matrix.
|
| 54 |
mat2: Second matrix.
|
| 55 |
+
axes: The axes over which to apply the product.
|
| 56 |
precision: JAX precision to use for the multiplication.
|
| 57 |
"""
|
| 58 |
+
return jnp.tensordot(a=mat1, b=mat2, axes=axes, precision=precision)
|
| 59 |
|
| 60 |
|
| 61 |
+
@functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision"))
|
| 62 |
def sliced_transposed_product(
|
| 63 |
mat,
|
| 64 |
block_size,
|
| 65 |
+
axes=(-1,),
|
| 66 |
precision=lax.Precision.DEFAULT,
|
| 67 |
):
|
| 68 |
+
"""Returns the blocked slices representing a symmetric contraction.
|
| 69 |
+
|
| 70 |
+
Specifically, the output is a contraction of the input mat with itself, in the
|
| 71 |
+
specified axes.
|
| 72 |
|
| 73 |
Args:
|
| 74 |
+
mat: The matrix for which we will compute a contraction with itself.
|
|
|
|
| 75 |
block_size: The size of row blocks to compute.
|
| 76 |
+
axes: Axes to use for the contraction.
|
| 77 |
precision: The precision to use in each computation.
|
| 78 |
|
| 79 |
Raises:
|
| 80 |
ValueError: Raised when the specified block size does not evenly divide
|
| 81 |
the number of rows of the input mat.
|
| 82 |
"""
|
| 83 |
+
rank = len(mat.shape)
|
| 84 |
+
|
| 85 |
+
def _make_axis_positive(ax):
|
| 86 |
+
assert -rank <= ax < rank
|
| 87 |
+
return ax + rank if ax < 0 else ax
|
| 88 |
+
|
| 89 |
+
positive_axes = [_make_axis_positive(ax) for ax in axes]
|
| 90 |
+
assert len(positive_axes) == len(axes)
|
| 91 |
+
remaining_axes = set(range(rank)) - set(positive_axes)
|
| 92 |
+
assert len(remaining_axes) == 1
|
| 93 |
+
remaining_ax = remaining_axes.pop()
|
| 94 |
+
|
| 95 |
+
num_rows = mat.shape[remaining_ax]
|
| 96 |
if num_rows % block_size != 0:
|
| 97 |
raise ValueError(
|
| 98 |
"The row dimension must be divisible by block_size. "
|
| 99 |
f"Instead got row dimension={num_rows} and block_size={block_size}."
|
| 100 |
)
|
| 101 |
+
|
| 102 |
+
block_rows = []
|
| 103 |
+
for i in range(num_rows // block_size):
|
| 104 |
+
start_indices = [0] * rank
|
| 105 |
+
start_indices[remaining_ax] = i * block_size
|
| 106 |
+
|
| 107 |
+
slice_sizes = list(mat.shape)
|
| 108 |
+
slice_sizes[remaining_ax] = block_size
|
| 109 |
+
|
| 110 |
+
slice_sizes_full = list(mat.shape)
|
| 111 |
+
slice_sizes_full[remaining_ax] = (i + 1) * block_size
|
| 112 |
+
|
| 113 |
+
block_rows.append(
|
| 114 |
+
product_with_transpose(
|
| 115 |
+
lax.dynamic_slice(
|
| 116 |
+
mat, start_indices=start_indices, slice_sizes=slice_sizes
|
| 117 |
+
),
|
| 118 |
+
lax.dynamic_slice(
|
| 119 |
+
mat, start_indices=[0] * rank, slice_sizes=slice_sizes_full
|
| 120 |
+
),
|
| 121 |
+
axes=(axes, axes),
|
| 122 |
+
precision=precision,
|
| 123 |
+
)
|
| 124 |
)
|
| 125 |
+
|
|
|
|
| 126 |
return SlicedSymmetricMatrix(block_rows=block_rows)
|
| 127 |
|
| 128 |
|
| 129 |
+
@functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision"))
|
| 130 |
def sliced_transposed_product_concat(
|
| 131 |
mat,
|
| 132 |
block_size,
|
| 133 |
+
axes=(-1,),
|
| 134 |
precision=lax.Precision.DEFAULT,
|
| 135 |
):
|
| 136 |
"""Returns the concatenated slices representing mat*mat^T.
|
|
|
|
| 139 |
mat: The matrix for which we will compute mat*mat^T. It does not need to be
|
| 140 |
square, and may be batched.
|
| 141 |
block_size: The size of row blocks to compute.
|
| 142 |
+
axes: Axes to use for the contraction.
|
| 143 |
precision: The precision to use in each computation.
|
| 144 |
|
| 145 |
Raises:
|
|
|
|
| 147 |
the number of rows of the input mat.
|
| 148 |
"""
|
| 149 |
sliced_symmetric_matrix = sliced_transposed_product(
|
| 150 |
+
mat=mat, block_size=block_size, axes=axes, precision=precision
|
| 151 |
)
|
| 152 |
return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
|
| 153 |
|
|
|
|
| 217 |
return materialize_matrix(SlicedSymmetricMatrix(block_rows=block_rows))
|
| 218 |
|
| 219 |
|
| 220 |
+
@functools.partial(jax.jit, static_argnames=("alpha", "beta", "axes"))
|
| 221 |
def update_sliced_rows(
|
| 222 |
symmetric_matrix,
|
| 223 |
mat,
|
| 224 |
alpha,
|
| 225 |
beta,
|
| 226 |
+
axes=(-1,),
|
| 227 |
):
|
| 228 |
"""Implements the blocked equivalent of SYRK.
|
| 229 |
|
|
|
|
| 236 |
should match that of symmetric_matrix.
|
| 237 |
alpha: The weight for the update.
|
| 238 |
beta: The weight for the original symmetric matrix.
|
| 239 |
+
axes: Axes to use for the contraction of the update.
|
| 240 |
|
| 241 |
Returns:
|
| 242 |
The updated rows of alpha * mat * mat^T + beta * symmetric_matrix.
|
| 243 |
"""
|
| 244 |
block_size = symmetric_matrix.block_rows[0].shape[-2]
|
| 245 |
+
sym_prod = sliced_transposed_product(mat=mat, block_size=block_size, axes=axes)
|
| 246 |
return SlicedSymmetricMatrix(
|
| 247 |
block_rows=[
|
| 248 |
update * alpha + row * beta
|
| 249 |
for update, row in zip(sym_prod.block_rows, symmetric_matrix.block_rows)
|
| 250 |
]
|
| 251 |
)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def find_num_blocks(block_rows_concat):
|
| 255 |
+
"""Returns the number of (row) blocks representing the concatenated matrix.
|
| 256 |
+
|
| 257 |
+
For example, an input with dimensions [256, 2560] represents 10 square blocks,
|
| 258 |
+
which matches 4 lower-triangular block rows (1+2+3+4). So this function will
|
| 259 |
+
return 4.
|
| 260 |
+
|
| 261 |
+
Use ordinary numpy functions here so that the returned value is static.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
block_rows_concat: The concatenated block array.
|
| 265 |
+
|
| 266 |
+
Raises:
|
| 267 |
+
ValueError: When the dimensions of the matrix do not correspond to a lower
|
| 268 |
+
triangular block representation.
|
| 269 |
+
"""
|
| 270 |
+
# Compute the number of square blocks used to represent the matrix.
|
| 271 |
+
total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2]
|
| 272 |
+
# Determine the number of block rows by inverting y = x*(x+1)/2.
|
| 273 |
+
num_blocks = np.round((np.sqrt(8 * total_blocks + 1) - 1) / 2).astype(np.int32)
|
| 274 |
+
if num_blocks * (num_blocks + 1) / 2 != total_blocks:
|
| 275 |
+
raise ValueError(
|
| 276 |
+
"Could not determine an appropriate number of blocks for "
|
| 277 |
+
"the concatenated matrix."
|
| 278 |
+
)
|
| 279 |
+
else:
|
| 280 |
+
return num_blocks
|
tools/train/train.py
CHANGED
|
@@ -37,7 +37,7 @@ import optax
|
|
| 37 |
import transformers
|
| 38 |
import wandb
|
| 39 |
from datasets import Dataset
|
| 40 |
-
from flax.core.frozen_dict import FrozenDict, freeze
|
| 41 |
from flax.serialization import from_bytes, to_bytes
|
| 42 |
from flax.training import train_state
|
| 43 |
from flax.training.common_utils import onehot
|
|
@@ -405,6 +405,12 @@ class TrainingArguments:
|
|
| 405 |
default=False,
|
| 406 |
metadata={"help": "Log model to wandb at `save_steps` frequency."},
|
| 407 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
|
| 409 |
seed_model: int = field(
|
| 410 |
default=42,
|
|
@@ -514,10 +520,22 @@ class MetricsLogger:
|
|
| 514 |
|
| 515 |
def log(self, metrics, prefix=None):
|
| 516 |
if jax.process_index() == 0:
|
| 517 |
-
log_metrics = {
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
wandb.log({**log_metrics, **self.state_dict})
|
| 522 |
|
| 523 |
|
|
@@ -1024,8 +1042,9 @@ def main():
|
|
| 1024 |
lambda x: x / training_args.gradient_accumulation_steps, (loss, grads)
|
| 1025 |
)
|
| 1026 |
|
| 1027 |
-
# update state
|
| 1028 |
grads = with_sharding_constraint(grads, param_spec)
|
|
|
|
|
|
|
| 1029 |
state = state.apply_gradients(
|
| 1030 |
grads=grads,
|
| 1031 |
dropout_rng=dropout_rng,
|
|
@@ -1033,11 +1052,49 @@ def main():
|
|
| 1033 |
train_samples=state.train_samples + batch_size_per_step,
|
| 1034 |
)
|
| 1035 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1036 |
metrics = {
|
| 1037 |
"loss": loss,
|
| 1038 |
"learning_rate": learning_rate_fn(state.step),
|
|
|
|
|
|
|
| 1039 |
}
|
| 1040 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1041 |
return state, metrics
|
| 1042 |
|
| 1043 |
# Define eval fn
|
|
|
|
| 37 |
import transformers
|
| 38 |
import wandb
|
| 39 |
from datasets import Dataset
|
| 40 |
+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
| 41 |
from flax.serialization import from_bytes, to_bytes
|
| 42 |
from flax.training import train_state
|
| 43 |
from flax.training.common_utils import onehot
|
|
|
|
| 405 |
default=False,
|
| 406 |
metadata={"help": "Log model to wandb at `save_steps` frequency."},
|
| 407 |
)
|
| 408 |
+
log_histograms: bool = field(
|
| 409 |
+
default=False,
|
| 410 |
+
metadata={
|
| 411 |
+
"help": "Log parameters and gradients histograms. Slows down training."
|
| 412 |
+
},
|
| 413 |
+
)
|
| 414 |
|
| 415 |
seed_model: int = field(
|
| 416 |
default=42,
|
|
|
|
| 520 |
|
| 521 |
def log(self, metrics, prefix=None):
|
| 522 |
if jax.process_index() == 0:
|
| 523 |
+
log_metrics = {}
|
| 524 |
+
for k, v in metrics.items():
|
| 525 |
+
if prefix is not None:
|
| 526 |
+
k = f"{prefix}/{k}"
|
| 527 |
+
if "_norm" in k:
|
| 528 |
+
log_metrics[f"{k}/"] = unfreeze(v)
|
| 529 |
+
elif "_hist" in k:
|
| 530 |
+
v = jax.tree_map(lambda x: jax.device_get(x), unfreeze(v))
|
| 531 |
+
v = jax.tree_map(
|
| 532 |
+
lambda x: wandb.Histogram(np_histogram=x),
|
| 533 |
+
v,
|
| 534 |
+
is_leaf=lambda x: isinstance(x, tuple),
|
| 535 |
+
)
|
| 536 |
+
log_metrics[f"{k}/"] = v
|
| 537 |
+
else:
|
| 538 |
+
log_metrics[k] = v
|
| 539 |
wandb.log({**log_metrics, **self.state_dict})
|
| 540 |
|
| 541 |
|
|
|
|
| 1042 |
lambda x: x / training_args.gradient_accumulation_steps, (loss, grads)
|
| 1043 |
)
|
| 1044 |
|
|
|
|
| 1045 |
grads = with_sharding_constraint(grads, param_spec)
|
| 1046 |
+
|
| 1047 |
+
# update state
|
| 1048 |
state = state.apply_gradients(
|
| 1049 |
grads=grads,
|
| 1050 |
dropout_rng=dropout_rng,
|
|
|
|
| 1052 |
train_samples=state.train_samples + batch_size_per_step,
|
| 1053 |
)
|
| 1054 |
|
| 1055 |
+
# get norm and histogram of grads and params
|
| 1056 |
+
zeros_norm = jax.tree_map(lambda _: jnp.float32(0), state.params)
|
| 1057 |
+
|
| 1058 |
+
def maybe_fn(fn, val, zeros):
|
| 1059 |
+
"""Call fn only if it is a logging step"""
|
| 1060 |
+
return jax.lax.cond(
|
| 1061 |
+
state.step % training_args.logging_steps == 0,
|
| 1062 |
+
fn,
|
| 1063 |
+
lambda _: zeros,
|
| 1064 |
+
val,
|
| 1065 |
+
)
|
| 1066 |
+
|
| 1067 |
+
def norm(val):
|
| 1068 |
+
return jax.tree_map(lambda x: jnp.linalg.norm(x), val)
|
| 1069 |
+
|
| 1070 |
+
gradients_norm = maybe_fn(norm, grads, zeros_norm)
|
| 1071 |
+
params_norm = maybe_fn(norm, state.params, zeros_norm)
|
| 1072 |
+
|
| 1073 |
metrics = {
|
| 1074 |
"loss": loss,
|
| 1075 |
"learning_rate": learning_rate_fn(state.step),
|
| 1076 |
+
"gradients_norm": gradients_norm,
|
| 1077 |
+
"params_norm": params_norm,
|
| 1078 |
}
|
| 1079 |
|
| 1080 |
+
if training_args.log_histograms:
|
| 1081 |
+
zeros_hist = jax.tree_map(
|
| 1082 |
+
lambda _: jnp.histogram(jnp.zeros(1), density=True), state.params
|
| 1083 |
+
)
|
| 1084 |
+
|
| 1085 |
+
def histogram(val):
|
| 1086 |
+
return jax.tree_map(lambda x: jnp.histogram(x, density=True), val)
|
| 1087 |
+
|
| 1088 |
+
gradients_hist = maybe_fn(histogram, grads, zeros_hist)
|
| 1089 |
+
params_hist = maybe_fn(histogram, state.params, zeros_hist)
|
| 1090 |
+
|
| 1091 |
+
metrics.update(
|
| 1092 |
+
{
|
| 1093 |
+
"params_hist": params_hist,
|
| 1094 |
+
"gradients_hist": gradients_hist,
|
| 1095 |
+
}
|
| 1096 |
+
)
|
| 1097 |
+
|
| 1098 |
return state, metrics
|
| 1099 |
|
| 1100 |
# Define eval fn
|