gg_prior / deepshape.py
wujun
Code Refactoring - Initial Version
223265d
raw
history blame
5.91 kB
import torch
from utils.encode.quantizer import LinearQuantizer
import math
from scipy.special import gamma as Gamma
import numpy as np
import dask.array as da
class DeepShape:
def __init__(self):
self.gamma_table = torch.load('utils/gamma_table.pt')
self.rho_table = torch.load('utils/rho_table.pt')
"""estimate GGD parameters"""
def Calc_GG_params(self, model, adj_minnum = 0):
#get parameters
params = []
for param in model.parameters():
params.append(param.flatten())
params = torch.cat(params).detach()
params_org = params.clone()
# Quantization
lq = LinearQuantizer(params, 13)
params = lq.quant(params)
#sorting
elements, counts = torch.unique(params, return_counts=True)
# dask_params = da.from_array(params.numpy(), chunks=int(1e8)) #if param's size is big
# elements, counts = da.unique(dask_params, return_counts=True)
# elements = torch.from_numpy(elements.compute())
# counts = torch.from_numpy(counts.compute())
indices = torch.argsort(counts, descending=True)
elements = elements[indices]
counts = counts[indices]
if adj_minnum > 0:
param_max = torch.min(elements[(counts<=adj_minnum) & (elements>0)]).long()
# print("param_max", (param_max/(2**13)))
# print('max_param, num_max_param', (elements[0]/(2**13)), counts[0])
elements_cut = params_org[torch.abs(params_org)<=(param_max.float()/(2**13))]
else:
elements_cut = params_org
#estimate
n = len(elements_cut)
var = torch.sum(torch.pow(elements_cut, 2))
mean = torch.sum(torch.abs(elements_cut))
self.gamma_table = self.gamma_table.to(elements_cut.device)
self.rho_table = self.rho_table.to(elements_cut.device)
rho = n * var / mean ** 2
pos = torch.argmin(torch.abs(rho - self.rho_table)).item()
shape = self.gamma_table[pos].item()
std = torch.sqrt(var / n)
beta = math.sqrt(Gamma(1/shape) / Gamma(3/shape))* std
mu = torch.mean(elements_cut)
print("mu:", mu)
print('shape:', shape)
print('beta',(beta))
return mu, shape, beta
"""GGD deepshape remap"""
def GGD_deepshape(self, model, shape_scale=0.8, std_scale=0.6, adj_minnum = 1000):
#get parameters
params = []
for param in model.parameters():
params.append(param.flatten())
params = torch.cat(params).detach()
params_org = params.clone()
# Quantization
lq = LinearQuantizer(params, 13)
params = lq.quant(params)
#sorting
elements, counts = torch.unique(params, return_counts=True)
indices = torch.argsort(counts, descending=True)
elements = elements[indices]
counts = counts[indices]
if adj_minnum > 0:
param_max = torch.min(elements[(counts<=adj_minnum) & (elements>0)]).long()
elements_cut = params_org[torch.abs(params_org)<=(param_max.float()/(2**13))]
else:
elements_cut = params_org
param_max=0
#estimate org GGD
n = len(elements_cut)
var = torch.sum(torch.pow(elements_cut, 2))
mean = torch.sum(torch.abs(elements_cut))
self.gamma_table = self.gamma_table.to(elements_cut.device)
self.rho_table = self.rho_table.to(elements_cut.device)
rho = n * var / mean ** 2
pos = torch.argmin(torch.abs(rho - self.rho_table)).item()
shape = self.gamma_table[pos].item()
std = torch.sqrt(var / n)
beta = math.sqrt(Gamma(1/shape) / Gamma(3/shape))* std
mu_est = torch.mean(elements_cut)
print("org mu:", mu_est)
print('org shape:', shape)
print('org beta',beta)
beta = (beta * (2**13))
mu_est = int(mu_est*(2**13))
#sorting params in [-param_pax, param_max]
if adj_minnum>0:
adj_indices = torch.nonzero((params>=mu_est-param_max)&(params<=mu_est+param_max), as_tuple=False).squeeze()
adj_indices = adj_indices[torch.argsort(params[(params>=mu_est-param_max)&(params<=mu_est+param_max)], descending=False)]
adj_num = len(adj_indices)
else:
adj_indices = torch.argsort(params, descending=False)
adj_num = len(adj_indices)
#remape new GGD
new_params = params.clone()
new_shape = shape * shape_scale
new_beta = beta * std_scale
if(beta<=0):
beta=1
x = torch.arange(mu_est-param_max, mu_est+param_max+1, device=params.device)
new_ratio = -torch.pow(torch.abs(x.float()-mu_est)/new_beta, new_shape)
new_ratio = torch.exp(new_ratio)
new_ratio = new_ratio / torch.sum(new_ratio)
new_num = (adj_num * new_ratio).long()
num_temp = 0
for i in range(0, 2*param_max+1):
new_params[adj_indices[num_temp : num_temp+new_num[i]]]=i+mu_est-param_max
num_temp += new_num[i]
new_params=new_params.float()/(2**13)
#modify model parameters
j=0
for name, param in model.named_parameters():
shape=param.data.shape
param_flatten = torch.flatten(param.data)
param_flatten = new_params[j: j+len(param_flatten)]
j+=len(param_flatten)
param_flatten = param_flatten.reshape(shape)
param.data= param_flatten
print("new mu:", float(mu_est)/(2**13))
print('new_shape:', new_shape)
print('new beta', float(new_beta)/(2**13))
return float(mu_est)/(2**13), new_shape, float(new_beta)/(2**13)