Spaces:
Running
Running
Create system.py
Browse files- tsr/system.py +205 -0
tsr/system.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import List, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import PIL.Image
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import trimesh
|
11 |
+
from einops import rearrange
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from .models.isosurface import MarchingCubeHelper
|
17 |
+
from .utils import (
|
18 |
+
BaseModule,
|
19 |
+
ImagePreprocessor,
|
20 |
+
find_class,
|
21 |
+
get_spherical_cameras,
|
22 |
+
scale_tensor,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class TSR(BaseModule):
|
27 |
+
@dataclass
|
28 |
+
class Config(BaseModule.Config):
|
29 |
+
cond_image_size: int
|
30 |
+
|
31 |
+
image_tokenizer_cls: str
|
32 |
+
image_tokenizer: dict
|
33 |
+
|
34 |
+
tokenizer_cls: str
|
35 |
+
tokenizer: dict
|
36 |
+
|
37 |
+
backbone_cls: str
|
38 |
+
backbone: dict
|
39 |
+
|
40 |
+
post_processor_cls: str
|
41 |
+
post_processor: dict
|
42 |
+
|
43 |
+
decoder_cls: str
|
44 |
+
decoder: dict
|
45 |
+
|
46 |
+
renderer_cls: str
|
47 |
+
renderer: dict
|
48 |
+
|
49 |
+
cfg: Config
|
50 |
+
|
51 |
+
@classmethod
|
52 |
+
def from_pretrained(
|
53 |
+
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
|
54 |
+
):
|
55 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
56 |
+
config_path = os.path.join(pretrained_model_name_or_path, config_name)
|
57 |
+
weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
|
58 |
+
else:
|
59 |
+
config_path = hf_hub_download(
|
60 |
+
repo_id=pretrained_model_name_or_path, filename=config_name
|
61 |
+
)
|
62 |
+
weight_path = hf_hub_download(
|
63 |
+
repo_id=pretrained_model_name_or_path, filename=weight_name
|
64 |
+
)
|
65 |
+
|
66 |
+
cfg = OmegaConf.load(config_path)
|
67 |
+
OmegaConf.resolve(cfg)
|
68 |
+
model = cls(cfg)
|
69 |
+
ckpt = torch.load(weight_path, map_location="cpu")
|
70 |
+
model.load_state_dict(ckpt)
|
71 |
+
return model
|
72 |
+
|
73 |
+
def configure(self):
|
74 |
+
self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
|
75 |
+
self.cfg.image_tokenizer
|
76 |
+
)
|
77 |
+
self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
|
78 |
+
self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
|
79 |
+
self.post_processor = find_class(self.cfg.post_processor_cls)(
|
80 |
+
self.cfg.post_processor
|
81 |
+
)
|
82 |
+
self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
|
83 |
+
self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer)
|
84 |
+
self.image_processor = ImagePreprocessor()
|
85 |
+
self.isosurface_helper = None
|
86 |
+
|
87 |
+
def forward(
|
88 |
+
self,
|
89 |
+
image: Union[
|
90 |
+
PIL.Image.Image,
|
91 |
+
np.ndarray,
|
92 |
+
torch.FloatTensor,
|
93 |
+
List[PIL.Image.Image],
|
94 |
+
List[np.ndarray],
|
95 |
+
List[torch.FloatTensor],
|
96 |
+
],
|
97 |
+
device: str,
|
98 |
+
) -> torch.FloatTensor:
|
99 |
+
rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to(
|
100 |
+
device
|
101 |
+
)
|
102 |
+
batch_size = rgb_cond.shape[0]
|
103 |
+
|
104 |
+
input_image_tokens: torch.Tensor = self.image_tokenizer(
|
105 |
+
rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1),
|
106 |
+
)
|
107 |
+
|
108 |
+
input_image_tokens = rearrange(
|
109 |
+
input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1
|
110 |
+
)
|
111 |
+
|
112 |
+
tokens: torch.Tensor = self.tokenizer(batch_size)
|
113 |
+
|
114 |
+
tokens = self.backbone(
|
115 |
+
tokens,
|
116 |
+
encoder_hidden_states=input_image_tokens,
|
117 |
+
)
|
118 |
+
|
119 |
+
scene_codes = self.post_processor(self.tokenizer.detokenize(tokens))
|
120 |
+
return scene_codes
|
121 |
+
|
122 |
+
def render(
|
123 |
+
self,
|
124 |
+
scene_codes,
|
125 |
+
n_views: int,
|
126 |
+
elevation_deg: float = 0.0,
|
127 |
+
camera_distance: float = 1.9,
|
128 |
+
fovy_deg: float = 40.0,
|
129 |
+
height: int = 256,
|
130 |
+
width: int = 256,
|
131 |
+
return_type: str = "pil",
|
132 |
+
):
|
133 |
+
rays_o, rays_d = get_spherical_cameras(
|
134 |
+
n_views, elevation_deg, camera_distance, fovy_deg, height, width
|
135 |
+
)
|
136 |
+
rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device)
|
137 |
+
|
138 |
+
def process_output(image: torch.FloatTensor):
|
139 |
+
if return_type == "pt":
|
140 |
+
return image
|
141 |
+
elif return_type == "np":
|
142 |
+
return image.detach().cpu().numpy()
|
143 |
+
elif return_type == "pil":
|
144 |
+
return Image.fromarray(
|
145 |
+
(image.detach().cpu().numpy() * 255.0).astype(np.uint8)
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
raise NotImplementedError
|
149 |
+
|
150 |
+
images = []
|
151 |
+
for scene_code in scene_codes:
|
152 |
+
images_ = []
|
153 |
+
for i in range(n_views):
|
154 |
+
with torch.no_grad():
|
155 |
+
image = self.renderer(
|
156 |
+
self.decoder, scene_code, rays_o[i], rays_d[i]
|
157 |
+
)
|
158 |
+
images_.append(process_output(image))
|
159 |
+
images.append(images_)
|
160 |
+
|
161 |
+
return images
|
162 |
+
|
163 |
+
def set_marching_cubes_resolution(self, resolution: int):
|
164 |
+
if (
|
165 |
+
self.isosurface_helper is not None
|
166 |
+
and self.isosurface_helper.resolution == resolution
|
167 |
+
):
|
168 |
+
return
|
169 |
+
self.isosurface_helper = MarchingCubeHelper(resolution)
|
170 |
+
|
171 |
+
def extract_mesh(self, scene_codes, has_vertex_color, resolution: int = 256, threshold: float = 25.0):
|
172 |
+
self.set_marching_cubes_resolution(resolution)
|
173 |
+
meshes = []
|
174 |
+
for scene_code in scene_codes:
|
175 |
+
with torch.no_grad():
|
176 |
+
density = self.renderer.query_triplane(
|
177 |
+
self.decoder,
|
178 |
+
scale_tensor(
|
179 |
+
self.isosurface_helper.grid_vertices.to(scene_codes.device),
|
180 |
+
self.isosurface_helper.points_range,
|
181 |
+
(-self.renderer.cfg.radius, self.renderer.cfg.radius),
|
182 |
+
),
|
183 |
+
scene_code,
|
184 |
+
)["density_act"]
|
185 |
+
v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold))
|
186 |
+
v_pos = scale_tensor(
|
187 |
+
v_pos,
|
188 |
+
self.isosurface_helper.points_range,
|
189 |
+
(-self.renderer.cfg.radius, self.renderer.cfg.radius),
|
190 |
+
)
|
191 |
+
color = None
|
192 |
+
if has_vertex_color:
|
193 |
+
with torch.no_grad():
|
194 |
+
color = self.renderer.query_triplane(
|
195 |
+
self.decoder,
|
196 |
+
v_pos,
|
197 |
+
scene_code,
|
198 |
+
)["color"]
|
199 |
+
mesh = trimesh.Trimesh(
|
200 |
+
vertices=v_pos.cpu().numpy(),
|
201 |
+
faces=t_pos_idx.cpu().numpy(),
|
202 |
+
vertex_colors=color.cpu().numpy() if has_vertex_color else None,
|
203 |
+
)
|
204 |
+
meshes.append(mesh)
|
205 |
+
return meshes
|