from torch import nn | |
def convert_to_buffer(module: nn.Module, persistent: bool = True): | |
# Recurse over child modules. | |
for name, child in list(module.named_children()): | |
convert_to_buffer(child, persistent) | |
# Also re-save buffers to change persistence. | |
for name, parameter_or_buffer in ( | |
*module.named_parameters(recurse=False), | |
*module.named_buffers(recurse=False), | |
): | |
value = parameter_or_buffer.detach().clone() | |
delattr(module, name) | |
module.register_buffer(name, value, persistent=persistent) | |