File size: 565 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
from torch import nn
def convert_to_buffer(module: nn.Module, persistent: bool = True):
# Recurse over child modules.
for name, child in list(module.named_children()):
convert_to_buffer(child, persistent)
# Also re-save buffers to change persistence.
for name, parameter_or_buffer in (
*module.named_parameters(recurse=False),
*module.named_buffers(recurse=False),
):
value = parameter_or_buffer.detach().clone()
delattr(module, name)
module.register_buffer(name, value, persistent=persistent)
|