from torch import nn def convert_to_buffer(module: nn.Module, persistent: bool = True): # Recurse over child modules. for name, child in list(module.named_children()): convert_to_buffer(child, persistent) # Also re-save buffers to change persistence. for name, parameter_or_buffer in ( *module.named_parameters(recurse=False), *module.named_buffers(recurse=False), ): value = parameter_or_buffer.detach().clone() delattr(module, name) module.register_buffer(name, value, persistent=persistent)