|
from pytorch_lightning import Callback |
|
import copy |
|
import itertools |
|
import torch |
|
import contextlib |
|
from torch.distributed.fsdp import FullyShardedDataParallel |
|
|
|
|
|
class EMACallback(Callback): |
|
def __init__( |
|
self, |
|
module_attr_name, |
|
ema_module_attr_name, |
|
decay=0.999, |
|
start_ema_step=0, |
|
init_ema_random=True, |
|
): |
|
super().__init__() |
|
self.decay = decay |
|
self.module_attr_name = module_attr_name |
|
self.ema_module_attr_name = ema_module_attr_name |
|
self.start_ema_step = start_ema_step |
|
self.init_ema_random = init_ema_random |
|
|
|
def on_train_start(self, trainer, pl_module): |
|
if pl_module.global_step == 0: |
|
if not hasattr(pl_module, self.module_attr_name): |
|
raise ValueError( |
|
f"Module {pl_module} does not have attribute {self.module_attr_name}" |
|
) |
|
if not hasattr(pl_module, self.ema_module_attr_name): |
|
pl_module.add_module( |
|
self.ema_module_attr_name, |
|
copy.deepcopy(getattr(pl_module, self.module_attr_name)) |
|
.eval() |
|
.requires_grad_(False), |
|
) |
|
self.reset_ema(pl_module) |
|
|
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
|
if pl_module.global_step == self.start_ema_step: |
|
self.reset_ema(pl_module) |
|
elif ( |
|
pl_module.global_step < self.start_ema_step |
|
and pl_module.global_step % 100 == 0 |
|
): |
|
|
|
self.update_ema(pl_module, decay=0.9) |
|
elif pl_module.global_step > self.start_ema_step: |
|
self.update_ema(pl_module, decay=self.decay) |
|
|
|
def update_ema(self, pl_module, decay=0.999): |
|
ema_module = getattr(pl_module, self.ema_module_attr_name) |
|
module = getattr(pl_module, self.module_attr_name) |
|
context_manager = self.get_model_context_manager(module) |
|
with context_manager: |
|
with torch.no_grad(): |
|
ema_params = ema_module.state_dict() |
|
for name, param in itertools.chain( |
|
module.named_parameters(), module.named_buffers() |
|
): |
|
if name in ema_params: |
|
if param.requires_grad: |
|
ema_params[name].copy_( |
|
ema_params[name].detach().lerp(param.detach(), decay) |
|
) |
|
|
|
def get_model_context_manager(self, module): |
|
fsdp_enabled = is_model_fsdp(module) |
|
model_context_manager = contextlib.nullcontext() |
|
if fsdp_enabled: |
|
model_context_manager = module.summon_full_params(module) |
|
return model_context_manager |
|
|
|
def reset_ema(self, pl_module): |
|
ema_module = getattr(pl_module, self.ema_module_attr_name) |
|
if self.init_ema_random: |
|
ema_module.init_weights() |
|
else: |
|
module = getattr(pl_module, self.module_attr_name) |
|
context_manager = self.get_model_context_manager(module) |
|
with context_manager: |
|
ema_params = ema_module.state_dict() |
|
for name, param in itertools.chain( |
|
module.named_parameters(), module.named_buffers() |
|
): |
|
if name in ema_params: |
|
ema_params[name].copy_(param.detach()) |
|
|
|
|
|
def is_model_fsdp(model: torch.nn.Module) -> bool: |
|
try: |
|
if isinstance(model, FullyShardedDataParallel): |
|
return True |
|
|
|
|
|
for _, obj in model.named_children(): |
|
if isinstance(obj, FullyShardedDataParallel): |
|
return True |
|
return False |
|
except ImportError: |
|
return False |
|
|