Spaces:
Running
on
Zero
Running
on
Zero
# ------------------------------------------------------------------------------------------------------------------------------------- | |
# Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): | |
# ------------------------------------------------------------------------------------------------------------------------------------- | |
import torch | |
import os | |
import os.path | |
import warnings | |
import pytorch_lightning as pl | |
from torch import Tensor | |
from pytorch_lightning import Callback | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_info | |
from pytorch_lightning.utilities.exceptions import MisconfigurationException | |
from pytorch_lightning.utilities.types import STEP_OUTPUT | |
from typing import Any, Dict, List, Optional | |
try: | |
import amp_C | |
apex_available = True | |
except Exception: | |
apex_available = False | |
class EMA(Callback): | |
""" | |
Implements Exponential Moving Averaging (EMA). | |
When training a model, this callback will maintain moving averages of the trained parameters. | |
When evaluating, we use the moving averages copy of the trained parameters. | |
When saving, we save an additional set of parameters with the prefix `ema`. | |
Args: | |
decay: The exponential decay used when calculating the moving average. Has to be between 0-1. | |
apply_ema_every_n_steps: Apply EMA every n global steps. | |
start_step: Start applying EMA from ``start_step`` global step onwards. | |
save_ema_weights_in_callback_state: Enable saving EMA weights in callback state. | |
evaluate_ema_weights_instead: Validate the EMA weights instead of the original weights. | |
Note this means that when saving the model, the validation metrics are calculated with the EMA weights. | |
Adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py | |
""" | |
def __init__( | |
self, | |
decay: float = 0.999, | |
apply_ema_every_n_steps: int = 1, | |
start_step: int = 0, | |
# else .ckpt will save a model weights copy in key 'callback' | |
save_ema_weights_in_callback_state: bool = False, | |
evaluate_ema_weights_instead: bool = True, | |
): | |
if not apex_available: | |
rank_zero_warn( | |
"EMA has better performance when Apex is installed: https://github.com/NVIDIA/apex#installation." | |
) | |
if not (0 <= decay <= 1): | |
raise MisconfigurationException("EMA decay value must be between 0 and 1") | |
self._ema_model_weights: Optional[List[torch.Tensor]] = None | |
self._overflow_buf: Optional[torch.Tensor] = None | |
self._cur_step: Optional[int] = None | |
self._weights_buffer: Optional[List[torch.Tensor]] = None | |
self.apply_ema_every_n_steps = apply_ema_every_n_steps | |
self.start_step = start_step | |
self.save_ema_weights_in_callback_state = save_ema_weights_in_callback_state | |
self.evaluate_ema_weights_instead = evaluate_ema_weights_instead | |
self.decay = decay | |
def on_train_start( | |
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" | |
) -> None: | |
rank_zero_info("Creating EMA weights copy.") | |
if self._ema_model_weights is None: | |
self._ema_model_weights = [ | |
p.detach().clone() for p in pl_module.state_dict().values() | |
] | |
# ensure that all the weights are on the correct device | |
self._ema_model_weights = [ | |
p.to(pl_module.device) for p in self._ema_model_weights | |
] | |
self._overflow_buf = torch.IntTensor([0]).to(pl_module.device) | |
def ema(self, pl_module: "pl.LightningModule") -> None: | |
if apex_available and pl_module.device.type == "cuda": | |
return self.apply_multi_tensor_ema(pl_module) | |
return self.apply_ema(pl_module) | |
def apply_multi_tensor_ema(self, pl_module: "pl.LightningModule") -> None: | |
model_weights = list(pl_module.state_dict().values()) | |
amp_C.multi_tensor_axpby( | |
65536, | |
self._overflow_buf, | |
[self._ema_model_weights, model_weights, self._ema_model_weights], | |
self.decay, | |
1 - self.decay, | |
-1, | |
) | |
def apply_ema(self, pl_module: "pl.LightningModule") -> None: | |
for orig_weight, ema_weight in zip( | |
list(pl_module.state_dict().values()), self._ema_model_weights | |
): | |
if ( | |
ema_weight.data.dtype != torch.long | |
and orig_weight.data.dtype != torch.long | |
): | |
# ensure that non-trainable parameters (e.g., feature distributions) are not included in EMA weight averaging | |
diff = ema_weight.data - orig_weight.data | |
diff.mul_(1.0 - self.decay) | |
ema_weight.sub_(diff) | |
def should_apply_ema(self, step: int) -> bool: | |
return ( | |
step != self._cur_step | |
and step >= self.start_step | |
and step % self.apply_ema_every_n_steps == 0 | |
) | |
def on_train_batch_end( | |
self, | |
trainer: "pl.Trainer", | |
pl_module: "pl.LightningModule", | |
outputs: STEP_OUTPUT, | |
batch: Any, | |
batch_idx: int, | |
) -> None: | |
if self.should_apply_ema(trainer.global_step): | |
self._cur_step = trainer.global_step | |
self.ema(pl_module) | |
def state_dict(self) -> Dict[str, Any]: | |
if self.save_ema_weights_in_callback_state: | |
return dict(cur_step=self._cur_step, ema_weights=self._ema_model_weights) | |
return dict(cur_step=self._cur_step) | |
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
self._cur_step = state_dict["cur_step"] | |
# when loading within apps such as NeMo, EMA weights will be loaded by the experiment manager separately | |
if self._ema_model_weights is None: | |
self._ema_model_weights = state_dict.get("ema_weights") | |
def on_load_checkpoint( | |
self, | |
trainer: "pl.Trainer", | |
pl_module: "pl.LightningModule", | |
checkpoint: Dict[str, Any], | |
) -> None: | |
checkpoint_callback = trainer.checkpoint_callback | |
if trainer.ckpt_path and checkpoint_callback is not None: | |
ext = checkpoint_callback.FILE_EXTENSION | |
if trainer.ckpt_path.endswith(f"-EMA{ext}"): | |
rank_zero_info( | |
"loading EMA based weights. " | |
"The callback will treat the loaded EMA weights as the main weights" | |
" and create a new EMA copy when training." | |
) | |
return | |
ema_path = trainer.ckpt_path.replace(ext, f"-EMA{ext}") | |
if os.path.exists(ema_path): | |
ema_state_dict = torch.load(ema_path, map_location=torch.device("cpu")) | |
self._ema_model_weights = ema_state_dict["state_dict"].values() | |
del ema_state_dict | |
rank_zero_info( | |
"EMA weights have been loaded successfully. Continuing training with saved EMA weights." | |
) | |
else: | |
warnings.warn( | |
"we were unable to find the associated EMA weights when re-loading, " | |
"training will start with new EMA weights.", | |
UserWarning, | |
) | |
def replace_model_weights(self, pl_module: "pl.LightningModule") -> None: | |
self._weights_buffer = [ | |
p.detach().clone().to("cpu") for p in pl_module.state_dict().values() | |
] | |
new_state_dict = { | |
k: v for k, v in zip(pl_module.state_dict().keys(), self._ema_model_weights) | |
} | |
pl_module.load_state_dict(new_state_dict) | |
def restore_original_weights(self, pl_module: "pl.LightningModule") -> None: | |
state_dict = pl_module.state_dict() | |
new_state_dict = {k: v for k, v in zip(state_dict.keys(), self._weights_buffer)} | |
pl_module.load_state_dict(new_state_dict) | |
del self._weights_buffer | |
def ema_initialized(self) -> bool: | |
return self._ema_model_weights is not None | |
def on_validation_start( | |
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" | |
) -> None: | |
if self.ema_initialized and self.evaluate_ema_weights_instead: | |
self.replace_model_weights(pl_module) | |
def on_validation_end( | |
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" | |
) -> None: | |
if self.ema_initialized and self.evaluate_ema_weights_instead: | |
self.restore_original_weights(pl_module) | |
def on_test_start( | |
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" | |
) -> None: | |
if self.ema_initialized and self.evaluate_ema_weights_instead: | |
self.replace_model_weights(pl_module) | |
def on_test_end( | |
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" | |
) -> None: | |
if self.ema_initialized and self.evaluate_ema_weights_instead: | |
self.restore_original_weights(pl_module) | |
class EMAModelCheckpoint(ModelCheckpoint): | |
""" | |
Light wrapper around Lightning's `ModelCheckpoint` to, upon request, save an EMA copy of the model as well. | |
Adapted from: https://github.com/NVIDIA/NeMo/blob/be0804f61e82dd0f63da7f9fe8a4d8388e330b18/nemo/utils/exp_manager.py#L744 | |
""" | |
def __init__(self, **kwargs): | |
# call the parent class constructor with the provided kwargs | |
super().__init__(**kwargs) | |
def _get_ema_callback(self, trainer: "pl.Trainer") -> Optional[EMA]: | |
ema_callback = None | |
for callback in trainer.callbacks: | |
if isinstance(callback, EMA): | |
ema_callback = callback | |
return ema_callback | |
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: | |
super()._save_checkpoint(trainer, filepath) | |
ema_callback = self._get_ema_callback(trainer) | |
if ema_callback is not None: | |
# save EMA copy of the model as well | |
ema_callback.replace_model_weights(trainer.lightning_module) | |
filepath = self._ema_format_filepath(filepath) | |
if self.verbose: | |
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") | |
os.makedirs(os.path.dirname(filepath), exist_ok=True) | |
super()._save_checkpoint(trainer, filepath) | |
ema_callback.restore_original_weights(trainer.lightning_module) | |
def _ema_format_filepath(self, filepath: str) -> str: | |
return filepath.replace(self.FILE_EXTENSION, f"-EMA{self.FILE_EXTENSION}") | |
# only change the last line | |
def _update_best_and_save( | |
self, | |
current: Tensor, | |
trainer: "pl.Trainer", | |
monitor_candidates: Dict[str, Tensor], | |
) -> None: | |
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k | |
del_filepath = None | |
if len(self.best_k_models) == k and k > 0: | |
del_filepath = self.kth_best_model_path | |
self.best_k_models.pop(del_filepath) | |
# do not save nan, replace with +/- inf | |
if isinstance(current, Tensor) and torch.isnan(current): | |
current = torch.tensor( | |
float("inf" if self.mode == "min" else "-inf"), device=current.device | |
) | |
filepath = self._get_metric_interpolated_filepath_name( | |
monitor_candidates, trainer, del_filepath | |
) | |
# save the current score | |
self.current_score = current | |
self.best_k_models[filepath] = current | |
if len(self.best_k_models) == k: | |
# monitor dict has reached k elements | |
_op = max if self.mode == "min" else min | |
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] | |
self.kth_value = self.best_k_models[self.kth_best_model_path] | |
_op = min if self.mode == "min" else max | |
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] | |
self.best_model_score = self.best_k_models[self.best_model_path] | |
if self.verbose: | |
epoch = monitor_candidates["epoch"] | |
step = monitor_candidates["step"] | |
rank_zero_info( | |
f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} reached {current:0.5f}" | |
f" (best {self.best_model_score:0.5f}), saving model to {filepath!r} as top {k}" | |
) | |
self._save_checkpoint(trainer, filepath) | |
if del_filepath is not None and filepath != del_filepath: | |
self._remove_checkpoint(trainer, del_filepath) | |
self._remove_checkpoint( | |
trainer, | |
del_filepath.replace(self.FILE_EXTENSION, f"-EMA{self.FILE_EXTENSION}"), | |
) | |