Spaces:
Running
Running
wujun
commited on
Commit
·
223265d
1
Parent(s):
30b1d24
Code Refactoring - Initial Version
Browse files- README.md +40 -13
- backslash.py +26 -0
- deepshape.py +152 -0
- gg_init.py +30 -0
- index.html +0 -435
- rf8.py +20 -0
README.md
CHANGED
@@ -9,19 +9,46 @@ license: cc-by-nc-sa-4.0
|
|
9 |
short_description: The code of gg prior.
|
10 |
---
|
11 |
|
12 |
-
#
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
# Website License
|
27 |
<a rel="license" href="http://creativecommons.org/licenses/by-sa/4.0/"><img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by-sa/4.0/88x31.png" /></a><br />This work is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by-sa/4.0/">Creative Commons Attribution-ShareAlike 4.0 International License</a>.
|
|
|
9 |
short_description: The code of gg prior.
|
10 |
---
|
11 |
|
12 |
+
# Introduction
|
13 |
+
- Reference: It Takes a Good Model to Train a Good Model: Generalized Gaussian Priors for Optimized LLMs
|
14 |
+
- Authors: Jun Wu, Yirong Xiong, Jiangtao Wen, Yuxing Han
|
15 |
+
- Paper Link: [https://services.arxiv.org/html/submission/6499264/view](https://services.arxiv.org/html/submission/6499264/view)
|
16 |
+
|
17 |
+
This repository provides a complete implementation of the methods described in the corresponding paper. Specifically, we implement the Generalized Gaussian Initialization, DeepShape, and the RF8 floating-point format as proposed in the paper. Furthermore, we adapt and reproduce the BackSlash training algorithm, and incorporate it seamlessly into our framework based on generalized Gaussian priors.
|
18 |
+
|
19 |
+
|
20 |
+
# BackSlash
|
21 |
+
- Reference: BackSlash: Rate Constrained Optimized Training of Large Language Models
|
22 |
+
- Authors: Jun Wu, Jiangtao Wen, Yuxing Han
|
23 |
+
- Paper Link: [https://arxiv.org/abs/2504.16968](https://arxiv.org/abs/2504.16968)
|
24 |
+
|
25 |
+
We reproduced the BackSlash training algorithm based on the algorithm diagram provided in the source paper, and assisted us in conducting more in-depth research on generalized Gaussian priors.
|
26 |
+
|
27 |
+
In BackSlash, estimating the shape parameters of the model parameter distribution requires looking up the mapping between $\rho(\nu)$ and $\nu$. To achieve this, we precompute the values of $\nu$ and $\rho(\nu)$ over the interval $[0.1,\, 3.0]$ with a step size of $0.01$, and store them in `data/gamma_table.pt` (for $\nu$) and `data/r\_gamma\_table.pt` (for $\rho(\nu)$), respectively.
|
28 |
+
|
29 |
+
The code for reproducing the shape parameter estimation and the BackSlash algorithm is stored in the `backslash.py` module. During model training, after each batch iteration, the BackSlash function is invoked to perform rate suppression on the model parameters. After a few epochs of BackSlash training, we further fine-tune the model using several epochs of standard training. This helps the model achieve significantly improved performance while maintaining a low bit rate. The same procedure is applied consistently across all experiments.
|
30 |
+
|
31 |
+
# Generalized Gaussian Initialization
|
32 |
+
|
33 |
+
The Generalized Gaussian Initialization algorithm is implemented in the file `gg_init.py`. In practice, this initialization function is applied to all linear layers of the model prior to the start of training.
|
34 |
+
|
35 |
+
The generalized Gaussian initialization takes two parameters: `shape` and `xi`.
|
36 |
+
|
37 |
+
- `shape` represents the user-specified shape parameter for generalized Gaussian initialization, which affects the parameter distribution during initialization.
|
38 |
+
- `xi` represents the user-specified activation function coefficient, whose value is determined by the specific type of activation function.
|
39 |
+
|
40 |
+
# DeepShape
|
41 |
+
|
42 |
+
The implementation of DeepShape is preserved in the file `deepshape.py`. We apply the classical image processing technique "histogram equalization" to adjust the parameter distribution of post-trained models.
|
43 |
+
|
44 |
+
While DeepShape effectively compresses parameter bitrate, it inevitably impacts model prediction accuracy. This effect on model performance can be mitigated through a limited number of post-training epochs.
|
45 |
+
|
46 |
+
# 8-bit Residual Floating-Point Format
|
47 |
+
|
48 |
+
The implementation method of the RF8 floating-point format is documented in the file `rf8.py`. In this paper, we only quantize the model parameters to RF8 format before inference, without considering the training scenario or quantizing the activation values during forward propagation in inference.
|
49 |
+
|
50 |
+
In RF8 format, all model parameters will be preserved as only the sum of the first two significant digits in their binary representation. The multiplicative relationship between these two significant digits does not exceed $2^{4}$.
|
51 |
+
|
52 |
|
53 |
# Website License
|
54 |
<a rel="license" href="http://creativecommons.org/licenses/by-sa/4.0/"><img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by-sa/4.0/88x31.png" /></a><br />This work is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by-sa/4.0/">Creative Commons Attribution-ShareAlike 4.0 International License</a>.
|
backslash.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
def backslash(model):
|
3 |
+
with torch.no_grad():
|
4 |
+
device = torch.device("cuda:0")
|
5 |
+
|
6 |
+
# Evaluate the shape parameter
|
7 |
+
n, var, mean = 0, 0, 0
|
8 |
+
for param in model.parameters():
|
9 |
+
param = param.flatten().detach()
|
10 |
+
n += param.shape[0]
|
11 |
+
var += torch.sum((param ** 2).to(device))
|
12 |
+
mean += torch.sum(torch.abs(param).to(device))
|
13 |
+
r_gamma = (n * var / mean ** 2).to(device=torch.device("cpu"))
|
14 |
+
pos = torch.argmin(torch.abs(r_gamma - model.r_gamma_table))
|
15 |
+
shape = model.gamma_table[pos]
|
16 |
+
std = torch.sqrt(var / n)
|
17 |
+
n = torch.tensor(n)
|
18 |
+
|
19 |
+
# Rate Constrained Optimization
|
20 |
+
for param in model.parameters():
|
21 |
+
constant = model.rdo * shape / n * torch.sign(param.data)
|
22 |
+
param_reg = torch.pow(
|
23 |
+
torch.abs(param.data) + model.clip, shape - 1)
|
24 |
+
param.data -= constant * param_reg
|
25 |
+
distribution = {"shape": shape, "standard": std}
|
26 |
+
return distribution
|
deepshape.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from utils.encode.quantizer import LinearQuantizer
|
3 |
+
import math
|
4 |
+
from scipy.special import gamma as Gamma
|
5 |
+
import numpy as np
|
6 |
+
import dask.array as da
|
7 |
+
|
8 |
+
|
9 |
+
class DeepShape:
|
10 |
+
def __init__(self):
|
11 |
+
self.gamma_table = torch.load('utils/gamma_table.pt')
|
12 |
+
self.rho_table = torch.load('utils/rho_table.pt')
|
13 |
+
|
14 |
+
"""estimate GGD parameters"""
|
15 |
+
def Calc_GG_params(self, model, adj_minnum = 0):
|
16 |
+
#get parameters
|
17 |
+
params = []
|
18 |
+
for param in model.parameters():
|
19 |
+
params.append(param.flatten())
|
20 |
+
params = torch.cat(params).detach()
|
21 |
+
params_org = params.clone()
|
22 |
+
|
23 |
+
# Quantization
|
24 |
+
lq = LinearQuantizer(params, 13)
|
25 |
+
params = lq.quant(params)
|
26 |
+
|
27 |
+
#sorting
|
28 |
+
elements, counts = torch.unique(params, return_counts=True)
|
29 |
+
# dask_params = da.from_array(params.numpy(), chunks=int(1e8)) #if param's size is big
|
30 |
+
# elements, counts = da.unique(dask_params, return_counts=True)
|
31 |
+
# elements = torch.from_numpy(elements.compute())
|
32 |
+
# counts = torch.from_numpy(counts.compute())
|
33 |
+
indices = torch.argsort(counts, descending=True)
|
34 |
+
elements = elements[indices]
|
35 |
+
counts = counts[indices]
|
36 |
+
|
37 |
+
if adj_minnum > 0:
|
38 |
+
param_max = torch.min(elements[(counts<=adj_minnum) & (elements>0)]).long()
|
39 |
+
# print("param_max", (param_max/(2**13)))
|
40 |
+
# print('max_param, num_max_param', (elements[0]/(2**13)), counts[0])
|
41 |
+
elements_cut = params_org[torch.abs(params_org)<=(param_max.float()/(2**13))]
|
42 |
+
else:
|
43 |
+
elements_cut = params_org
|
44 |
+
|
45 |
+
#estimate
|
46 |
+
n = len(elements_cut)
|
47 |
+
var = torch.sum(torch.pow(elements_cut, 2))
|
48 |
+
mean = torch.sum(torch.abs(elements_cut))
|
49 |
+
self.gamma_table = self.gamma_table.to(elements_cut.device)
|
50 |
+
self.rho_table = self.rho_table.to(elements_cut.device)
|
51 |
+
rho = n * var / mean ** 2
|
52 |
+
pos = torch.argmin(torch.abs(rho - self.rho_table)).item()
|
53 |
+
shape = self.gamma_table[pos].item()
|
54 |
+
std = torch.sqrt(var / n)
|
55 |
+
beta = math.sqrt(Gamma(1/shape) / Gamma(3/shape))* std
|
56 |
+
mu = torch.mean(elements_cut)
|
57 |
+
print("mu:", mu)
|
58 |
+
print('shape:', shape)
|
59 |
+
print('beta',(beta))
|
60 |
+
|
61 |
+
return mu, shape, beta
|
62 |
+
|
63 |
+
|
64 |
+
"""GGD deepshape remap"""
|
65 |
+
def GGD_deepshape(self, model, shape_scale=0.8, std_scale=0.6, adj_minnum = 1000):
|
66 |
+
#get parameters
|
67 |
+
params = []
|
68 |
+
for param in model.parameters():
|
69 |
+
params.append(param.flatten())
|
70 |
+
params = torch.cat(params).detach()
|
71 |
+
params_org = params.clone()
|
72 |
+
|
73 |
+
# Quantization
|
74 |
+
lq = LinearQuantizer(params, 13)
|
75 |
+
params = lq.quant(params)
|
76 |
+
|
77 |
+
#sorting
|
78 |
+
elements, counts = torch.unique(params, return_counts=True)
|
79 |
+
indices = torch.argsort(counts, descending=True)
|
80 |
+
elements = elements[indices]
|
81 |
+
counts = counts[indices]
|
82 |
+
|
83 |
+
if adj_minnum > 0:
|
84 |
+
param_max = torch.min(elements[(counts<=adj_minnum) & (elements>0)]).long()
|
85 |
+
elements_cut = params_org[torch.abs(params_org)<=(param_max.float()/(2**13))]
|
86 |
+
else:
|
87 |
+
elements_cut = params_org
|
88 |
+
param_max=0
|
89 |
+
|
90 |
+
#estimate org GGD
|
91 |
+
n = len(elements_cut)
|
92 |
+
var = torch.sum(torch.pow(elements_cut, 2))
|
93 |
+
mean = torch.sum(torch.abs(elements_cut))
|
94 |
+
self.gamma_table = self.gamma_table.to(elements_cut.device)
|
95 |
+
self.rho_table = self.rho_table.to(elements_cut.device)
|
96 |
+
rho = n * var / mean ** 2
|
97 |
+
pos = torch.argmin(torch.abs(rho - self.rho_table)).item()
|
98 |
+
shape = self.gamma_table[pos].item()
|
99 |
+
std = torch.sqrt(var / n)
|
100 |
+
beta = math.sqrt(Gamma(1/shape) / Gamma(3/shape))* std
|
101 |
+
mu_est = torch.mean(elements_cut)
|
102 |
+
|
103 |
+
print("org mu:", mu_est)
|
104 |
+
print('org shape:', shape)
|
105 |
+
print('org beta',beta)
|
106 |
+
|
107 |
+
beta = (beta * (2**13))
|
108 |
+
mu_est = int(mu_est*(2**13))
|
109 |
+
|
110 |
+
#sorting params in [-param_pax, param_max]
|
111 |
+
if adj_minnum>0:
|
112 |
+
adj_indices = torch.nonzero((params>=mu_est-param_max)&(params<=mu_est+param_max), as_tuple=False).squeeze()
|
113 |
+
adj_indices = adj_indices[torch.argsort(params[(params>=mu_est-param_max)&(params<=mu_est+param_max)], descending=False)]
|
114 |
+
adj_num = len(adj_indices)
|
115 |
+
else:
|
116 |
+
adj_indices = torch.argsort(params, descending=False)
|
117 |
+
adj_num = len(adj_indices)
|
118 |
+
|
119 |
+
#remape new GGD
|
120 |
+
new_params = params.clone()
|
121 |
+
new_shape = shape * shape_scale
|
122 |
+
new_beta = beta * std_scale
|
123 |
+
if(beta<=0):
|
124 |
+
beta=1
|
125 |
+
|
126 |
+
x = torch.arange(mu_est-param_max, mu_est+param_max+1, device=params.device)
|
127 |
+
new_ratio = -torch.pow(torch.abs(x.float()-mu_est)/new_beta, new_shape)
|
128 |
+
new_ratio = torch.exp(new_ratio)
|
129 |
+
new_ratio = new_ratio / torch.sum(new_ratio)
|
130 |
+
new_num = (adj_num * new_ratio).long()
|
131 |
+
num_temp = 0
|
132 |
+
for i in range(0, 2*param_max+1):
|
133 |
+
new_params[adj_indices[num_temp : num_temp+new_num[i]]]=i+mu_est-param_max
|
134 |
+
num_temp += new_num[i]
|
135 |
+
new_params=new_params.float()/(2**13)
|
136 |
+
|
137 |
+
#modify model parameters
|
138 |
+
j=0
|
139 |
+
for name, param in model.named_parameters():
|
140 |
+
shape=param.data.shape
|
141 |
+
param_flatten = torch.flatten(param.data)
|
142 |
+
param_flatten = new_params[j: j+len(param_flatten)]
|
143 |
+
j+=len(param_flatten)
|
144 |
+
param_flatten = param_flatten.reshape(shape)
|
145 |
+
param.data= param_flatten
|
146 |
+
|
147 |
+
print("new mu:", float(mu_est)/(2**13))
|
148 |
+
print('new_shape:', new_shape)
|
149 |
+
print('new beta', float(new_beta)/(2**13))
|
150 |
+
return float(mu_est)/(2**13), new_shape, float(new_beta)/(2**13)
|
151 |
+
|
152 |
+
|
gg_init.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import init
|
3 |
+
from scipy.special import gamma as Gamma
|
4 |
+
from scipy.stats import gennorm
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def gg_init(model, shape=2, xi=2):
|
9 |
+
"""Generalized Gaussian Initialization for ReLU"""
|
10 |
+
# shape for the shape of parameter distribution
|
11 |
+
# xi = 1 for Sigmoid or no activation
|
12 |
+
# xi = 2 for ReLU
|
13 |
+
# xi = 2 / (1 + k^2) for LeakyReLU
|
14 |
+
with torch.no_grad():
|
15 |
+
for name, param in model.named_parameters():
|
16 |
+
param_device = param.device
|
17 |
+
param_dtype = param.dtype
|
18 |
+
if len(param.shape) == 2:
|
19 |
+
n_dim = param.shape[0]
|
20 |
+
alpha = np.sqrt(xi/n_dim*Gamma(1/shape) / Gamma(3/shape))
|
21 |
+
gennorm_params = gennorm.rvs(
|
22 |
+
shape, loc=0, scale=alpha, size=param.shape)
|
23 |
+
param.data = torch.from_numpy(gennorm_params)
|
24 |
+
else:
|
25 |
+
if "weight" in name:
|
26 |
+
param.data = torch.ones(param.shape)
|
27 |
+
elif "bias" in name:
|
28 |
+
param.data = torch.zeros(param.shape)
|
29 |
+
|
30 |
+
param.data = param.data.to(param_dtype).to(param_device)
|
index.html
DELETED
@@ -1,435 +0,0 @@
|
|
1 |
-
<!DOCTYPE html>
|
2 |
-
<html>
|
3 |
-
<head>
|
4 |
-
<meta charset="utf-8">
|
5 |
-
<meta name="description"
|
6 |
-
content="Deformable Neural Radiance Fields creates free-viewpoint portraits (nerfies) from casually captured videos.">
|
7 |
-
<meta name="keywords" content="Nerfies, D-NeRF, NeRF">
|
8 |
-
<meta name="viewport" content="width=device-width, initial-scale=1">
|
9 |
-
<title>Nerfies: Deformable Neural Radiance Fields</title>
|
10 |
-
|
11 |
-
<link href="https://fonts.googleapis.com/css?family=Google+Sans|Noto+Sans|Castoro"
|
12 |
-
rel="stylesheet">
|
13 |
-
|
14 |
-
<link rel="stylesheet" href="./static/css/bulma.min.css">
|
15 |
-
<link rel="stylesheet" href="./static/css/bulma-carousel.min.css">
|
16 |
-
<link rel="stylesheet" href="./static/css/bulma-slider.min.css">
|
17 |
-
<link rel="stylesheet" href="./static/css/fontawesome.all.min.css">
|
18 |
-
<link rel="stylesheet"
|
19 |
-
href="https://cdn.jsdelivr.net/gh/jpswalsh/academicons@1/css/academicons.min.css">
|
20 |
-
<link rel="stylesheet" href="./static/css/index.css">
|
21 |
-
<link rel="icon" href="./static/images/favicon.svg">
|
22 |
-
|
23 |
-
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
|
24 |
-
<script defer src="./static/js/fontawesome.all.min.js"></script>
|
25 |
-
<script src="./static/js/bulma-carousel.min.js"></script>
|
26 |
-
<script src="./static/js/bulma-slider.min.js"></script>
|
27 |
-
<script src="./static/js/index.js"></script>
|
28 |
-
</head>
|
29 |
-
<body>
|
30 |
-
|
31 |
-
<section class="hero">
|
32 |
-
<div class="hero-body">
|
33 |
-
<div class="container is-max-desktop">
|
34 |
-
<div class="columns is-centered">
|
35 |
-
<div class="column has-text-centered">
|
36 |
-
<h1 class="title is-1 publication-title">Nerfies: Deformable Neural Radiance Fields</h1>
|
37 |
-
<div class="is-size-5 publication-authors">
|
38 |
-
<span class="author-block">
|
39 |
-
<a href="https://keunhong.com" target="_blank">Keunhong Park</a><sup>1</sup>,</span>
|
40 |
-
<span class="author-block">
|
41 |
-
<a href="https://utkarshsinha.com" target="_blank">Utkarsh Sinha</a><sup>2</sup>,</span>
|
42 |
-
<span class="author-block">
|
43 |
-
<a href="https://jonbarron.info" target="_blank">Jonathan T. Barron</a><sup>2</sup>,
|
44 |
-
</span>
|
45 |
-
<span class="author-block">
|
46 |
-
<a href="http://sofienbouaziz.com" target="_blank">Sofien Bouaziz</a><sup>2</sup>,
|
47 |
-
</span>
|
48 |
-
<span class="author-block">
|
49 |
-
<a href="https://www.danbgoldman.com" target="_blank">Dan B Goldman</a><sup>2</sup>,
|
50 |
-
</span>
|
51 |
-
<span class="author-block">
|
52 |
-
<a href="https://homes.cs.washington.edu/~seitz/" target="_blank">Steven M. Seitz</a><sup>1,2</sup>,
|
53 |
-
</span>
|
54 |
-
<span class="author-block">
|
55 |
-
<a href="http://www.ricardomartinbrualla.com" target="_blank">Ricardo Martin-Brualla</a><sup>2</sup>
|
56 |
-
</span>
|
57 |
-
</div>
|
58 |
-
|
59 |
-
<div class="is-size-5 publication-authors">
|
60 |
-
<span class="author-block"><sup>1</sup>University of Washington,</span>
|
61 |
-
<span class="author-block"><sup>2</sup>Google Research</span>
|
62 |
-
</div>
|
63 |
-
|
64 |
-
<div class="column has-text-centered">
|
65 |
-
<div class="publication-links">
|
66 |
-
<!-- PDF Link. -->
|
67 |
-
<span class="link-block">
|
68 |
-
<a href="https://arxiv.org/pdf/2011.12948" target="_blank"
|
69 |
-
class="external-link button is-normal is-rounded is-dark">
|
70 |
-
<span class="icon">
|
71 |
-
<i class="fas fa-file-pdf"></i>
|
72 |
-
</span>
|
73 |
-
<span>Paper</span>
|
74 |
-
</a>
|
75 |
-
</span>
|
76 |
-
<span class="link-block">
|
77 |
-
<a href="https://arxiv.org/abs/2011.12948" target="_blank"
|
78 |
-
class="external-link button is-normal is-rounded is-dark">
|
79 |
-
<span class="icon">
|
80 |
-
<i class="ai ai-arxiv"></i>
|
81 |
-
</span>
|
82 |
-
<span>arXiv</span>
|
83 |
-
</a>
|
84 |
-
</span>
|
85 |
-
<!-- Video Link. -->
|
86 |
-
<span class="link-block">
|
87 |
-
<a href="https://www.youtube.com/watch?v=MrKrnHhk8IA" target="_blank"
|
88 |
-
class="external-link button is-normal is-rounded is-dark">
|
89 |
-
<span class="icon">
|
90 |
-
<i class="fab fa-youtube"></i>
|
91 |
-
</span>
|
92 |
-
<span>Video</span>
|
93 |
-
</a>
|
94 |
-
</span>
|
95 |
-
<!-- Code Link. -->
|
96 |
-
<span class="link-block">
|
97 |
-
<a href="https://github.com/google/nerfies" target="_blank"
|
98 |
-
class="external-link button is-normal is-rounded is-dark">
|
99 |
-
<span class="icon">
|
100 |
-
<i class="fab fa-github"></i>
|
101 |
-
</span>
|
102 |
-
<span>Code</span>
|
103 |
-
</a>
|
104 |
-
</span>
|
105 |
-
<!-- Dataset Link. -->
|
106 |
-
<span class="link-block">
|
107 |
-
<a href="https://github.com/google/nerfies/releases/tag/0.1" target="_blank"
|
108 |
-
class="external-link button is-normal is-rounded is-dark">
|
109 |
-
<span class="icon">
|
110 |
-
<i class="far fa-images"></i>
|
111 |
-
</span>
|
112 |
-
<span>Data</span>
|
113 |
-
</a>
|
114 |
-
</div>
|
115 |
-
|
116 |
-
</div>
|
117 |
-
</div>
|
118 |
-
</div>
|
119 |
-
</div>
|
120 |
-
</div>
|
121 |
-
</section>
|
122 |
-
|
123 |
-
<section class="hero teaser">
|
124 |
-
<div class="container is-max-desktop">
|
125 |
-
<div class="hero-body">
|
126 |
-
<video id="teaser" autoplay muted loop playsinline height="100%">
|
127 |
-
<source src="./static/videos/teaser.mp4"
|
128 |
-
type="video/mp4">
|
129 |
-
</video>
|
130 |
-
<h2 class="subtitle has-text-centered">
|
131 |
-
<span class="dnerf">Nerfies</span> turns selfie videos from your phone into
|
132 |
-
free-viewpoint
|
133 |
-
portraits.
|
134 |
-
</h2>
|
135 |
-
</div>
|
136 |
-
</div>
|
137 |
-
</section>
|
138 |
-
|
139 |
-
|
140 |
-
<section class="hero is-light is-small">
|
141 |
-
<div class="hero-body">
|
142 |
-
<div class="container">
|
143 |
-
<div id="results-carousel" class="carousel results-carousel">
|
144 |
-
<div class="item item-steve">
|
145 |
-
<video poster="" id="steve" autoplay controls muted loop playsinline height="100%">
|
146 |
-
<source src="./static/videos/steve.mp4"
|
147 |
-
type="video/mp4">
|
148 |
-
</video>
|
149 |
-
</div>
|
150 |
-
<div class="item item-chair-tp">
|
151 |
-
<video poster="" id="chair-tp" autoplay controls muted loop playsinline height="100%">
|
152 |
-
<source src="./static/videos/chair-tp.mp4"
|
153 |
-
type="video/mp4">
|
154 |
-
</video>
|
155 |
-
</div>
|
156 |
-
<div class="item item-shiba">
|
157 |
-
<video poster="" id="shiba" autoplay controls muted loop playsinline height="100%">
|
158 |
-
<source src="./static/videos/shiba.mp4"
|
159 |
-
type="video/mp4">
|
160 |
-
</video>
|
161 |
-
</div>
|
162 |
-
<div class="item item-fullbody">
|
163 |
-
<video poster="" id="fullbody" autoplay controls muted loop playsinline height="100%">
|
164 |
-
<source src="./static/videos/fullbody.mp4"
|
165 |
-
type="video/mp4">
|
166 |
-
</video>
|
167 |
-
</div>
|
168 |
-
<div class="item item-blueshirt">
|
169 |
-
<video poster="" id="blueshirt" autoplay controls muted loop playsinline height="100%">
|
170 |
-
<source src="./static/videos/blueshirt.mp4"
|
171 |
-
type="video/mp4">
|
172 |
-
</video>
|
173 |
-
</div>
|
174 |
-
<div class="item item-mask">
|
175 |
-
<video poster="" id="mask" autoplay controls muted loop playsinline height="100%">
|
176 |
-
<source src="./static/videos/mask.mp4"
|
177 |
-
type="video/mp4">
|
178 |
-
</video>
|
179 |
-
</div>
|
180 |
-
<div class="item item-coffee">
|
181 |
-
<video poster="" id="coffee" autoplay controls muted loop playsinline height="100%">
|
182 |
-
<source src="./static/videos/coffee.mp4"
|
183 |
-
type="video/mp4">
|
184 |
-
</video>
|
185 |
-
</div>
|
186 |
-
<div class="item item-toby">
|
187 |
-
<video poster="" id="toby" autoplay controls muted loop playsinline height="100%">
|
188 |
-
<source src="./static/videos/toby2.mp4"
|
189 |
-
type="video/mp4">
|
190 |
-
</video>
|
191 |
-
</div>
|
192 |
-
</div>
|
193 |
-
</div>
|
194 |
-
</div>
|
195 |
-
</section>
|
196 |
-
|
197 |
-
|
198 |
-
<section class="section">
|
199 |
-
<div class="container is-max-desktop">
|
200 |
-
<!-- Abstract. -->
|
201 |
-
<div class="columns is-centered has-text-centered">
|
202 |
-
<div class="column is-four-fifths">
|
203 |
-
<h2 class="title is-3">Abstract</h2>
|
204 |
-
<div class="content has-text-justified">
|
205 |
-
<p>
|
206 |
-
We present the first method capable of photorealistically reconstructing a non-rigidly
|
207 |
-
deforming scene using photos/videos captured casually from mobile phones.
|
208 |
-
</p>
|
209 |
-
<p>
|
210 |
-
Our approach augments neural radiance fields
|
211 |
-
(NeRF) by optimizing an
|
212 |
-
additional continuous volumetric deformation field that warps each observed point into a
|
213 |
-
canonical 5D NeRF.
|
214 |
-
We observe that these NeRF-like deformation fields are prone to local minima, and
|
215 |
-
propose a coarse-to-fine optimization method for coordinate-based models that allows for
|
216 |
-
more robust optimization.
|
217 |
-
By adapting principles from geometry processing and physical simulation to NeRF-like
|
218 |
-
models, we propose an elastic regularization of the deformation field that further
|
219 |
-
improves robustness.
|
220 |
-
</p>
|
221 |
-
<p>
|
222 |
-
We show that <span class="dnerf">Nerfies</span> can turn casually captured selfie
|
223 |
-
photos/videos into deformable NeRF
|
224 |
-
models that allow for photorealistic renderings of the subject from arbitrary
|
225 |
-
viewpoints, which we dub <i>"nerfies"</i>. We evaluate our method by collecting data
|
226 |
-
using a
|
227 |
-
rig with two mobile phones that take time-synchronized photos, yielding train/validation
|
228 |
-
images of the same pose at different viewpoints. We show that our method faithfully
|
229 |
-
reconstructs non-rigidly deforming scenes and reproduces unseen views with high
|
230 |
-
fidelity.
|
231 |
-
</p>
|
232 |
-
</div>
|
233 |
-
</div>
|
234 |
-
</div>
|
235 |
-
<!--/ Abstract. -->
|
236 |
-
|
237 |
-
<!-- Paper video. -->
|
238 |
-
<div class="columns is-centered has-text-centered">
|
239 |
-
<div class="column is-four-fifths">
|
240 |
-
<h2 class="title is-3">Video</h2>
|
241 |
-
<div class="publication-video">
|
242 |
-
<iframe src="https://www.youtube.com/embed/MrKrnHhk8IA?rel=0&showinfo=0"
|
243 |
-
frameborder="0" allow="autoplay; encrypted-media" allowfullscreen></iframe>
|
244 |
-
</div>
|
245 |
-
</div>
|
246 |
-
</div>
|
247 |
-
<!--/ Paper video. -->
|
248 |
-
</div>
|
249 |
-
</section>
|
250 |
-
|
251 |
-
|
252 |
-
<section class="section">
|
253 |
-
<div class="container is-max-desktop">
|
254 |
-
|
255 |
-
<div class="columns is-centered">
|
256 |
-
|
257 |
-
<!-- Visual Effects. -->
|
258 |
-
<div class="column">
|
259 |
-
<div class="content">
|
260 |
-
<h2 class="title is-3">Visual Effects</h2>
|
261 |
-
<p>
|
262 |
-
Using <i>nerfies</i> you can create fun visual effects. This Dolly zoom effect
|
263 |
-
would be impossible without nerfies since it would require going through a wall.
|
264 |
-
</p>
|
265 |
-
<video id="dollyzoom" autoplay controls muted loop playsinline height="100%">
|
266 |
-
<source src="./static/videos/dollyzoom-stacked.mp4"
|
267 |
-
type="video/mp4">
|
268 |
-
</video>
|
269 |
-
</div>
|
270 |
-
</div>
|
271 |
-
<!--/ Visual Effects. -->
|
272 |
-
|
273 |
-
<!-- Matting. -->
|
274 |
-
<div class="column">
|
275 |
-
<h2 class="title is-3">Matting</h2>
|
276 |
-
<div class="columns is-centered">
|
277 |
-
<div class="column content">
|
278 |
-
<p>
|
279 |
-
As a byproduct of our method, we can also solve the matting problem by ignoring
|
280 |
-
samples that fall outside of a bounding box during rendering.
|
281 |
-
</p>
|
282 |
-
<video id="matting-video" controls playsinline height="100%">
|
283 |
-
<source src="./static/videos/matting.mp4"
|
284 |
-
type="video/mp4">
|
285 |
-
</video>
|
286 |
-
</div>
|
287 |
-
|
288 |
-
</div>
|
289 |
-
</div>
|
290 |
-
</div>
|
291 |
-
<!--/ Matting. -->
|
292 |
-
|
293 |
-
<!-- Animation. -->
|
294 |
-
<div class="columns is-centered">
|
295 |
-
<div class="column is-full-width">
|
296 |
-
<h2 class="title is-3">Animation</h2>
|
297 |
-
|
298 |
-
<!-- Interpolating. -->
|
299 |
-
<h3 class="title is-4">Interpolating states</h3>
|
300 |
-
<div class="content has-text-justified">
|
301 |
-
<p>
|
302 |
-
We can also animate the scene by interpolating the deformation latent codes of two input
|
303 |
-
frames. Use the slider here to linearly interpolate between the left frame and the right
|
304 |
-
frame.
|
305 |
-
</p>
|
306 |
-
</div>
|
307 |
-
<div class="columns is-vcentered interpolation-panel">
|
308 |
-
<div class="column is-3 has-text-centered">
|
309 |
-
<img src="./static/images/interpolate_start.jpg"
|
310 |
-
class="interpolation-image"
|
311 |
-
alt="Interpolate start reference image."/>
|
312 |
-
<p>Start Frame</p>
|
313 |
-
</div>
|
314 |
-
<div class="column interpolation-video-column">
|
315 |
-
<div id="interpolation-image-wrapper">
|
316 |
-
Loading...
|
317 |
-
</div>
|
318 |
-
<input class="slider is-fullwidth is-large is-info"
|
319 |
-
id="interpolation-slider"
|
320 |
-
step="1" min="0" max="100" value="0" type="range">
|
321 |
-
</div>
|
322 |
-
<div class="column is-3 has-text-centered">
|
323 |
-
<img src="./static/images/interpolate_end.jpg"
|
324 |
-
class="interpolation-image"
|
325 |
-
alt="Interpolation end reference image."/>
|
326 |
-
<p class="is-bold">End Frame</p>
|
327 |
-
</div>
|
328 |
-
</div>
|
329 |
-
<br/>
|
330 |
-
<!--/ Interpolating. -->
|
331 |
-
|
332 |
-
<!-- Re-rendering. -->
|
333 |
-
<h3 class="title is-4">Re-rendering the input video</h3>
|
334 |
-
<div class="content has-text-justified">
|
335 |
-
<p>
|
336 |
-
Using <span class="dnerf">Nerfies</span>, you can re-render a video from a novel
|
337 |
-
viewpoint such as a stabilized camera by playing back the training deformations.
|
338 |
-
</p>
|
339 |
-
</div>
|
340 |
-
<div class="content has-text-centered">
|
341 |
-
<video id="replay-video"
|
342 |
-
controls
|
343 |
-
muted
|
344 |
-
preload
|
345 |
-
playsinline
|
346 |
-
width="75%">
|
347 |
-
<source src="./static/videos/replay.mp4"
|
348 |
-
type="video/mp4">
|
349 |
-
</video>
|
350 |
-
</div>
|
351 |
-
<!--/ Re-rendering. -->
|
352 |
-
|
353 |
-
</div>
|
354 |
-
</div>
|
355 |
-
<!--/ Animation. -->
|
356 |
-
|
357 |
-
|
358 |
-
<!-- Concurrent Work. -->
|
359 |
-
<div class="columns is-centered">
|
360 |
-
<div class="column is-full-width">
|
361 |
-
<h2 class="title is-3">Related Links</h2>
|
362 |
-
|
363 |
-
<div class="content has-text-justified">
|
364 |
-
<p>
|
365 |
-
There's a lot of excellent work that was introduced around the same time as ours.
|
366 |
-
</p>
|
367 |
-
<p>
|
368 |
-
<a href="https://arxiv.org/abs/2104.09125" target="_blank">Progressive Encoding for Neural Optimization</a> introduces an idea similar to our windowed position encoding for coarse-to-fine optimization.
|
369 |
-
</p>
|
370 |
-
<p>
|
371 |
-
<a href="https://www.albertpumarola.com/research/D-NeRF/index.html" target="_blank">D-NeRF</a> and <a href="https://gvv.mpi-inf.mpg.de/projects/nonrigid_nerf/" target="_blank">NR-NeRF</a>
|
372 |
-
both use deformation fields to model non-rigid scenes.
|
373 |
-
</p>
|
374 |
-
<p>
|
375 |
-
Some works model videos with a NeRF by directly modulating the density, such as <a href="https://video-nerf.github.io/" target="_blank">Video-NeRF</a>, <a href="https://www.cs.cornell.edu/~zl548/NSFF/" target="_blank">NSFF</a>, and <a href="https://neural-3d-video.github.io/" target="_blank">DyNeRF</a>
|
376 |
-
</p>
|
377 |
-
<p>
|
378 |
-
There are probably many more by the time you are reading this. Check out <a href="https://dellaert.github.io/NeRF/" target="_blank">Frank Dellart's survey on recent NeRF papers</a>, and <a href="https://github.com/yenchenlin/awesome-NeRF" target="_blank">Yen-Chen Lin's curated list of NeRF papers</a>.
|
379 |
-
</p>
|
380 |
-
</div>
|
381 |
-
</div>
|
382 |
-
</div>
|
383 |
-
<!--/ Concurrent Work. -->
|
384 |
-
|
385 |
-
</div>
|
386 |
-
</section>
|
387 |
-
|
388 |
-
|
389 |
-
<section class="section" id="BibTeX">
|
390 |
-
<div class="container is-max-desktop content">
|
391 |
-
<h2 class="title">BibTeX</h2>
|
392 |
-
<pre><code>@article{park2021nerfies,
|
393 |
-
author = {Park, Keunhong and Sinha, Utkarsh and Barron, Jonathan T. and Bouaziz, Sofien and Goldman, Dan B and Seitz, Steven M. and Martin-Brualla, Ricardo},
|
394 |
-
title = {Nerfies: Deformable Neural Radiance Fields},
|
395 |
-
journal = {ICCV},
|
396 |
-
year = {2021},
|
397 |
-
}</code></pre>
|
398 |
-
</div>
|
399 |
-
</section>
|
400 |
-
|
401 |
-
|
402 |
-
<footer class="footer">
|
403 |
-
<div class="container">
|
404 |
-
<div class="content has-text-centered">
|
405 |
-
<a class="icon-link" target="_blank"
|
406 |
-
href="./static/videos/nerfies_paper.pdf">
|
407 |
-
<i class="fas fa-file-pdf"></i>
|
408 |
-
</a>
|
409 |
-
<a class="icon-link" href="https://github.com/keunhong" target="_blank" class="external-link" disabled>
|
410 |
-
<i class="fab fa-github"></i>
|
411 |
-
</a>
|
412 |
-
</div>
|
413 |
-
<div class="columns is-centered">
|
414 |
-
<div class="column is-8">
|
415 |
-
<div class="content">
|
416 |
-
<p>
|
417 |
-
This website is licensed under a <a rel="license" target="_blank"
|
418 |
-
href="http://creativecommons.org/licenses/by-sa/4.0/">Creative
|
419 |
-
Commons Attribution-ShareAlike 4.0 International License</a>.
|
420 |
-
</p>
|
421 |
-
<p>
|
422 |
-
This means you are free to borrow the <a target="_blank"
|
423 |
-
href="https://github.com/nerfies/nerfies.github.io">source code</a> of this website,
|
424 |
-
we just ask that you link back to this page in the footer.
|
425 |
-
Please remember to remove the analytics code included in the header of the website which
|
426 |
-
you do not want on your website.
|
427 |
-
</p>
|
428 |
-
</div>
|
429 |
-
</div>
|
430 |
-
</div>
|
431 |
-
</div>
|
432 |
-
</footer>
|
433 |
-
|
434 |
-
</body>
|
435 |
-
</html>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rf8.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
def get_residual(weights):
|
6 |
+
"""Get the order of the first significant digit of the tensors"""
|
7 |
+
signs = torch.sign(weights)
|
8 |
+
exps = torch.round(torch.log2(torch.abs(weights)))
|
9 |
+
pow_weights = signs * torch.pow(2, exps)
|
10 |
+
return pow_weights, exps
|
11 |
+
|
12 |
+
|
13 |
+
def rf8(model, n=4):
|
14 |
+
"""Residual Float-Point 8-bit Model Quantization"""
|
15 |
+
with torch.no_grad():
|
16 |
+
for param in model.parameters():
|
17 |
+
data1, exps1 = get_residual(param.data)
|
18 |
+
data2, exps2 = get_residual(param.data - data1)
|
19 |
+
flags = (exps1-exps2 <= n)
|
20 |
+
param.data = data1 + flags * data2
|