wujun commited on
Commit
8e64547
·
1 Parent(s): 223265d
Files changed (3) hide show
  1. backslash.py +11 -4
  2. data/gamma_table.pt +3 -0
  3. data/r_gamma_table.pt +3 -0
backslash.py CHANGED
@@ -1,4 +1,11 @@
1
  import torch
 
 
 
 
 
 
 
2
  def backslash(model):
3
  with torch.no_grad():
4
  device = torch.device("cuda:0")
@@ -11,16 +18,16 @@ def backslash(model):
11
  var += torch.sum((param ** 2).to(device))
12
  mean += torch.sum(torch.abs(param).to(device))
13
  r_gamma = (n * var / mean ** 2).to(device=torch.device("cpu"))
14
- pos = torch.argmin(torch.abs(r_gamma - model.r_gamma_table))
15
- shape = model.gamma_table[pos]
16
  std = torch.sqrt(var / n)
17
  n = torch.tensor(n)
18
 
19
  # Rate Constrained Optimization
20
  for param in model.parameters():
21
- constant = model.rdo * shape / n * torch.sign(param.data)
22
  param_reg = torch.pow(
23
- torch.abs(param.data) + model.clip, shape - 1)
24
  param.data -= constant * param_reg
25
  distribution = {"shape": shape, "standard": std}
26
  return distribution
 
1
  import torch
2
+
3
+ rdo = 2e3
4
+ clip = 1
5
+ gamma_table = torch.load("data/gamma_table.pt")
6
+ r_gamma_table = torch.load("data/r_gamma_table.pt")
7
+
8
+
9
  def backslash(model):
10
  with torch.no_grad():
11
  device = torch.device("cuda:0")
 
18
  var += torch.sum((param ** 2).to(device))
19
  mean += torch.sum(torch.abs(param).to(device))
20
  r_gamma = (n * var / mean ** 2).to(device=torch.device("cpu"))
21
+ pos = torch.argmin(torch.abs(r_gamma - r_gamma_table))
22
+ shape = gamma_table[pos]
23
  std = torch.sqrt(var / n)
24
  n = torch.tensor(n)
25
 
26
  # Rate Constrained Optimization
27
  for param in model.parameters():
28
+ constant = rdo * shape / n * torch.sign(param.data)
29
  param_reg = torch.pow(
30
+ torch.abs(param.data) + clip, shape - 1)
31
  param.data -= constant * param_reg
32
  distribution = {"shape": shape, "standard": std}
33
  return distribution
data/gamma_table.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4525076935eaf656655b070bdbb390bf3b48088991b39894024bfca5a2c0225
3
+ size 2352
data/r_gamma_table.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:086241d0f92a1570315b68613c6de94eec3e048bc9fef8e2f3f83ac1180bb718
3
+ size 2362