# coding=utf-8 # Copyright 2023 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Loss functions.""" import functools import inspect from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union import jax import jax.numpy as jnp import ml_collections _LOSS_FUNCTIONS = {} Array = Any # jnp.ndarray somehow doesn't work anymore for pytype. ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet ArrayDict = Dict[str, Array] DictTree = Dict[str, Union[Array, "DictTree"]] # pytype: disable=not-supported-yet PRNGKey = Array LossFn = Callable[[Dict[str, ArrayTree], Dict[str, ArrayTree]], Tuple[Array, ArrayTree]] ConfigAttr = Any MetricSpec = Dict[str, str] def standardize_loss_config( loss_config ): """Standardize loss configs into a common ConfigDict format. Args: loss_config: List of strings or ConfigDict specifying loss configuration. Valid input formats are: - Option 1 (list of strings), for example, `loss_config = ["box", "presence"]` - Option 2 (losses with weights only), for example, `loss_config = ConfigDict({"box": 5, "presence": 2})` - Option 3 (losses with weights and other parameters), for example, `loss_config = ConfigDict({"box": {"weight": 5, "metric": "l1"}, "presence": {"weight": 2}})` Returns: Standardized ConfigDict containing the loss configuration. Raises: ValueError: If loss_config is a list that contains non-string entries. """ if isinstance(loss_config, Sequence): # Option 1 if not all(isinstance(loss_type, str) for loss_type in loss_config): raise ValueError(f"Loss types all need to be str but got {loss_config}") return ml_collections.FrozenConfigDict({k: {} for k in loss_config}) # Convert all option-2-style weights to option-3-style dictionaries. loss_config = { k: { "weight": v } if isinstance(v, (float, int)) else v for k, v in loss_config.items() } return ml_collections.FrozenConfigDict(loss_config) def update_loss_aux(loss_aux, update): existing_keys = set(update.keys()).intersection(loss_aux.keys()) if existing_keys: raise KeyError( f"Can't overwrite existing keys in loss_aux: {existing_keys}") loss_aux.update(update) def compute_full_loss( preds, targets, loss_config ): """Loss function that parses and combines weighted loss terms. Args: preds: Dictionary of tensors containing model predictions. targets: Dictionary of tensors containing prediction targets. loss_config: List of strings or ConfigDict specifying loss configuration. See @register_loss decorated functions below for valid loss names. Valid losses formats are: - Option 1 (list of strings), for example, `loss_config = ["box", "presence"]` - Option 2 (losses with weights only), for example, `loss_config = ConfigDict({"box": 5, "presence": 2})` - Option 3 (losses with weights and other parameters), for example, `loss_config = ConfigDict({"box": {"weight": 5, "metric": "l1"}, "presence": {"weight": 2}})` - Option 4 (like 3 but decoupling name and loss_type), for example, `loss_config = ConfigDict({"recon_flow": {"loss_type": "recon", "key": "flow"}, "recon_video": {"loss_type": "recon", "key": "video"}})` Returns: A 2-tuple of the sum of all individual loss terms and a dictionary of auxiliary losses and metrics. """ loss = jnp.zeros([], jnp.float32) loss_aux = {} loss_config = standardize_loss_config(loss_config) for loss_name, cfg in loss_config.items(): context_kwargs = {"preds": preds, "targets": targets} weight, loss_term, loss_aux_update = compute_loss_term( loss_name=loss_name, context_kwargs=context_kwargs, config_kwargs=cfg) unweighted_loss = jnp.mean(loss_term) loss += weight * unweighted_loss loss_aux_update[loss_name + "_value"] = unweighted_loss loss_aux_update[loss_name + "_weight"] = jnp.ones_like(unweighted_loss) update_loss_aux(loss_aux, loss_aux_update) return loss, loss_aux def register_loss(func=None, *, name = None, check_unused_kwargs = True): """Decorator for registering a loss function. Can be used without arguments: ``` @register_loss def my_loss(**_): return 0 ``` or with keyword arguments: ``` @register_loss(name="my_renamed_loss") def my_loss(**_): return 0 ``` Loss functions may accept - context kwargs: `preds` and `targets` - config kwargs: any argument specified in the config - the special `config_kwargs` parameter that contains the entire loss config Loss functions also _need_ to accept a **kwarg argument to support extending the interface. They should return either: - just the computed loss (pre-reduction) - or a tuple of the computed loss and a loss_aux_updates dict Args: func: the decorated function name (str): Optional name to be used for this loss in the config. Defaults to the name of the function. check_unused_kwargs (bool): By default compute_loss_term raises an error if there are any unused config kwargs. If this flag is set to False that step is skipped. This is useful if the config_kwargs should be passed onward to another function. Returns: The decorated function (or a partial of the decorator) """ # If this decorator has been called with parameters but no function, then we # return the decorator again (but with partially filled parameters). # This allows using both @register_loss and @register_loss(name="foo") if func is None: return functools.partial( register_loss, name=name, check_unused_kwargs=check_unused_kwargs) # No (further) arguments: this is the actual decorator # ensure that the loss function includes a **kwargs argument loss_name = name if name is not None else func.__name__ if not any(v.kind == inspect.Parameter.VAR_KEYWORD for k, v in inspect.signature(func).parameters.items()): raise TypeError( f"Loss function '{loss_name}' needs to include a **kwargs argument") func.name = loss_name func.check_unused_kwargs = check_unused_kwargs _LOSS_FUNCTIONS[loss_name] = func return func def compute_loss_term( loss_name, context_kwargs, config_kwargs): """Compute a loss function given its config and context parameters. Takes care of: - finding the correct loss function based on "loss_type" or name - the optional "weight" parameter - checking for typos and collisions in config parameters - adding the optional loss_aux_updates if omitted by the loss_fn Args: loss_name: Name of the loss, i.e. its key in the config.losses dict. context_kwargs: Dictionary of context variables (`preds` and `targets`) config_kwargs: The config dict for this loss. Returns: 1. the loss weight (float) 2. loss term (Array) 3. loss aux updates (Dict[str, Array]) Raises: KeyError: Unknown loss_type KeyError: Unused config entries, i.e. not used by the loss function. Not raised if using @register_loss(check_unused_kwargs=False) KeyError: Config entry with a name that conflicts with a context_kwarg ValueError: Non-numerical weight in config_kwargs """ # Make a dict copy of config_kwargs kwargs = {k: v for k, v in config_kwargs.items()} # Get the loss function loss_type = kwargs.pop("loss_type", loss_name) if loss_type not in _LOSS_FUNCTIONS: raise KeyError(f"Unknown loss_type '{loss_type}'.") loss_fn = _LOSS_FUNCTIONS[loss_type] # Take care of "weight" term weight = kwargs.pop("weight", 1.0) if not isinstance(weight, (int, float)): raise ValueError(f"Weight for loss {loss_name} should be a number, " f"but was {weight}.") # Check for unused config entries (to prevent typos etc.) config_keys = set(kwargs) if loss_fn.check_unused_kwargs: param_names = set(inspect.signature(loss_fn).parameters) unused_config_keys = config_keys - param_names if unused_config_keys: raise KeyError(f"Unrecognized config entries {unused_config_keys} " f"for loss {loss_name}.") # Check for key collisions between context and config conflicting_config_keys = config_keys.intersection(context_kwargs) if conflicting_config_keys: raise KeyError(f"The config keys {conflicting_config_keys} conflict " f"with the context parameters ({context_kwargs.keys()}) " f"for loss {loss_name}.") # Construct the arguments for the loss function kwargs.update(context_kwargs) kwargs["config_kwargs"] = config_kwargs # Call loss results = loss_fn(**kwargs) # Add empty loss_aux_updates if necessary if isinstance(results, Tuple): loss, loss_aux_update = results else: loss, loss_aux_update = results, {} return weight, loss, loss_aux_update # -------- Loss functions -------- @register_loss def recon(preds, targets, key = "video", reduction_type = "sum", **_): """Reconstruction loss (MSE).""" squared_l2_norm_fn = jax.vmap(functools.partial( squared_l2_norm, reduction_type=reduction_type)) targets = targets[key] loss = squared_l2_norm_fn(preds["outputs"][key], targets) if reduction_type == "mean": # This rescaling reflects taking the sum over feature axis & # mean over space/time axes. loss *= targets.shape[-1] # pytype: disable=attribute-error # allow-recursive-types return jnp.mean(loss) def squared_l2_norm(preds, targets, reduction_type = "sum"): if reduction_type == "sum": return jnp.sum(jnp.square(preds - targets)) elif reduction_type == "mean": return jnp.mean(jnp.square(preds - targets)) else: raise ValueError(f"Unsupported reduction_type: {reduction_type}")