Spaces:
Running
Running
import torch | |
from torch.nn import init | |
from scipy.special import gamma as Gamma | |
from scipy.stats import gennorm | |
import numpy as np | |
def gg_init(model, shape=2, xi=2): | |
"""Generalized Gaussian Initialization for ReLU""" | |
# shape for the shape of parameter distribution | |
# xi = 1 for Sigmoid or no activation | |
# xi = 2 for ReLU | |
# xi = 2 / (1 + k^2) for LeakyReLU | |
with torch.no_grad(): | |
for name, param in model.named_parameters(): | |
param_device = param.device | |
param_dtype = param.dtype | |
if len(param.shape) == 2: | |
n_dim = param.shape[0] | |
alpha = np.sqrt(xi/n_dim*Gamma(1/shape) / Gamma(3/shape)) | |
gennorm_params = gennorm.rvs( | |
shape, loc=0, scale=alpha, size=param.shape) | |
param.data = torch.from_numpy(gennorm_params) | |
else: | |
if "weight" in name: | |
param.data = torch.ones(param.shape) | |
elif "bias" in name: | |
param.data = torch.zeros(param.shape) | |
param.data = param.data.to(param_dtype).to(param_device) | |