Spaces:
Runtime error
Runtime error
| import torch | |
| from collections import OrderedDict | |
| def update_ema(ema_model, model, decay=0.9999): | |
| """ | |
| Step the EMA model towards the current model. | |
| """ | |
| ema_params = OrderedDict(ema_model.named_parameters()) | |
| model_params = OrderedDict(model.named_parameters()) | |
| for name, param in model_params.items(): | |
| # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed | |
| ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) | |
| def requires_grad(model, flag=True): | |
| """ | |
| Set requires_grad flag for all parameters in a model. | |
| """ | |
| for p in model.parameters(): | |
| p.requires_grad = flag |