AnySplat / src /misc /nn_module_tools.py
alexnasa's picture
Upload 243 files
2568013 verified
raw
history blame contribute delete
565 Bytes
from torch import nn
def convert_to_buffer(module: nn.Module, persistent: bool = True):
# Recurse over child modules.
for name, child in list(module.named_children()):
convert_to_buffer(child, persistent)
# Also re-save buffers to change persistence.
for name, parameter_or_buffer in (
*module.named_parameters(recurse=False),
*module.named_buffers(recurse=False),
):
value = parameter_or_buffer.detach().clone()
delattr(module, name)
module.register_buffer(name, value, persistent=persistent)