Spaces:
Running
on
T4
Running
on
T4
| # Copyright 2021 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Function to stack repeats of a layer function without shared parameters.""" | |
| import collections | |
| import contextlib | |
| import functools | |
| import inspect | |
| from typing import Any, Callable, Optional, Tuple, Union | |
| import haiku as hk | |
| import jax | |
| import jax.numpy as jnp | |
| LayerStackCarry = collections.namedtuple('LayerStackCarry', ['x', 'rng']) | |
| LayerStackScanned = collections.namedtuple('LayerStackScanned', | |
| ['i', 'args_ys']) | |
| # WrappedFn should take in arbitrarily nested `jnp.ndarray`, and return the | |
| # exact same type. We cannot express this with `typing`. So we just use it | |
| # to inform the user. In reality, the typing below will accept anything. | |
| NestedArray = Any | |
| WrappedFn = Callable[..., Union[NestedArray, Tuple[NestedArray]]] | |
| def _check_no_varargs(f): | |
| if list(inspect.signature( | |
| f).parameters.values())[0].kind == inspect.Parameter.VAR_POSITIONAL: | |
| raise ValueError( | |
| 'The function `f` should not have any `varargs` (that is *args) ' | |
| 'argument. Instead, it should only use explicit positional' | |
| 'arguments.') | |
| def nullcontext(): | |
| yield | |
| def maybe_with_rng(key): | |
| if key is not None: | |
| return hk.with_rng(key) | |
| else: | |
| return nullcontext() | |
| def maybe_fold_in(key, data): | |
| if key is not None: | |
| return jax.random.fold_in(key, data) | |
| else: | |
| return None | |
| class _LayerStack(hk.Module): | |
| """Module to compose parameterized functions, implemented as a scan.""" | |
| def __init__(self, | |
| count: int, | |
| unroll: int, | |
| name: Optional[str] = None): | |
| """Iterate a function `f` `count` times, with non-shared parameters.""" | |
| super().__init__(name=name) | |
| self._count = count | |
| self._unroll = unroll | |
| def __call__(self, x, *args_ys): | |
| count = self._count | |
| if hk.running_init(): | |
| # At initialization time, we run just one layer but add an extra first | |
| # dimension to every initialized tensor, making sure to use different | |
| # random keys for different slices. | |
| def creator(next_creator, shape, dtype, init, context): | |
| del context | |
| def multi_init(shape, dtype): | |
| assert shape[0] == count | |
| key = hk.maybe_next_rng_key() | |
| def rng_context_init(slice_idx): | |
| slice_key = maybe_fold_in(key, slice_idx) | |
| with maybe_with_rng(slice_key): | |
| return init(shape[1:], dtype) | |
| return jax.vmap(rng_context_init)(jnp.arange(count)) | |
| return next_creator((count,) + tuple(shape), dtype, multi_init) | |
| def getter(next_getter, value, context): | |
| trailing_dims = len(context.original_shape) + 1 | |
| sliced_value = jax.lax.index_in_dim( | |
| value, index=0, axis=value.ndim - trailing_dims, keepdims=False) | |
| return next_getter(sliced_value) | |
| with hk.experimental.custom_creator( | |
| creator), hk.experimental.custom_getter(getter): | |
| if len(args_ys) == 1 and args_ys[0] is None: | |
| args0 = (None,) | |
| else: | |
| args0 = [ | |
| jax.lax.dynamic_index_in_dim(ys, 0, keepdims=False) | |
| for ys in args_ys | |
| ] | |
| x, z = self._call_wrapped(x, *args0) | |
| if z is None: | |
| return x, z | |
| # Broadcast state to hold each layer state. | |
| def broadcast_state(layer_state): | |
| return jnp.broadcast_to( | |
| layer_state, [count,] + list(layer_state.shape)) | |
| zs = jax.tree_util.tree_map(broadcast_state, z) | |
| return x, zs | |
| else: | |
| # Use scan during apply, threading through random seed so that it's | |
| # unique for each layer. | |
| def layer(carry: LayerStackCarry, scanned: LayerStackScanned): | |
| rng = carry.rng | |
| def getter(next_getter, value, context): | |
| # Getter slices the full param at the current loop index. | |
| trailing_dims = len(context.original_shape) + 1 | |
| assert value.shape[value.ndim - trailing_dims] == count, ( | |
| f'Attempting to use a parameter stack of size ' | |
| f'{value.shape[value.ndim - trailing_dims]} for a LayerStack of ' | |
| f'size {count}.') | |
| sliced_value = jax.lax.dynamic_index_in_dim( | |
| value, scanned.i, axis=value.ndim - trailing_dims, keepdims=False) | |
| return next_getter(sliced_value) | |
| with hk.experimental.custom_getter(getter): | |
| if rng is None: | |
| out_x, z = self._call_wrapped(carry.x, *scanned.args_ys) | |
| else: | |
| rng, rng_ = jax.random.split(rng) | |
| with hk.with_rng(rng_): | |
| out_x, z = self._call_wrapped(carry.x, *scanned.args_ys) | |
| return LayerStackCarry(x=out_x, rng=rng), z | |
| carry = LayerStackCarry(x=x, rng=hk.maybe_next_rng_key()) | |
| scanned = LayerStackScanned(i=jnp.arange(count, dtype=jnp.int32), | |
| args_ys=args_ys) | |
| carry, zs = hk.scan( | |
| layer, carry, scanned, length=count, unroll=self._unroll) | |
| return carry.x, zs | |
| def _call_wrapped(self, | |
| x: jnp.ndarray, | |
| *args, | |
| ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: | |
| raise NotImplementedError() | |
| class _LayerStackNoState(_LayerStack): | |
| """_LayerStack impl with no per-layer state provided to the function.""" | |
| def __init__(self, | |
| f: WrappedFn, | |
| count: int, | |
| unroll: int, | |
| name: Optional[str] = None): | |
| super().__init__(count=count, unroll=unroll, name=name) | |
| _check_no_varargs(f) | |
| self._f = f | |
| def _call_wrapped(self, args, y): | |
| del y | |
| ret = self._f(*args) | |
| if len(args) == 1: | |
| # If the function takes a single argument, the wrapped function receives | |
| # a tuple of length 1, and therefore it must return a tuple of length 1. | |
| ret = (ret,) | |
| return ret, None | |
| class _LayerStackWithState(_LayerStack): | |
| """_LayerStack impl with per-layer state provided to the function.""" | |
| def __init__(self, | |
| f: WrappedFn, | |
| count: int, | |
| unroll: int, | |
| name: Optional[str] = None): | |
| super().__init__(count=count, unroll=unroll, name=name) | |
| self._f = f | |
| def _call_wrapped(self, x, *args): | |
| return self._f(x, *args) | |
| def layer_stack(num_layers: int, | |
| with_state=False, | |
| unroll: int = 1, | |
| name: Optional[str] = None): | |
| """Utility to wrap a Haiku function and recursively apply it to an input. | |
| A function is valid if it uses only explicit position parameters, and | |
| its return type matches its input type. The position parameters can be | |
| arbitrarily nested structures with `jnp.ndarray` at the leaf nodes. Note | |
| that kwargs are not supported, neither are functions with variable number | |
| of parameters (specified by `*args`). | |
| If `with_state=False` then the new, wrapped function can be understood as | |
| performing the following: | |
| ``` | |
| for i in range(num_layers): | |
| x = f(x) | |
| return x | |
| ``` | |
| And if `with_state=True`, assuming `f` takes two arguments on top of `x`: | |
| ``` | |
| for i in range(num_layers): | |
| x, zs[i] = f(x, ys_0[i], ys_1[i]) | |
| return x, zs | |
| ``` | |
| The code using `layer_stack` for the above function would be: | |
| ``` | |
| def f(x, y_0, y_1): | |
| ... | |
| return new_x, z | |
| x, zs = layer_stack.layer_stack(num_layers, | |
| with_state=True)(f)(x, ys_0, ys_1) | |
| ``` | |
| Crucially, any parameters created inside `f` will not be shared across | |
| iterations. | |
| Args: | |
| num_layers: The number of times to iterate the wrapped function. | |
| with_state: Whether or not to pass per-layer state to the wrapped function. | |
| unroll: the unroll used by `scan`. | |
| name: Name of the Haiku context. | |
| Returns: | |
| Callable that will produce a layer stack when called with a valid function. | |
| """ | |
| def iterate(f): | |
| if with_state: | |
| def wrapped(x, *args): | |
| for ys in args: | |
| assert ys.shape[0] == num_layers | |
| return _LayerStackWithState( | |
| f, num_layers, unroll=unroll, name=name)(x, *args) | |
| else: | |
| _check_no_varargs(f) | |
| def wrapped(*args): | |
| ret = _LayerStackNoState( | |
| f, num_layers, unroll=unroll, name=name)(args, None)[0] | |
| if len(args) == 1: | |
| # If the function takes a single argument, we must also return a | |
| # single value, and not a tuple of length 1. | |
| ret = ret[0] | |
| return ret | |
| return wrapped | |
| return iterate | |