|
import math |
|
import os |
|
import time |
|
from pathlib import Path |
|
from typing import Literal, Optional |
|
|
|
import numpy as np |
|
import torch |
|
import tyro |
|
from PIL import Image |
|
from torch import Tensor, optim |
|
|
|
from gsplat import rasterization, rasterization_2dgs |
|
|
|
|
|
class SimpleTrainer: |
|
"""Trains random gaussians to fit an image.""" |
|
|
|
def __init__( |
|
self, |
|
gt_image: Tensor, |
|
num_points: int = 2000, |
|
): |
|
self.device = torch.device("cuda:0") |
|
self.gt_image = gt_image.to(device=self.device) |
|
self.num_points = num_points |
|
|
|
fov_x = math.pi / 2.0 |
|
self.H, self.W = gt_image.shape[0], gt_image.shape[1] |
|
self.focal = 0.5 * float(self.W) / math.tan(0.5 * fov_x) |
|
self.img_size = torch.tensor([self.W, self.H, 1], device=self.device) |
|
|
|
self._init_gaussians() |
|
|
|
def _init_gaussians(self): |
|
"""Random gaussians""" |
|
bd = 2 |
|
|
|
self.means = bd * (torch.rand(self.num_points, 3, device=self.device) - 0.5) |
|
self.scales = torch.rand(self.num_points, 3, device=self.device) |
|
d = 3 |
|
self.rgbs = torch.rand(self.num_points, d, device=self.device) |
|
|
|
u = torch.rand(self.num_points, 1, device=self.device) |
|
v = torch.rand(self.num_points, 1, device=self.device) |
|
w = torch.rand(self.num_points, 1, device=self.device) |
|
|
|
self.quats = torch.cat( |
|
[ |
|
torch.sqrt(1.0 - u) * torch.sin(2.0 * math.pi * v), |
|
torch.sqrt(1.0 - u) * torch.cos(2.0 * math.pi * v), |
|
torch.sqrt(u) * torch.sin(2.0 * math.pi * w), |
|
torch.sqrt(u) * torch.cos(2.0 * math.pi * w), |
|
], |
|
-1, |
|
) |
|
self.opacities = torch.ones((self.num_points), device=self.device) |
|
|
|
self.viewmat = torch.tensor( |
|
[ |
|
[1.0, 0.0, 0.0, 0.0], |
|
[0.0, 1.0, 0.0, 0.0], |
|
[0.0, 0.0, 1.0, 8.0], |
|
[0.0, 0.0, 0.0, 1.0], |
|
], |
|
device=self.device, |
|
) |
|
self.background = torch.zeros(d, device=self.device) |
|
|
|
self.means.requires_grad = True |
|
self.scales.requires_grad = True |
|
self.quats.requires_grad = True |
|
self.rgbs.requires_grad = True |
|
self.opacities.requires_grad = True |
|
self.viewmat.requires_grad = False |
|
|
|
def train( |
|
self, |
|
iterations: int = 1000, |
|
lr: float = 0.01, |
|
save_imgs: bool = False, |
|
model_type: Literal["3dgs", "2dgs"] = "3dgs", |
|
): |
|
optimizer = optim.Adam( |
|
[self.rgbs, self.means, self.scales, self.opacities, self.quats], lr |
|
) |
|
mse_loss = torch.nn.MSELoss() |
|
frames = [] |
|
times = [0] * 2 |
|
K = torch.tensor( |
|
[ |
|
[self.focal, 0, self.W / 2], |
|
[0, self.focal, self.H / 2], |
|
[0, 0, 1], |
|
], |
|
device=self.device, |
|
) |
|
|
|
if model_type == "3dgs": |
|
rasterize_fnc = rasterization |
|
elif model_type == "2dgs": |
|
rasterize_fnc = rasterization_2dgs |
|
|
|
for iter in range(iterations): |
|
start = time.time() |
|
|
|
renders = rasterize_fnc( |
|
self.means, |
|
self.quats / self.quats.norm(dim=-1, keepdim=True), |
|
self.scales, |
|
torch.sigmoid(self.opacities), |
|
torch.sigmoid(self.rgbs), |
|
self.viewmat[None], |
|
K[None], |
|
self.W, |
|
self.H, |
|
packed=False, |
|
)[0] |
|
out_img = renders[0] |
|
torch.cuda.synchronize() |
|
times[0] += time.time() - start |
|
loss = mse_loss(out_img, self.gt_image) |
|
optimizer.zero_grad() |
|
start = time.time() |
|
loss.backward() |
|
torch.cuda.synchronize() |
|
times[1] += time.time() - start |
|
optimizer.step() |
|
print(f"Iteration {iter + 1}/{iterations}, Loss: {loss.item()}") |
|
|
|
if save_imgs and iter % 5 == 0: |
|
frames.append((out_img.detach().cpu().numpy() * 255).astype(np.uint8)) |
|
if save_imgs: |
|
|
|
frames = [Image.fromarray(frame) for frame in frames] |
|
out_dir = os.path.join(os.getcwd(), "results") |
|
os.makedirs(out_dir, exist_ok=True) |
|
frames[0].save( |
|
f"{out_dir}/training.gif", |
|
save_all=True, |
|
append_images=frames[1:], |
|
optimize=False, |
|
duration=5, |
|
loop=0, |
|
) |
|
print(f"Total(s):\nRasterization: {times[0]:.3f}, Backward: {times[1]:.3f}") |
|
print( |
|
f"Per step(s):\nRasterization: {times[0]/iterations:.5f}, Backward: {times[1]/iterations:.5f}" |
|
) |
|
|
|
|
|
def image_path_to_tensor(image_path: Path): |
|
import torchvision.transforms as transforms |
|
|
|
img = Image.open(image_path) |
|
transform = transforms.ToTensor() |
|
img_tensor = transform(img).permute(1, 2, 0)[..., :3] |
|
return img_tensor |
|
|
|
|
|
def main( |
|
height: int = 256, |
|
width: int = 256, |
|
num_points: int = 100000, |
|
save_imgs: bool = True, |
|
img_path: Optional[Path] = None, |
|
iterations: int = 1000, |
|
lr: float = 0.01, |
|
model_type: Literal["3dgs", "2dgs"] = "3dgs", |
|
) -> None: |
|
if img_path: |
|
gt_image = image_path_to_tensor(img_path) |
|
else: |
|
gt_image = torch.ones((height, width, 3)) * 1.0 |
|
|
|
gt_image[: height // 2, : width // 2, :] = torch.tensor([1.0, 0.0, 0.0]) |
|
gt_image[height // 2 :, width // 2 :, :] = torch.tensor([0.0, 0.0, 1.0]) |
|
|
|
trainer = SimpleTrainer(gt_image=gt_image, num_points=num_points) |
|
trainer.train( |
|
iterations=iterations, |
|
lr=lr, |
|
save_imgs=save_imgs, |
|
model_type=model_type, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
tyro.cli(main) |
|
|