import math import os import time from pathlib import Path from typing import Literal, Optional import numpy as np import torch import tyro from PIL import Image from torch import Tensor, optim from gsplat import rasterization, rasterization_2dgs class SimpleTrainer: """Trains random gaussians to fit an image.""" def __init__( self, gt_image: Tensor, num_points: int = 2000, ): self.device = torch.device("cuda:0") self.gt_image = gt_image.to(device=self.device) self.num_points = num_points fov_x = math.pi / 2.0 self.H, self.W = gt_image.shape[0], gt_image.shape[1] self.focal = 0.5 * float(self.W) / math.tan(0.5 * fov_x) self.img_size = torch.tensor([self.W, self.H, 1], device=self.device) self._init_gaussians() def _init_gaussians(self): """Random gaussians""" bd = 2 self.means = bd * (torch.rand(self.num_points, 3, device=self.device) - 0.5) self.scales = torch.rand(self.num_points, 3, device=self.device) d = 3 self.rgbs = torch.rand(self.num_points, d, device=self.device) u = torch.rand(self.num_points, 1, device=self.device) v = torch.rand(self.num_points, 1, device=self.device) w = torch.rand(self.num_points, 1, device=self.device) self.quats = torch.cat( [ torch.sqrt(1.0 - u) * torch.sin(2.0 * math.pi * v), torch.sqrt(1.0 - u) * torch.cos(2.0 * math.pi * v), torch.sqrt(u) * torch.sin(2.0 * math.pi * w), torch.sqrt(u) * torch.cos(2.0 * math.pi * w), ], -1, ) self.opacities = torch.ones((self.num_points), device=self.device) self.viewmat = torch.tensor( [ [1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 8.0], [0.0, 0.0, 0.0, 1.0], ], device=self.device, ) self.background = torch.zeros(d, device=self.device) self.means.requires_grad = True self.scales.requires_grad = True self.quats.requires_grad = True self.rgbs.requires_grad = True self.opacities.requires_grad = True self.viewmat.requires_grad = False def train( self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = False, model_type: Literal["3dgs", "2dgs"] = "3dgs", ): optimizer = optim.Adam( [self.rgbs, self.means, self.scales, self.opacities, self.quats], lr ) mse_loss = torch.nn.MSELoss() frames = [] times = [0] * 2 # rasterization, backward K = torch.tensor( [ [self.focal, 0, self.W / 2], [0, self.focal, self.H / 2], [0, 0, 1], ], device=self.device, ) if model_type == "3dgs": rasterize_fnc = rasterization elif model_type == "2dgs": rasterize_fnc = rasterization_2dgs for iter in range(iterations): start = time.time() renders = rasterize_fnc( self.means, self.quats / self.quats.norm(dim=-1, keepdim=True), self.scales, torch.sigmoid(self.opacities), torch.sigmoid(self.rgbs), self.viewmat[None], K[None], self.W, self.H, packed=False, )[0] out_img = renders[0] torch.cuda.synchronize() times[0] += time.time() - start loss = mse_loss(out_img, self.gt_image) optimizer.zero_grad() start = time.time() loss.backward() torch.cuda.synchronize() times[1] += time.time() - start optimizer.step() print(f"Iteration {iter + 1}/{iterations}, Loss: {loss.item()}") if save_imgs and iter % 5 == 0: frames.append((out_img.detach().cpu().numpy() * 255).astype(np.uint8)) if save_imgs: # save them as a gif with PIL frames = [Image.fromarray(frame) for frame in frames] out_dir = os.path.join(os.getcwd(), "results") os.makedirs(out_dir, exist_ok=True) frames[0].save( f"{out_dir}/training.gif", save_all=True, append_images=frames[1:], optimize=False, duration=5, loop=0, ) print(f"Total(s):\nRasterization: {times[0]:.3f}, Backward: {times[1]:.3f}") print( f"Per step(s):\nRasterization: {times[0]/iterations:.5f}, Backward: {times[1]/iterations:.5f}" ) def image_path_to_tensor(image_path: Path): import torchvision.transforms as transforms img = Image.open(image_path) transform = transforms.ToTensor() img_tensor = transform(img).permute(1, 2, 0)[..., :3] return img_tensor def main( height: int = 256, width: int = 256, num_points: int = 100000, save_imgs: bool = True, img_path: Optional[Path] = None, iterations: int = 1000, lr: float = 0.01, model_type: Literal["3dgs", "2dgs"] = "3dgs", ) -> None: if img_path: gt_image = image_path_to_tensor(img_path) else: gt_image = torch.ones((height, width, 3)) * 1.0 # make top left and bottom right red, blue gt_image[: height // 2, : width // 2, :] = torch.tensor([1.0, 0.0, 0.0]) gt_image[height // 2 :, width // 2 :, :] = torch.tensor([0.0, 0.0, 1.0]) trainer = SimpleTrainer(gt_image=gt_image, num_points=num_points) trainer.train( iterations=iterations, lr=lr, save_imgs=save_imgs, model_type=model_type, ) if __name__ == "__main__": tyro.cli(main)