ReubenSun commited on
Commit
55f226f
·
1 Parent(s): bc373eb

texture sync

Browse files
step1x3d_texture/pipelines/ig2mv_sdxl_pipeline.py CHANGED
@@ -51,6 +51,20 @@ from ..models.attention_processor import (
51
  DecoupledMVRowSelfAttnProcessor2_0,
52
  set_unet_2d_condition_attn_processor,
53
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
 
@@ -70,6 +84,27 @@ def retrieve_latents(
70
  raise AttributeError("Could not access latents of provided encoder_output")
71
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
74
  def __init__(
75
  self,
@@ -309,6 +344,8 @@ class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
309
  # Image condition
310
  reference_image: Optional[PipelineImageInput] = None,
311
  reference_conditioning_scale: Optional[float] = 1.0,
 
 
312
  **kwargs,
313
  ):
314
  r"""
@@ -556,6 +593,27 @@ class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
556
  latents,
557
  )
558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
560
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
561
 
@@ -709,6 +767,36 @@ class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
709
  ).to(device=device, dtype=latents.dtype)
710
 
711
  self._num_timesteps = len(timesteps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  with self.progress_bar(total=num_inference_steps) as progress_bar:
713
  for i, t in enumerate(timesteps):
714
  if self.interrupt:
@@ -768,9 +856,49 @@ class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
768
 
769
  # compute the previous noisy sample x_t -> x_t-1
770
  latents_dtype = latents.dtype
771
- latents = self.scheduler.step(
772
- noise_pred, t, latents, **extra_step_kwargs, return_dict=False
773
- )[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
774
  if latents.dtype != latents_dtype:
775
  if torch.backends.mps.is_available():
776
  # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
 
51
  DecoupledMVRowSelfAttnProcessor2_0,
52
  set_unet_2d_condition_attn_processor,
53
  )
54
+ import random
55
+ from ..texture_sync.project import UVProjection as UVP
56
+ from ..texture_sync.step_sync import step_tex_sync
57
+ from trimesh import Trimesh
58
+ from torchvision.transforms import Compose, Resize, GaussianBlur, InterpolationMode
59
+ from diffusers.utils import (
60
+ BaseOutput,
61
+ numpy_to_pil,
62
+ pt_to_pil,
63
+ is_accelerate_available,
64
+ is_accelerate_version,
65
+ logging,
66
+ replace_example_docstring
67
+ )
68
 
69
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
70
 
 
84
  raise AttributeError("Could not access latents of provided encoder_output")
85
 
86
 
87
+ @torch.no_grad()
88
+ def composite_rendered_view(scheduler, backgrounds, foregrounds, masks, t):
89
+ composited_images = []
90
+ for i, (background, foreground, mask) in enumerate(zip(backgrounds, foregrounds, masks)):
91
+ if t > 0:
92
+ alphas_cumprod = scheduler.alphas_cumprod[t]
93
+ noise = torch.normal(0, 1, background.shape, device=background.device)
94
+ background = (1-alphas_cumprod) * noise + alphas_cumprod * background
95
+ composited = foreground * mask + background * (1-mask)
96
+ composited_images.append(composited)
97
+ composited_tensor = torch.stack(composited_images)
98
+ return composited_tensor
99
+
100
+
101
+ @torch.no_grad()
102
+ def encode_latents(vae, imgs):
103
+ imgs = (imgs-0.5)*2
104
+ latents = vae.encode(imgs).latent_dist.sample()
105
+ latents = vae.config.scaling_factor * latents
106
+ return latents
107
+
108
  class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
109
  def __init__(
110
  self,
 
344
  # Image condition
345
  reference_image: Optional[PipelineImageInput] = None,
346
  reference_conditioning_scale: Optional[float] = 1.0,
347
+ mesh: Optional[Trimesh] = None,
348
+ texture_sync_config: Optional[dict] = None,
349
  **kwargs,
350
  ):
351
  r"""
 
593
  latents,
594
  )
595
 
596
+ # texture patams init
597
+ texture_size = texture_sync_config["texture_size"]
598
+ latent_size = texture_sync_config["latent_size"]
599
+ elevations = texture_sync_config["elevations"]
600
+ azimuths = texture_sync_config["azimuths"]
601
+ texture_sync_ratio = texture_sync_config["texture_sync_ratio"]
602
+ camera_poses = [(elv, azim) for elv, azim in zip(elevations, azimuths)]
603
+ uvp = UVP(texture_size=texture_size, render_size=latent_size, sampling_mode="nearest", channels=4, device=self._execution_device)
604
+ uvp.load_mesh(mesh, scale_factor=1.0, autouv=True)
605
+ uvp.set_cameras_and_render_settings(camera_poses, centers=None, camera_distance=texture_sync_config["camera_distance"], scale=((1.0, 1.0, 1.0),))
606
+
607
+ latent_tex = uvp.set_noise_texture()
608
+ noise_views = uvp.render_textured_views()
609
+ foregrounds = [view[:-1] for view in noise_views]
610
+ masks = [view[-1:] for view in noise_views]
611
+
612
+ if texture_sync_ratio>0:
613
+ composited_tensor = composite_rendered_view(self.scheduler, latents, foregrounds, masks, int(timesteps[0].cpu().item())+1)
614
+ latents = composited_tensor.type(latents.dtype)
615
+ uvp.to("cpu")
616
+
617
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
618
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
619
 
 
767
  ).to(device=device, dtype=latents.dtype)
768
 
769
  self._num_timesteps = len(timesteps)
770
+
771
+
772
+ # texture sync params
773
+ exp_start = texture_sync_config["exp_start"]
774
+ exp_end = texture_sync_config["exp_end"]
775
+ shuffle_background_change = texture_sync_config["shuffle_background_change"]
776
+ shuffle_background_end = texture_sync_config["shuffle_background_end"]
777
+ num_timesteps = self.scheduler.config.num_train_timesteps
778
+
779
+ uvp.to(self._execution_device)
780
+ color_constants = {"black": [-1, -1, -1], "white": [1, 1, 1], "maroon": [0, -1, -1],
781
+ "red": [1, -1, -1], "olive": [0, 0, -1], "yellow": [1, 1, -1],
782
+ "green": [-1, 0, -1], "lime": [-1 ,1, -1], "teal": [-1, 0, 0],
783
+ "aqua": [-1, 1, 1], "navy": [-1, -1, 0], "blue": [-1, -1, 1],
784
+ "purple": [0, -1 , 0], "fuchsia": [1, -1, 1]}
785
+ color_names = list(color_constants.keys())
786
+ background_colors = [random.choice(list(color_constants.keys())) for i in range(len(camera_poses))]
787
+ intermediate_results = []
788
+ self.upcast_vae()
789
+ self.vae.config.force_upcast = True
790
+ color_images = torch.FloatTensor([color_constants[name] for name in color_names]).reshape(-1,3,1,1).to(dtype=torch.float32, device=self._execution_device)
791
+ color_images = torch.ones(
792
+ (1,1,latent_size*8, latent_size*8),
793
+ device=self._execution_device,
794
+ dtype=torch.float32
795
+ ) * color_images
796
+ color_images = ((0.5*color_images)+0.5)
797
+ color_latents = encode_latents(self.vae, color_images).to(dtype=self.text_encoder_2.dtype)
798
+ color_latents = {color[0]:color[1] for color in zip(color_names, [latent for latent in color_latents])}
799
+
800
  with self.progress_bar(total=num_inference_steps) as progress_bar:
801
  for i, t in enumerate(timesteps):
802
  if self.interrupt:
 
856
 
857
  # compute the previous noisy sample x_t -> x_t-1
858
  latents_dtype = latents.dtype
859
+
860
+ # texture sync
861
+ current_exp = ((exp_end-exp_start) * i / num_inference_steps) + exp_start
862
+ if t > (1-texture_sync_ratio)*num_timesteps:
863
+ step_results = step_tex_sync(
864
+ scheduler=self.scheduler,
865
+ uvp=uvp,
866
+ model_output=noise_pred,
867
+ timestep=t,
868
+ sample=latents,
869
+ texture=latent_tex,
870
+ return_dict=True,
871
+ main_views=[],
872
+ exp= current_exp,
873
+ **extra_step_kwargs
874
+ )
875
+
876
+ pred_original_sample = step_results["pred_original_sample"]
877
+ latents = step_results["prev_sample"]
878
+ latent_tex = step_results["prev_tex"]
879
+
880
+ # Composit latent foreground with random color background
881
+ background_latents = [color_latents[color] for color in background_colors]
882
+ composited_tensor = composite_rendered_view(self.scheduler, background_latents, latents, masks, t)
883
+ latents = composited_tensor.type(latents.dtype)
884
+
885
+ intermediate_results.append((latents.to("cpu"), pred_original_sample.to("cpu")))
886
+ else:
887
+ step_results = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
888
+ pred_original_sample = step_results["pred_original_sample"]
889
+ latents = step_results["prev_sample"]
890
+ latent_tex = None
891
+ intermediate_results.append((latents.to("cpu"), pred_original_sample.to("cpu")))
892
+
893
+ # 2. Shuffle background colors; only black and white used after certain timestep
894
+ if (1-t/num_timesteps) < shuffle_background_change:
895
+ background_colors = [random.choice(list(color_constants.keys())) for i in range(len(camera_poses))]
896
+ elif (1-t/num_timesteps) < shuffle_background_end:
897
+ background_colors = [random.choice(["black","white"]) for i in range(len(camera_poses))]
898
+ else:
899
+ background_colors = background_colors
900
+ del noise_pred
901
+
902
  if latents.dtype != latents_dtype:
903
  if torch.backends.mps.is_available():
904
  # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
step1x3d_texture/pipelines/step1x_3d_texture_synthesis_pipeline.py CHANGED
@@ -24,7 +24,6 @@ import trimesh
24
  import xatlas
25
  import scipy.sparse
26
  from scipy.sparse.linalg import spsolve
27
-
28
  from step1x3d_geometry.models.pipelines.pipeline_utils import smart_load_model
29
 
30
 
@@ -36,7 +35,7 @@ class Step1X3DTextureConfig:
36
  self.unet_model = None
37
  self.lora_model = None
38
  self.adapter_path = "stepfun-ai/Step1X-3D"
39
- self.scheduler = None
40
  self.num_views = 6
41
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
42
  self.dtype = torch.float16
@@ -61,6 +60,20 @@ class Step1X3DTextureConfig:
61
  self.bake_exp = 4
62
  self.merge_method = "fast"
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  class Step1X3DTexturePipeline:
66
  def __init__(self, config):
@@ -120,11 +133,9 @@ class Step1X3DTexturePipeline:
120
  if unet_model is not None:
121
  pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
122
 
123
- print('VAE Loaded!')
124
  # Prepare pipeline
125
  pipe = IG2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs)
126
 
127
- print('Base model Loaded!')
128
  # Load scheduler if provided
129
  scheduler_class = None
130
  if scheduler == "ddpm":
@@ -138,14 +149,11 @@ class Step1X3DTexturePipeline:
138
  shift_scale=8.0,
139
  scheduler_class=scheduler_class,
140
  )
141
- print('Scheduler Loaded!')
142
  pipe.init_custom_adapter(
143
  num_views=num_views,
144
  self_attn_processor=DecoupledMVRowColSelfAttnProcessor2_0,
145
  )
146
- print(f'Load adapter from {adapter_path}/step1x-3d-ig2v.safetensors')
147
  pipe.load_custom_adapter(adapter_path, "step1x-3d-ig2v.safetensors")
148
- print(f'Load adapter successed!')
149
  pipe.to(device=device, dtype=dtype)
150
  pipe.cond_encoder.to(device=device, dtype=dtype)
151
 
@@ -282,6 +290,7 @@ class Step1X3DTexturePipeline:
282
  negative_prompt=negative_prompt,
283
  cross_attention_kwargs={"scale": lora_scale},
284
  mesh=mesh_bp,
 
285
  **pipe_kwargs,
286
  ).images
287
 
@@ -359,7 +368,7 @@ class Step1X3DTexturePipeline:
359
  width=768,
360
  num_inference_steps=self.config.num_inference_steps,
361
  guidance_scale=self.config.guidance_scale,
362
- seed=seed if seed is not None else self.config.seed,
363
  lora_scale=self.config.lora_scale,
364
  reference_conditioning_scale=self.config.reference_conditioning_scale,
365
  negative_prompt=self.config.negative_prompt,
 
24
  import xatlas
25
  import scipy.sparse
26
  from scipy.sparse.linalg import spsolve
 
27
  from step1x3d_geometry.models.pipelines.pipeline_utils import smart_load_model
28
 
29
 
 
35
  self.unet_model = None
36
  self.lora_model = None
37
  self.adapter_path = "stepfun-ai/Step1X-3D"
38
+ self.scheduler = "ddpm"
39
  self.num_views = 6
40
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
41
  self.dtype = torch.float16
 
60
  self.bake_exp = 4
61
  self.merge_method = "fast"
62
 
63
+ # texture sync params
64
+ self.texture_sync_config = {
65
+ "texture_size": 1536,
66
+ "latent_size": 768//8,
67
+ "elevations": [0, 0, 0, 0, 90, -90],
68
+ "azimuths": [0, 90, 180, 270, 0, 0],
69
+ "texture_sync_ratio": 0.5,
70
+ "exp_end": 6.0,
71
+ "exp_start": 0,
72
+ "shuffle_background_change": 0.4,
73
+ "shuffle_background_end": 0.99,
74
+ "camera_distance": 1.8
75
+ }
76
+
77
 
78
  class Step1X3DTexturePipeline:
79
  def __init__(self, config):
 
133
  if unet_model is not None:
134
  pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
135
 
 
136
  # Prepare pipeline
137
  pipe = IG2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs)
138
 
 
139
  # Load scheduler if provided
140
  scheduler_class = None
141
  if scheduler == "ddpm":
 
149
  shift_scale=8.0,
150
  scheduler_class=scheduler_class,
151
  )
 
152
  pipe.init_custom_adapter(
153
  num_views=num_views,
154
  self_attn_processor=DecoupledMVRowColSelfAttnProcessor2_0,
155
  )
 
156
  pipe.load_custom_adapter(adapter_path, "step1x-3d-ig2v.safetensors")
 
157
  pipe.to(device=device, dtype=dtype)
158
  pipe.cond_encoder.to(device=device, dtype=dtype)
159
 
 
290
  negative_prompt=negative_prompt,
291
  cross_attention_kwargs={"scale": lora_scale},
292
  mesh=mesh_bp,
293
+ texture_sync_config=self.config.texture_sync_config,
294
  **pipe_kwargs,
295
  ).images
296
 
 
368
  width=768,
369
  num_inference_steps=self.config.num_inference_steps,
370
  guidance_scale=self.config.guidance_scale,
371
+ seed= seed if seed is not None else self.config.seed,
372
  lora_scale=self.config.lora_scale,
373
  reference_conditioning_scale=self.config.reference_conditioning_scale,
374
  negative_prompt=self.config.negative_prompt,
step1x3d_texture/renderer/geometry.py DELETED
@@ -1,151 +0,0 @@
1
- import torch
2
- import pytorch3d
3
- import torch.nn.functional as F
4
-
5
- from pytorch3d.ops import interpolate_face_attributes
6
-
7
- from pytorch3d.renderer import (
8
- look_at_view_transform,
9
- FoVPerspectiveCameras,
10
- AmbientLights,
11
- PointLights,
12
- DirectionalLights,
13
- Materials,
14
- RasterizationSettings,
15
- MeshRenderer,
16
- MeshRasterizer,
17
- SoftPhongShader,
18
- SoftSilhouetteShader,
19
- HardPhongShader,
20
- TexturesVertex,
21
- TexturesUV,
22
- Materials,
23
- )
24
- from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend
25
- from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties
26
- from pytorch3d.renderer.mesh.shader import ShaderBase
27
-
28
-
29
- def get_cos_angle(points, normals, camera_position):
30
- """
31
- calculate cosine similarity between view->surface and surface normal.
32
- """
33
-
34
- if points.shape != normals.shape:
35
- msg = "Expected points and normals to have the same shape: got %r, %r"
36
- raise ValueError(msg % (points.shape, normals.shape))
37
-
38
- # Ensure all inputs have same batch dimension as points
39
- matched_tensors = convert_to_tensors_and_broadcast(
40
- points, camera_position, device=points.device
41
- )
42
- _, camera_position = matched_tensors
43
-
44
- # Reshape direction and color so they have all the arbitrary intermediate
45
- # dimensions as points. Assume first dim = batch dim and last dim = 3.
46
- points_dims = points.shape[1:-1]
47
- expand_dims = (-1,) + (1,) * len(points_dims)
48
-
49
- if camera_position.shape != normals.shape:
50
- camera_position = camera_position.view(expand_dims + (3,))
51
-
52
- normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)
53
-
54
- # Calculate the cosine value.
55
- view_direction = camera_position - points
56
- view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6)
57
- cos_angle = torch.sum(view_direction * normals, dim=-1, keepdim=True)
58
- cos_angle = cos_angle.clamp(0, 1)
59
-
60
- # Cosine of the angle between the reflected light ray and the viewer
61
- return cos_angle
62
-
63
-
64
- def _geometry_shading_with_pixels(
65
- meshes, fragments, lights, cameras, materials, texels
66
- ):
67
- """
68
- Render pixel space vertex position, normal(world), depth, and cos angle
69
-
70
- Args:
71
- meshes: Batch of meshes
72
- fragments: Fragments named tuple with the outputs of rasterization
73
- lights: Lights class containing a batch of lights
74
- cameras: Cameras class containing a batch of cameras
75
- materials: Materials class containing a batch of material properties
76
- texels: texture per pixel of shape (N, H, W, K, 3)
77
-
78
- Returns:
79
- colors: (N, H, W, K, 3)
80
- pixel_coords: (N, H, W, K, 3), camera coordinates of each intersection.
81
- """
82
- verts = meshes.verts_packed() # (V, 3)
83
- faces = meshes.faces_packed() # (F, 3)
84
- vertex_normals = meshes.verts_normals_packed() # (V, 3)
85
- faces_verts = verts[faces]
86
- faces_normals = vertex_normals[faces]
87
- pixel_coords_in_camera = interpolate_face_attributes(
88
- fragments.pix_to_face, fragments.bary_coords, faces_verts
89
- )
90
- pixel_normals = interpolate_face_attributes(
91
- fragments.pix_to_face, fragments.bary_coords, faces_normals
92
- )
93
-
94
- cos_angles = get_cos_angle(
95
- pixel_coords_in_camera, pixel_normals, cameras.get_camera_center()
96
- )
97
-
98
- return pixel_coords_in_camera, pixel_normals, fragments.zbuf[..., None], cos_angles
99
-
100
-
101
- class HardGeometryShader(ShaderBase):
102
- """
103
- renders common geometric informations.
104
-
105
-
106
- """
107
-
108
- def forward(self, fragments, meshes, **kwargs):
109
- cameras = super()._get_cameras(**kwargs)
110
- texels = self.texel_from_uv(fragments, meshes)
111
-
112
- lights = kwargs.get("lights", self.lights)
113
- materials = kwargs.get("materials", self.materials)
114
- blend_params = kwargs.get("blend_params", self.blend_params)
115
- verts, normals, depths, cos_angles = _geometry_shading_with_pixels(
116
- meshes=meshes,
117
- fragments=fragments,
118
- texels=texels,
119
- lights=lights,
120
- cameras=cameras,
121
- materials=materials,
122
- )
123
- texels = meshes.sample_textures(fragments)
124
- verts = hard_rgb_blend(verts, fragments, blend_params)
125
- normals = hard_rgb_blend(normals, fragments, blend_params)
126
- depths = hard_rgb_blend(depths, fragments, blend_params)
127
- cos_angles = hard_rgb_blend(cos_angles, fragments, blend_params)
128
- from IPython import embed
129
-
130
- embed()
131
- texels = hard_rgb_blend(texels, fragments, blend_params)
132
- return verts, normals, depths, cos_angles, texels, fragments
133
-
134
- def texel_from_uv(self, fragments, meshes):
135
- texture_tmp = meshes.textures
136
- maps_tmp = texture_tmp.maps_padded()
137
- uv_color = [[[1, 0], [1, 1]], [[0, 0], [0, 1]]]
138
- uv_color = (
139
- torch.FloatTensor(uv_color).to(maps_tmp[0].device).type(maps_tmp[0].dtype)
140
- )
141
- uv_texture = TexturesUV(
142
- [uv_color.clone() for t in maps_tmp],
143
- texture_tmp.faces_uvs_padded(),
144
- texture_tmp.verts_uvs_padded(),
145
- sampling_mode="bilinear",
146
- )
147
- meshes.textures = uv_texture
148
- texels = meshes.sample_textures(fragments)
149
- meshes.textures = texture_tmp
150
- texels = torch.cat((texels, texels[..., -1:] * 0), dim=-1)
151
- return texels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
step1x3d_texture/renderer/project.py DELETED
@@ -1,875 +0,0 @@
1
- import torch
2
- import pytorch3d
3
-
4
-
5
- from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj, IO
6
-
7
- from pytorch3d.structures import Meshes
8
- from pytorch3d.renderer import (
9
- look_at_view_transform,
10
- FoVPerspectiveCameras,
11
- FoVOrthographicCameras,
12
- AmbientLights,
13
- PointLights,
14
- DirectionalLights,
15
- Materials,
16
- RasterizationSettings,
17
- MeshRenderer,
18
- MeshRasterizer,
19
- TexturesUV,
20
- )
21
-
22
- from .geometry import HardGeometryShader
23
- from .shader import HardNChannelFlatShader
24
- from .voronoi import voronoi_solve
25
- import torch.nn.functional as F
26
- import open3d as o3d
27
- import pdb
28
- import kaolin as kal
29
- import numpy as np
30
-
31
-
32
- import torch
33
- from pytorch3d.renderer.cameras import FoVOrthographicCameras
34
- from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
35
- from pytorch3d.common.datatypes import Device
36
- import math
37
- import torch.nn.functional as F
38
- from trimesh import Trimesh
39
- from pytorch3d.structures import Meshes
40
- import os
41
-
42
- LIST_TYPE = Union[list, np.ndarray, torch.Tensor]
43
-
44
- _R = torch.eye(3)[None] # (1, 3, 3)
45
- _T = torch.zeros(1, 3) # (1, 3)
46
- _BatchFloatType = Union[float, Sequence[float], torch.Tensor]
47
-
48
-
49
- class CustomOrthographicCameras(FoVOrthographicCameras):
50
- def compute_projection_matrix(
51
- self, znear, zfar, max_x, min_x, max_y, min_y, scale_xyz
52
- ) -> torch.Tensor:
53
- """
54
- 自定义正交投影矩阵计算,继承并修改深度通道参数
55
- 参数维度说明:
56
- - znear/zfar: (N,)
57
- - max_x/min_x: (N,)
58
- - max_y/min_y: (N,)
59
- - scale_xyz: (N, 3)
60
- """
61
- K = torch.zeros((self._N, 4, 4), dtype=torch.float32, device=self.device)
62
-
63
- ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
64
- # NOTE: OpenGL flips handedness of coordinate system between camera
65
- # space and NDC space so z sign is -ve. In PyTorch3D we maintain a
66
- # right handed coordinate system throughout.
67
- z_sign = +1.0
68
-
69
- K[:, 0, 0] = (2.0 / (max_x - min_x)) * scale_xyz[:, 0]
70
- K[:, 1, 1] = (2.0 / (max_y - min_y)) * scale_xyz[:, 1]
71
- K[:, 0, 3] = -(max_x + min_x) / (max_x - min_x)
72
- K[:, 1, 3] = -(max_y + min_y) / (max_y - min_y)
73
- K[:, 3, 3] = ones
74
-
75
- # NOTE: This maps the z coordinate to the range [0, 1] and replaces the
76
- # the OpenGL z normalization to [-1, 1]
77
- K[:, 2, 2] = -2 * (1.0 / (zfar - znear)) * scale_xyz[:, 2]
78
- K[:, 2, 3] = -(znear + zfar) / (zfar - znear)
79
-
80
- return K
81
-
82
- def __init__(
83
- self,
84
- znear: _BatchFloatType = 1.0,
85
- zfar: _BatchFloatType = 100.0,
86
- max_y: _BatchFloatType = 1.0,
87
- min_y: _BatchFloatType = -1.0,
88
- max_x: _BatchFloatType = 1.0,
89
- min_x: _BatchFloatType = -1.0,
90
- scale_xyz=((1.0, 1.0, 1.0),), # (N, 3)
91
- R: torch.Tensor = _R,
92
- T: torch.Tensor = _T,
93
- K: Optional[torch.Tensor] = None,
94
- device: Device = "cpu",
95
- ):
96
- # 继承父类初始化逻辑
97
- super().__init__(
98
- znear=znear,
99
- zfar=zfar,
100
- max_y=max_y,
101
- min_y=min_y,
102
- max_x=max_x,
103
- min_x=min_x,
104
- scale_xyz=scale_xyz,
105
- R=R,
106
- T=T,
107
- K=K,
108
- device=device,
109
- )
110
-
111
-
112
- def erode_torch_batch(binary_img_batch, kernel_size):
113
- pad = (kernel_size - 1) // 2
114
- bin_img = F.pad(
115
- binary_img_batch.unsqueeze(1), pad=[pad, pad, pad, pad], mode="reflect"
116
- )
117
- out = -F.max_pool2d(-bin_img, kernel_size=kernel_size, stride=1, padding=0)
118
- out = out.squeeze(1)
119
- return out
120
-
121
-
122
- def dilate_torch_batch(binary_img_batch, kernel_size):
123
- pad = (kernel_size - 1) // 2
124
- bin_img = F.pad(binary_img_batch, pad=[pad, pad, pad, pad], mode="reflect")
125
- out = F.max_pool2d(bin_img, kernel_size=kernel_size, stride=1, padding=0)
126
- out = out.squeeze()
127
- return out
128
-
129
-
130
- # Pytorch3D based renderering functions, managed in a class
131
- # Render size is recommended to be the same as your latent view size
132
- # DO NOT USE "bilinear" sampling when you are handling latents.
133
- # Stable Diffusion has 4 latent channels so use channels=4
134
-
135
-
136
- class UVProjection:
137
- def __init__(
138
- self,
139
- texture_size=96,
140
- render_size=64,
141
- sampling_mode="nearest",
142
- channels=3,
143
- device=None,
144
- ):
145
- self.channels = channels
146
- self.device = device or torch.device("cpu")
147
- self.lights = AmbientLights(
148
- ambient_color=((1.0,) * channels,), device=self.device
149
- )
150
- self.target_size = (texture_size, texture_size)
151
- self.render_size = render_size
152
- self.sampling_mode = sampling_mode
153
-
154
- # Load obj mesh, rescale the mesh to fit into the bounding box
155
- def load_mesh(self, mesh, scale_factor=2.0, auto_center=True, autouv=False):
156
- if isinstance(mesh, Trimesh):
157
- vertices = torch.tensor(mesh.vertices, dtype=torch.float32).to(self.device)
158
- faces = torch.tensor(mesh.faces, dtype=torch.int64).to(self.device)
159
- mesh = Meshes(verts=[vertices], faces=[faces])
160
- verts = mesh.verts_packed()
161
- mesh = mesh.update_padded(verts[None, :, :])
162
- elif isinstance(mesh, str) and os.path.isfile(mesh):
163
- mesh = load_objs_as_meshes([mesh_path], device=self.device)
164
- if auto_center:
165
- verts = mesh.verts_packed()
166
- max_bb = (verts - 0).max(0)[0]
167
- min_bb = (verts - 0).min(0)[0]
168
- scale = (max_bb - min_bb).max() / 2
169
- center = (max_bb + min_bb) / 2
170
- mesh.offset_verts_(-center)
171
- mesh.scale_verts_((scale_factor / float(scale)))
172
- else:
173
- mesh.scale_verts_((scale_factor))
174
-
175
- if autouv or (mesh.textures is None):
176
- mesh = self.uv_unwrap(mesh)
177
- self.mesh = mesh
178
-
179
- def load_glb_mesh(
180
- self, mesh_path, trimesh, scale_factor=1.0, auto_center=True, autouv=False
181
- ):
182
- from pytorch3d.io.experimental_gltf_io import MeshGlbFormat
183
-
184
- io = IO()
185
- io.register_meshes_format(MeshGlbFormat())
186
- with open(mesh_path, "rb") as f:
187
- mesh = io.load_mesh(f, include_textures=True, device=self.device)
188
- if auto_center:
189
- verts = mesh.verts_packed()
190
-
191
- max_bb = (verts - 0).max(0)[0]
192
- min_bb = (verts - 0).min(0)[0]
193
- scale = (max_bb - min_bb).max() / 2
194
- center = (max_bb + min_bb) / 2
195
- mesh.offset_verts_(-center)
196
- mesh.scale_verts_((scale_factor / float(scale)))
197
- verts = mesh.verts_packed()
198
- # T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=verts.device, dtype=verts.dtype)
199
- # T = torch.tensor([[0, 0, 1], [0, 1, 0], [-1, 0, 0]], device=verts.device, dtype=verts.dtype)
200
- # verts = verts @ T
201
- mesh = mesh.update_padded(verts[None, :, :])
202
- else:
203
- mesh.scale_verts_((scale_factor))
204
- if autouv or (mesh.textures is None):
205
- mesh = self.uv_unwrap(mesh)
206
- self.mesh = mesh
207
-
208
- # Save obj mesh
209
- def save_mesh(self, mesh_path, texture):
210
- save_obj(
211
- mesh_path,
212
- self.mesh.verts_list()[0],
213
- self.mesh.faces_list()[0],
214
- verts_uvs=self.mesh.textures.verts_uvs_list()[0],
215
- faces_uvs=self.mesh.textures.faces_uvs_list()[0],
216
- texture_map=texture,
217
- )
218
-
219
- # Code referred to TEXTure code (https://github.com/TEXTurePaper/TEXTurePaper.git)
220
- def uv_unwrap(self, mesh):
221
- verts_list = mesh.verts_list()[0]
222
- faces_list = mesh.faces_list()[0]
223
-
224
- import xatlas
225
- import numpy as np
226
-
227
- v_np = verts_list.cpu().numpy()
228
- f_np = faces_list.int().cpu().numpy()
229
- atlas = xatlas.Atlas()
230
- atlas.add_mesh(v_np, f_np)
231
- chart_options = xatlas.ChartOptions()
232
- chart_options.max_iterations = 4
233
- atlas.generate(chart_options=chart_options)
234
- vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
235
-
236
- vt = (
237
- torch.from_numpy(vt_np.astype(np.float32))
238
- .type(verts_list.dtype)
239
- .to(mesh.device)
240
- )
241
- ft = (
242
- torch.from_numpy(ft_np.astype(np.int64))
243
- .type(faces_list.dtype)
244
- .to(mesh.device)
245
- )
246
-
247
- new_map = torch.zeros(self.target_size + (self.channels,), device=mesh.device)
248
- new_tex = TexturesUV([new_map], [ft], [vt], sampling_mode=self.sampling_mode)
249
-
250
- mesh.textures = new_tex
251
- return mesh
252
-
253
- """
254
- A functions that disconnect faces in the mesh according to
255
- its UV seams. The number of vertices are made equal to the
256
- number of unique vertices its UV layout, while the faces list
257
- is intact.
258
- """
259
-
260
- def disconnect_faces(self):
261
- mesh = self.mesh
262
- verts_list = mesh.verts_list()
263
- faces_list = mesh.faces_list()
264
- verts_uvs_list = mesh.textures.verts_uvs_list()
265
- faces_uvs_list = mesh.textures.faces_uvs_list()
266
- packed_list = [v[f] for v, f in zip(verts_list, faces_list)]
267
- verts_disconnect_list = [
268
- torch.zeros(
269
- (verts_uvs_list[i].shape[0], 3),
270
- dtype=verts_list[0].dtype,
271
- device=verts_list[0].device,
272
- )
273
- for i in range(len(verts_list))
274
- ]
275
- for i in range(len(verts_list)):
276
- verts_disconnect_list[i][faces_uvs_list] = packed_list[i]
277
- assert not mesh.has_verts_normals(), "Not implemented for vertex normals"
278
- self.mesh_d = Meshes(verts_disconnect_list, faces_uvs_list, mesh.textures)
279
- return self.mesh_d
280
-
281
- """
282
- A function that construct a temp mesh for back-projection.
283
- Take a disconnected mesh and a rasterizer, the function calculates
284
- the projected faces as the UV, as use its original UV with pseudo
285
- z value as world space geometry.
286
- """
287
-
288
- def construct_uv_mesh(self):
289
- mesh = self.mesh_d
290
- verts_list = mesh.verts_list()
291
- verts_uvs_list = mesh.textures.verts_uvs_list()
292
- # faces_list = [torch.flip(faces, [-1]) for faces in mesh.faces_list()]
293
- new_verts_list = []
294
- for i, (verts, verts_uv) in enumerate(zip(verts_list, verts_uvs_list)):
295
- verts = verts.clone()
296
- verts_uv = verts_uv.clone()
297
- verts[..., 0:2] = verts_uv[..., :]
298
- verts = (verts - 0.5) * 2
299
- verts[..., 2] *= 1
300
- new_verts_list.append(verts)
301
- textures_uv = mesh.textures.clone()
302
- self.mesh_uv = Meshes(new_verts_list, mesh.faces_list(), textures_uv)
303
- return self.mesh_uv
304
-
305
- # Set texture for the current mesh.
306
- def set_texture_map(self, texture):
307
- new_map = texture.permute(1, 2, 0)
308
- new_map = new_map.to(self.device)
309
- new_tex = TexturesUV(
310
- [new_map],
311
- self.mesh.textures.faces_uvs_padded(),
312
- self.mesh.textures.verts_uvs_padded(),
313
- sampling_mode=self.sampling_mode,
314
- )
315
- self.mesh.textures = new_tex
316
-
317
- # Set the initial normal noise texture
318
- # No generator here for replication of the experiment result. Add one as you wish
319
- def set_noise_texture(self, channels=None):
320
- if not channels:
321
- channels = self.channels
322
- noise_texture = torch.normal(
323
- 0, 1, (channels,) + self.target_size, device=self.device
324
- )
325
- self.set_texture_map(noise_texture)
326
- return noise_texture
327
-
328
- # Set the cameras given the camera poses and centers
329
- def set_cameras(self, camera_poses, centers=None, camera_distance=2.7, scale=None):
330
- elev = torch.FloatTensor([pose[0] for pose in camera_poses])
331
- azim = torch.FloatTensor([pose[1] for pose in camera_poses])
332
- print("camera_distance:{}".format(camera_distance))
333
- R, T = look_at_view_transform(
334
- dist=camera_distance, elev=elev, azim=azim, at=centers or ((0, 0, 0),)
335
- )
336
- # flip_mat = torch.from_numpy(np.diag([-1.0, 1.0, -1.0]) ).type(torch.FloatTensor).to(R.device)
337
- # R = R@flip_mat
338
- # R = R.permute(0, 2, 1)
339
- # T = T*torch.from_numpy(np.array([-1.0, 1.0, -1.0])).type(torch.FloatTensor).to(R.device)
340
- # print("v R size:{}, v T size:{}".format(R.size(), T.size()))
341
- # c2w = self.get_c2w(elev, [camera_distance]*len(elev), azim)
342
- # w2c = torch.linalg.inv(c2w)
343
- # R, T= w2c[:, :3, :3], w2c[:, :3, 3]
344
- print("R size:{}, T size:{}".format(R.size(), T.size()))
345
- # self.cameras = CustomOrthographicCameras(device=self.device, R=R, T=T, scale_xyz=scale or ((1,1,1),), znear=0.1, min_x=-0.55, max_x=0.55, min_y=-0.55, max_y=0.55)
346
- self.cameras = FoVOrthographicCameras(
347
- device=self.device, R=R, T=T, scale_xyz=scale or ((1, 1, 1),)
348
- )
349
-
350
- # Set all necessary internal data for rendering and texture baking
351
- # Can be used to refresh after changing camera positions
352
- def set_cameras_and_render_settings(
353
- self,
354
- camera_poses,
355
- centers=None,
356
- camera_distance=2.7,
357
- render_size=None,
358
- scale=None,
359
- ):
360
- self.set_cameras(camera_poses, centers, camera_distance, scale=scale)
361
- if render_size is None:
362
- render_size = self.render_size
363
- if not hasattr(self, "renderer"):
364
- self.setup_renderer(size=render_size)
365
- if not hasattr(self, "mesh_d"):
366
- self.disconnect_faces()
367
- if not hasattr(self, "mesh_uv"):
368
- self.construct_uv_mesh()
369
- self.calculate_tex_gradient()
370
- self.calculate_visible_triangle_mask()
371
- _, _, _, cos_maps, _, _ = self.render_geometry()
372
- self.calculate_cos_angle_weights(cos_maps)
373
-
374
- # Setup renderers for rendering
375
- # max faces per bin set to 30000 to avoid overflow in many test cases.
376
- # You can use default value to let pytorch3d handle that for you.
377
- def setup_renderer(
378
- self,
379
- size=64,
380
- blur=0.0,
381
- face_per_pix=1,
382
- perspective_correct=False,
383
- channels=None,
384
- ):
385
- if not channels:
386
- channels = self.channels
387
-
388
- self.raster_settings = RasterizationSettings(
389
- image_size=size,
390
- blur_radius=blur,
391
- faces_per_pixel=face_per_pix,
392
- perspective_correct=perspective_correct,
393
- cull_backfaces=True,
394
- max_faces_per_bin=30000,
395
- )
396
-
397
- self.renderer = MeshRenderer(
398
- rasterizer=MeshRasterizer(
399
- cameras=self.cameras,
400
- raster_settings=self.raster_settings,
401
- ),
402
- shader=HardNChannelFlatShader(
403
- device=self.device,
404
- cameras=self.cameras,
405
- lights=self.lights,
406
- channels=channels,
407
- # materials=materials
408
- ),
409
- )
410
-
411
- # Bake screen-space cosine weights to UV space
412
- # May be able to reimplement using the generic "bake_texture" function, but it works so leave it here for now
413
- @torch.enable_grad()
414
- def calculate_cos_angle_weights(self, cos_angles, fill=True, channels=None):
415
- if not channels:
416
- channels = self.channels
417
- cos_maps = []
418
- tmp_mesh = self.mesh.clone()
419
- for i in range(len(self.cameras)):
420
-
421
- zero_map = torch.zeros(
422
- self.target_size + (channels,), device=self.device, requires_grad=True
423
- )
424
- optimizer = torch.optim.SGD([zero_map], lr=1, momentum=0)
425
- optimizer.zero_grad()
426
- zero_tex = TexturesUV(
427
- [zero_map],
428
- self.mesh.textures.faces_uvs_padded(),
429
- self.mesh.textures.verts_uvs_padded(),
430
- sampling_mode=self.sampling_mode,
431
- )
432
- tmp_mesh.textures = zero_tex
433
-
434
- images_predicted = self.renderer(
435
- tmp_mesh, cameras=self.cameras[i], lights=self.lights
436
- )
437
-
438
- loss = torch.sum((cos_angles[i, :, :, 0:1] ** 1 - images_predicted) ** 2)
439
- loss.backward()
440
- optimizer.step()
441
-
442
- if fill:
443
- zero_map = zero_map.detach() / (self.gradient_maps[i] + 1e-8)
444
- zero_map = voronoi_solve(
445
- zero_map, self.gradient_maps[i][..., 0], self.device
446
- )
447
- else:
448
- zero_map = zero_map.detach() / (self.gradient_maps[i] + 1e-8)
449
- cos_maps.append(zero_map)
450
- self.cos_maps = cos_maps
451
-
452
- # Get geometric info from fragment shader
453
- # Can be used for generating conditioning image and cosine weights
454
- # Returns some information you may not need, remember to release them for memory saving
455
- @torch.no_grad()
456
- def render_geometry(self, image_size=None):
457
- if image_size:
458
- size = self.renderer.rasterizer.raster_settings.image_size
459
- self.renderer.rasterizer.raster_settings.image_size = image_size
460
- shader = self.renderer.shader
461
- self.renderer.shader = HardGeometryShader(
462
- device=self.device, cameras=self.cameras[0], lights=self.lights
463
- )
464
- tmp_mesh = self.mesh.clone()
465
-
466
- verts, normals, depths, cos_angles, texels, fragments = self.renderer(
467
- tmp_mesh.extend(len(self.cameras)), cameras=self.cameras, lights=self.lights
468
- )
469
- self.renderer.shader = shader
470
-
471
- if image_size:
472
- self.renderer.rasterizer.raster_settings.image_size = size
473
-
474
- return verts, normals, depths, cos_angles, texels, fragments
475
-
476
- # Project world normal to view space and normalize
477
- @torch.no_grad()
478
- def decode_view_normal(self, normals):
479
- w2v_mat = self.cameras.get_full_projection_transform()
480
- normals_view = torch.clone(normals)[:, :, :, 0:3]
481
- normals_view = normals_view.reshape(normals_view.shape[0], -1, 3)
482
- normals_view = w2v_mat.transform_normals(normals_view)
483
- normals_view = normals_view.reshape(normals.shape[0:3] + (3,))
484
- normals_view[:, :, :, 2] *= -1
485
- normals = (normals_view[..., 0:3] + 1) * normals[
486
- ..., 3:
487
- ] / 2 + torch.FloatTensor(((((0.5, 0.5, 1))))).to(self.device) * (
488
- 1 - normals[..., 3:]
489
- )
490
- # normals = torch.cat([normal for normal in normals], dim=1)
491
- normals = normals.clamp(0, 1)
492
- return normals
493
-
494
- # Normalize absolute depth to inverse depth
495
- @torch.no_grad()
496
- def decode_normalized_depth(self, depths, batched_norm=False):
497
- view_z, mask = depths.unbind(-1)
498
- view_z = view_z * mask + 100 * (1 - mask)
499
- inv_z = 1 / view_z
500
- inv_z_min = inv_z * mask + 100 * (1 - mask)
501
- if not batched_norm:
502
- max_ = torch.max(inv_z, 1, keepdim=True)
503
- max_ = torch.max(max_[0], 2, keepdim=True)[0]
504
-
505
- min_ = torch.min(inv_z_min, 1, keepdim=True)
506
- min_ = torch.min(min_[0], 2, keepdim=True)[0]
507
- else:
508
- max_ = torch.max(inv_z)
509
- min_ = torch.min(inv_z_min)
510
- inv_z = (inv_z - min_) / (max_ - min_)
511
- inv_z = inv_z.clamp(0, 1)
512
- inv_z = inv_z[..., None].repeat(1, 1, 1, 3)
513
-
514
- return inv_z
515
-
516
- # Multiple screen pixels could pass gradient to a same texel
517
- # We can precalculate this gradient strength and use it to normalize gradients when we bake textures
518
- @torch.enable_grad()
519
- def calculate_tex_gradient(self, channels=None):
520
- if not channels:
521
- channels = self.channels
522
- tmp_mesh = self.mesh.clone()
523
- gradient_maps = []
524
- for i in range(len(self.cameras)):
525
- zero_map = torch.zeros(
526
- self.target_size + (channels,), device=self.device, requires_grad=True
527
- )
528
- optimizer = torch.optim.SGD([zero_map], lr=1, momentum=0)
529
- optimizer.zero_grad()
530
- zero_tex = TexturesUV(
531
- [zero_map],
532
- self.mesh.textures.faces_uvs_padded(),
533
- self.mesh.textures.verts_uvs_padded(),
534
- sampling_mode=self.sampling_mode,
535
- )
536
- tmp_mesh.textures = zero_tex
537
- images_predicted = self.renderer(
538
- tmp_mesh, cameras=self.cameras[i], lights=self.lights
539
- )
540
- loss = torch.sum((1 - images_predicted) ** 2)
541
- loss.backward()
542
- optimizer.step()
543
-
544
- gradient_maps.append(zero_map.detach())
545
-
546
- self.gradient_maps = gradient_maps
547
-
548
- # Get the UV space masks of triangles visible in each view
549
- # First get face ids from each view, then filter pixels on UV space to generate masks
550
-
551
- @torch.no_grad()
552
- def get_c2w(
553
- self,
554
- elevation_deg: LIST_TYPE,
555
- distance: LIST_TYPE,
556
- azimuth_deg: Optional[LIST_TYPE],
557
- num_views: Optional[int] = 1,
558
- device: Optional[str] = None,
559
- ) -> torch.FloatTensor:
560
- if azimuth_deg is None:
561
- assert (
562
- num_views is not None
563
- ), "num_views must be provided if azimuth_deg is None."
564
- azimuth_deg = torch.linspace(
565
- 0, 360, num_views + 1, dtype=torch.float32, device=device
566
- )[:-1]
567
- else:
568
- num_views = len(azimuth_deg)
569
-
570
- def list_to_pt(
571
- x: LIST_TYPE,
572
- dtype: Optional[torch.dtype] = None,
573
- device: Optional[str] = None,
574
- ) -> torch.Tensor:
575
- if isinstance(x, list) or isinstance(x, np.ndarray):
576
- return torch.tensor(x, dtype=dtype, device=device)
577
- return x.to(dtype=dtype)
578
-
579
- azimuth_deg = list_to_pt(azimuth_deg, dtype=torch.float32, device=device)
580
- elevation_deg = list_to_pt(elevation_deg, dtype=torch.float32, device=device)
581
- camera_distances = list_to_pt(distance, dtype=torch.float32, device=device)
582
- elevation = elevation_deg * math.pi / 180
583
- azimuth = azimuth_deg * math.pi / 180
584
- camera_positions = torch.stack(
585
- [
586
- camera_distances * torch.cos(elevation) * torch.cos(azimuth),
587
- camera_distances * torch.cos(elevation) * torch.sin(azimuth),
588
- camera_distances * torch.sin(elevation),
589
- ],
590
- dim=-1,
591
- )
592
- center = torch.zeros_like(camera_positions)
593
- up = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)[
594
- None, :
595
- ].repeat(num_views, 1)
596
- lookat = F.normalize(center - camera_positions, dim=-1)
597
- right = F.normalize(torch.cross(lookat, up, dim=-1), dim=-1)
598
- up = F.normalize(torch.cross(right, lookat, dim=-1), dim=-1)
599
- c2w3x4 = torch.cat(
600
- [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
601
- dim=-1,
602
- )
603
- c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
604
- c2w[:, 3, 3] = 1.0
605
- return c2w
606
-
607
- @torch.no_grad()
608
- def calculate_visible_triangle_mask(self, channels=None, image_size=(512, 512)):
609
- if not channels:
610
- channels = self.channels
611
-
612
- pix2face_list = []
613
- for i in range(len(self.cameras)):
614
- self.renderer.rasterizer.raster_settings.image_size = image_size
615
- pix2face = self.renderer.rasterizer(
616
- self.mesh_d, cameras=self.cameras[i]
617
- ).pix_to_face
618
- self.renderer.rasterizer.raster_settings.image_size = self.render_size
619
- pix2face_list.append(pix2face)
620
-
621
- if not hasattr(self, "mesh_uv"):
622
- self.construct_uv_mesh()
623
-
624
- raster_settings = RasterizationSettings(
625
- image_size=self.target_size,
626
- blur_radius=0,
627
- faces_per_pixel=1,
628
- perspective_correct=False,
629
- cull_backfaces=False,
630
- max_faces_per_bin=30000,
631
- )
632
-
633
- R, T = look_at_view_transform(dist=2, elev=0, azim=0)
634
- # flip_mat = torch.from_numpy(np.diag([-1.0, 1.0, -1.0]) ).type(torch.FloatTensor).to(R.device)
635
- # R = R@flip_mat
636
- # T = T*torch.tensor(np.array([-1.0, 1.0, -1.0])).type(torch.FloatTensor).to(R.device)
637
- # c2w = self.get_c2w([0], [1.8], [0])
638
- # w2c = torch.linalg.inv(c2w)[:, :3,:]
639
- # R, T= w2c[:, :3,:3], w2c[:, :3, 3]
640
- # print("R size:{}, T size:{}".format(R.size(), T.size()))
641
- cameras = FoVOrthographicCameras(device=self.device, R=R, T=T)
642
- # cameras = CustomOrthographicCameras(device=self.device, R=R, T=T)
643
-
644
- # cameras = CustomOrthographicCameras(device=self.device, R=R, T=T, znear=0.1, min_x=-0.55, max_x=0.55, min_y=-0.55, max_y=0.55)
645
-
646
- rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
647
- uv_pix2face = rasterizer(self.mesh_uv).pix_to_face
648
-
649
- visible_triangles = []
650
- for i in range(len(pix2face_list)):
651
- valid_faceid = torch.unique(pix2face_list[i])
652
- valid_faceid = valid_faceid[1:] if valid_faceid[0] == -1 else valid_faceid
653
- mask = torch.isin(uv_pix2face[0], valid_faceid, assume_unique=False)
654
- # uv_pix2face[0][~mask] = -1
655
- triangle_mask = torch.ones(self.target_size + (1,), device=self.device)
656
- triangle_mask[~mask] = 0
657
-
658
- triangle_mask[:, 1:][triangle_mask[:, :-1] > 0] = 1
659
- triangle_mask[:, :-1][triangle_mask[:, 1:] > 0] = 1
660
- triangle_mask[1:, :][triangle_mask[:-1, :] > 0] = 1
661
- triangle_mask[:-1, :][triangle_mask[1:, :] > 0] = 1
662
- visible_triangles.append(triangle_mask)
663
-
664
- self.visible_triangles = visible_triangles
665
-
666
- # Render the current mesh and texture from current cameras
667
- def render_textured_views(self):
668
- meshes = self.mesh.extend(len(self.cameras))
669
- images_predicted = self.renderer(
670
- meshes, cameras=self.cameras, lights=self.lights
671
- )
672
-
673
- return [image.permute(2, 0, 1) for image in images_predicted]
674
-
675
- @torch.no_grad()
676
- def get_point_validation_by_o3d(
677
- self, points, eye_position, hidden_point_removal_radius=200
678
- ):
679
- point_visibility = torch.zeros((points.shape[0]), device=points.device).bool()
680
-
681
- pcd = o3d.geometry.PointCloud(
682
- points=o3d.utility.Vector3dVector(points.cpu().numpy())
683
- )
684
- camera_pose = (
685
- eye_position.get_camera_center().squeeze().cpu().numpy().astype(np.float64)
686
- )
687
- # o3d_camera = [0, 0, diameter]
688
- diameter = np.linalg.norm(
689
- np.asarray(pcd.get_max_bound()) - np.asarray(pcd.get_min_bound())
690
- )
691
- radius = diameter * 200 # The radius of the sperical projection
692
- _, pt_map = pcd.hidden_point_removal(camera_pose, radius)
693
-
694
- visible_point_ids = np.array(pt_map)
695
-
696
- point_visibility[visible_point_ids] = True
697
- return point_visibility
698
-
699
- @torch.no_grad()
700
- def hidden_judge(self, camera, texture_dim):
701
- mesh = self.mesh
702
-
703
- verts = mesh.verts_packed()
704
- faces = mesh.faces_packed()
705
- verts_uv = mesh.textures.verts_uvs_padded()[0] # 获取打包后的 UV 坐标 (V, 2)
706
- faces_uv = mesh.textures.faces_uvs_padded()[0]
707
- uv_face_attr = torch.index_select(
708
- verts_uv, 0, faces_uv.view(-1)
709
- ) # 选择对应顶点的 UV 坐标
710
- uv_face_attr = uv_face_attr.view(
711
- faces.shape[0], faces_uv.shape[1], 2
712
- ).unsqueeze(0)
713
- x, y, z = verts[:, 0], verts[:, 1], verts[:, 2]
714
- mesh_out_of_range = False
715
- if (
716
- x.min() < -1
717
- or x.max() > 1
718
- or y.min() < -1
719
- or y.max() > 1
720
- or z.min() < -1
721
- or z.max() > 1
722
- ):
723
- mesh_out_of_range = True
724
- face_vertices_world = kal.ops.mesh.index_vertices_by_faces(
725
- verts.unsqueeze(0), faces
726
- )
727
- face_vertices_z = torch.zeros_like(
728
- face_vertices_world[:, :, :, -1], device=verts.device
729
- )
730
- uv_position, face_idx = kal.render.mesh.rasterize(
731
- texture_dim,
732
- texture_dim,
733
- face_vertices_z,
734
- uv_face_attr * 2 - 1,
735
- face_features=face_vertices_world,
736
- )
737
- uv_position = torch.clamp(uv_position, -1, 1)
738
- uv_position[face_idx == -1] = 0
739
-
740
- points = uv_position.reshape(-1, 3)
741
- mask = points[:, 0] != 0
742
- valid_points = points[mask]
743
- # np.save("tmp/pcd.npy", valid_points.cpu().numpy())
744
- # print(camera.get_camera_center())
745
-
746
- points_visibility = self.get_point_validation_by_o3d(
747
- valid_points, camera
748
- ).float()
749
- visibility_map = torch.zeros((texture_dim * texture_dim,)).to(self.device)
750
- visibility_map[mask] = points_visibility
751
- visibility_map = visibility_map.reshape((texture_dim, texture_dim))
752
- return visibility_map
753
-
754
- @torch.enable_grad()
755
- def bake_texture(
756
- self,
757
- views=None,
758
- main_views=[],
759
- cos_weighted=True,
760
- channels=None,
761
- exp=None,
762
- noisy=False,
763
- generator=None,
764
- smooth_colorize=False,
765
- ):
766
- if not exp:
767
- exp = 1
768
- if not channels:
769
- channels = self.channels
770
- views = [view.permute(1, 2, 0) for view in views]
771
-
772
- tmp_mesh = self.mesh
773
- bake_maps = [
774
- torch.zeros(
775
- self.target_size + (views[0].shape[2],),
776
- device=self.device,
777
- requires_grad=True,
778
- )
779
- for view in views
780
- ]
781
- optimizer = torch.optim.SGD(bake_maps, lr=1, momentum=0)
782
- optimizer.zero_grad()
783
- loss = 0
784
- for i in range(len(self.cameras)):
785
- bake_tex = TexturesUV(
786
- [bake_maps[i]],
787
- tmp_mesh.textures.faces_uvs_padded(),
788
- tmp_mesh.textures.verts_uvs_padded(),
789
- sampling_mode=self.sampling_mode,
790
- )
791
- tmp_mesh.textures = bake_tex
792
- images_predicted = self.renderer(
793
- tmp_mesh,
794
- cameras=self.cameras[i],
795
- lights=self.lights,
796
- device=self.device,
797
- )
798
- predicted_rgb = images_predicted[..., :-1]
799
- loss += (((predicted_rgb[...] - views[i])) ** 2).sum()
800
- loss.backward(retain_graph=False)
801
- optimizer.step()
802
-
803
- total_weights = 0
804
- baked = 0
805
- for i in range(len(bake_maps)):
806
- normalized_baked_map = bake_maps[i].detach() / (
807
- self.gradient_maps[i] + 1e-8
808
- )
809
- bake_map = voronoi_solve(
810
- normalized_baked_map, self.gradient_maps[i][..., 0], self.device
811
- )
812
- # bake_map = voronoi_solve(normalized_baked_map, self.visible_triangles[i].squeeze())
813
-
814
- weight = self.visible_triangles[i] * (self.cos_maps[i]) ** exp
815
- if smooth_colorize:
816
- visibility_map = self.hidden_judge(
817
- self.cameras[i], self.target_size[0]
818
- ).unsqueeze(-1)
819
- weight *= visibility_map
820
- if noisy:
821
- noise = (
822
- torch.rand(weight.shape[:-1] + (1,), generator=generator)
823
- .type(weight.dtype)
824
- .to(weight.device)
825
- )
826
- weight *= noise
827
- total_weights += weight
828
-
829
- baked += bake_map * weight
830
- baked /= total_weights + 1e-8
831
-
832
- whole_visible_mask = None
833
- if not smooth_colorize:
834
- baked = voronoi_solve(baked, total_weights[..., 0], self.device)
835
- tmp_mesh.textures = TexturesUV(
836
- [baked],
837
- tmp_mesh.textures.faces_uvs_padded(),
838
- tmp_mesh.textures.verts_uvs_padded(),
839
- sampling_mode=self.sampling_mode,
840
- )
841
- else: # smooth colorize
842
- baked = voronoi_solve(baked, total_weights[..., 0], self.device)
843
- whole_visible_mask = self.visible_triangles[0].to(torch.int32)
844
- for tensor in self.visible_triangles[1:]:
845
- whole_visible_mask = torch.bitwise_or(
846
- whole_visible_mask, tensor.to(torch.int32)
847
- )
848
-
849
- baked *= whole_visible_mask
850
- tmp_mesh.textures = TexturesUV(
851
- [baked],
852
- tmp_mesh.textures.faces_uvs_padded(),
853
- tmp_mesh.textures.verts_uvs_padded(),
854
- sampling_mode=self.sampling_mode,
855
- )
856
-
857
- extended_mesh = tmp_mesh.extend(len(self.cameras))
858
- images_predicted = self.renderer(
859
- extended_mesh, cameras=self.cameras, lights=self.lights
860
- )
861
- learned_views = [image.permute(2, 0, 1) for image in images_predicted]
862
-
863
- return learned_views, baked.permute(2, 0, 1), total_weights.permute(2, 0, 1)
864
-
865
- # Move the internel data to a specific device
866
- def to(self, device):
867
- for mesh_name in ["mesh", "mesh_d", "mesh_uv"]:
868
- if hasattr(self, mesh_name):
869
- mesh = getattr(self, mesh_name)
870
- setattr(self, mesh_name, mesh.to(device))
871
- for list_name in ["visible_triangles", "visibility_maps", "cos_maps"]:
872
- if hasattr(self, list_name):
873
- map_list = getattr(self, list_name)
874
- for i in range(len(map_list)):
875
- map_list[i] = map_list[i].to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
step1x3d_texture/renderer/shader.py DELETED
@@ -1,127 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- import pytorch3d
5
-
6
-
7
- from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj
8
- from pytorch3d.ops import interpolate_face_attributes
9
-
10
- from pytorch3d.structures import Meshes
11
- from pytorch3d.renderer import (
12
- look_at_view_transform,
13
- FoVPerspectiveCameras,
14
- AmbientLights,
15
- PointLights,
16
- DirectionalLights,
17
- Materials,
18
- RasterizationSettings,
19
- MeshRenderer,
20
- MeshRasterizer,
21
- SoftPhongShader,
22
- SoftSilhouetteShader,
23
- HardPhongShader,
24
- TexturesVertex,
25
- TexturesUV,
26
- Materials,
27
- )
28
- from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend
29
- from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties
30
-
31
- from pytorch3d.renderer.lighting import AmbientLights
32
- from pytorch3d.renderer.materials import Materials
33
- from pytorch3d.renderer.mesh.shader import ShaderBase
34
- from pytorch3d.renderer.mesh.shading import _apply_lighting, flat_shading
35
- from pytorch3d.renderer.mesh.rasterizer import Fragments
36
-
37
-
38
- """
39
- Customized the original pytorch3d hard flat shader to support N channel flat shading
40
- """
41
-
42
-
43
- class HardNChannelFlatShader(ShaderBase):
44
- """
45
- Per face lighting - the lighting model is applied using the average face
46
- position and the face normal. The blending function hard assigns
47
- the color of the closest face for each pixel.
48
-
49
- To use the default values, simply initialize the shader with the desired
50
- device e.g.
51
-
52
- .. code-block::
53
-
54
- shader = HardFlatShader(device=torch.device("cuda:0"))
55
- """
56
-
57
- def __init__(
58
- self,
59
- device="cpu",
60
- cameras: Optional[TensorProperties] = None,
61
- lights: Optional[TensorProperties] = None,
62
- materials: Optional[Materials] = None,
63
- blend_params: Optional[BlendParams] = None,
64
- channels: int = 3,
65
- ):
66
- self.channels = channels
67
- ones = ((1.0,) * channels,)
68
- zeros = ((0.0,) * channels,)
69
-
70
- if (
71
- not isinstance(lights, AmbientLights)
72
- or not lights.ambient_color.shape[-1] == channels
73
- ):
74
- lights = AmbientLights(
75
- ambient_color=ones,
76
- device=device,
77
- )
78
-
79
- if not materials or not materials.ambient_color.shape[-1] == channels:
80
- materials = Materials(
81
- device=device,
82
- diffuse_color=zeros,
83
- ambient_color=ones,
84
- specular_color=zeros,
85
- shininess=0.0,
86
- )
87
-
88
- blend_params_new = BlendParams(background_color=(1.0,) * channels)
89
- if not isinstance(blend_params, BlendParams):
90
- blend_params = blend_params_new
91
- else:
92
- background_color_ = blend_params.background_color
93
- if (
94
- isinstance(background_color_, Sequence[float])
95
- and not len(background_color_) == channels
96
- ):
97
- blend_params = blend_params_new
98
- if (
99
- isinstance(background_color_, torch.Tensor)
100
- and not background_color_.shape[-1] == channels
101
- ):
102
- blend_params = blend_params_new
103
-
104
- super().__init__(
105
- device,
106
- cameras,
107
- lights,
108
- materials,
109
- blend_params,
110
- )
111
-
112
- def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
113
- cameras = super()._get_cameras(**kwargs)
114
- texels = meshes.sample_textures(fragments)
115
- lights = kwargs.get("lights", self.lights)
116
- materials = kwargs.get("materials", self.materials)
117
- blend_params = kwargs.get("blend_params", self.blend_params)
118
- colors = flat_shading(
119
- meshes=meshes,
120
- fragments=fragments,
121
- texels=texels,
122
- lights=lights,
123
- cameras=cameras,
124
- materials=materials,
125
- )
126
- images = hard_rgb_blend(colors, fragments, blend_params)
127
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
step1x3d_texture/{renderer → texture_sync}/__init__.py RENAMED
File without changes
step1x3d_texture/texture_sync/geometry.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch3d
3
+ import torch.nn.functional as F
4
+
5
+ from pytorch3d.ops import interpolate_face_attributes
6
+
7
+ from pytorch3d.renderer import (
8
+ look_at_view_transform,
9
+ FoVPerspectiveCameras,
10
+ AmbientLights,
11
+ PointLights,
12
+ DirectionalLights,
13
+ Materials,
14
+ RasterizationSettings,
15
+ MeshRenderer,
16
+ MeshRasterizer,
17
+ SoftPhongShader,
18
+ SoftSilhouetteShader,
19
+ HardPhongShader,
20
+ TexturesVertex,
21
+ TexturesUV,
22
+ Materials,
23
+
24
+ )
25
+ from pytorch3d.renderer.blending import BlendParams,hard_rgb_blend
26
+ from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties
27
+ from pytorch3d.renderer.mesh.shader import ShaderBase
28
+
29
+
30
+ def get_cos_angle(
31
+ points, normals, camera_position
32
+ ):
33
+ '''
34
+ calculate cosine similarity between view->surface and surface normal.
35
+ '''
36
+
37
+ if points.shape != normals.shape:
38
+ msg = "Expected points and normals to have the same shape: got %r, %r"
39
+ raise ValueError(msg % (points.shape, normals.shape))
40
+
41
+ # Ensure all inputs have same batch dimension as points
42
+ matched_tensors = convert_to_tensors_and_broadcast(
43
+ points, camera_position, device=points.device
44
+ )
45
+ _, camera_position = matched_tensors
46
+
47
+ # Reshape direction and color so they have all the arbitrary intermediate
48
+ # dimensions as points. Assume first dim = batch dim and last dim = 3.
49
+ points_dims = points.shape[1:-1]
50
+ expand_dims = (-1,) + (1,) * len(points_dims)
51
+
52
+ if camera_position.shape != normals.shape:
53
+ camera_position = camera_position.view(expand_dims + (3,))
54
+
55
+ normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)
56
+
57
+ # Calculate the cosine value.
58
+ view_direction = camera_position - points
59
+ view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6)
60
+ cos_angle = torch.sum(view_direction * normals, dim=-1, keepdim=True)
61
+ cos_angle = cos_angle.clamp(0, 1)
62
+
63
+ # Cosine of the angle between the reflected light ray and the viewer
64
+ return cos_angle
65
+
66
+
67
+ def _geometry_shading_with_pixels(
68
+ meshes, fragments, lights, cameras, materials, texels
69
+ ):
70
+ """
71
+ Render pixel space vertex position, normal(world), depth, and cos angle
72
+
73
+ Args:
74
+ meshes: Batch of meshes
75
+ fragments: Fragments named tuple with the outputs of rasterization
76
+ lights: Lights class containing a batch of lights
77
+ cameras: Cameras class containing a batch of cameras
78
+ materials: Materials class containing a batch of material properties
79
+ texels: texture per pixel of shape (N, H, W, K, 3)
80
+
81
+ Returns:
82
+ colors: (N, H, W, K, 3)
83
+ pixel_coords: (N, H, W, K, 3), camera coordinates of each intersection.
84
+ """
85
+ verts = meshes.verts_packed() # (V, 3)
86
+ faces = meshes.faces_packed() # (F, 3)
87
+ vertex_normals = meshes.verts_normals_packed() # (V, 3)
88
+ faces_verts = verts[faces]
89
+ faces_normals = vertex_normals[faces]
90
+ pixel_coords_in_camera = interpolate_face_attributes(
91
+ fragments.pix_to_face, fragments.bary_coords, faces_verts
92
+ )
93
+ pixel_normals = interpolate_face_attributes(
94
+ fragments.pix_to_face, fragments.bary_coords, faces_normals
95
+ )
96
+
97
+ cos_angles = get_cos_angle(pixel_coords_in_camera, pixel_normals, cameras.get_camera_center())
98
+
99
+ return pixel_coords_in_camera, pixel_normals, fragments.zbuf[...,None], cos_angles
100
+
101
+
102
+ class HardGeometryShader(ShaderBase):
103
+ """
104
+ renders common geometric informations.
105
+
106
+
107
+ """
108
+
109
+ def forward(self, fragments, meshes, **kwargs):
110
+ cameras = super()._get_cameras(**kwargs)
111
+ texels = self.texel_from_uv(fragments, meshes)
112
+
113
+ lights = kwargs.get("lights", self.lights)
114
+ materials = kwargs.get("materials", self.materials)
115
+ blend_params = kwargs.get("blend_params", self.blend_params)
116
+ verts, normals, depths, cos_angles = _geometry_shading_with_pixels(
117
+ meshes=meshes,
118
+ fragments=fragments,
119
+ texels=texels,
120
+ lights=lights,
121
+ cameras=cameras,
122
+ materials=materials,
123
+ )
124
+ verts = hard_rgb_blend(verts, fragments, blend_params)
125
+ normals = hard_rgb_blend(normals, fragments, blend_params)
126
+ depths = hard_rgb_blend(depths, fragments, blend_params)
127
+ cos_angles = hard_rgb_blend(cos_angles, fragments, blend_params)
128
+ texels = hard_rgb_blend(texels, fragments, blend_params)
129
+ return verts, normals, depths, cos_angles, texels, fragments
130
+
131
+ def texel_from_uv(self, fragments, meshes):
132
+ texture_tmp = meshes.textures
133
+ maps_tmp = texture_tmp.maps_padded()
134
+ uv_color = [ [[1,0],[1,1]],[[0,0],[0,1]] ]
135
+ uv_color = torch.FloatTensor(uv_color).to(maps_tmp[0].device).type(maps_tmp[0].dtype)
136
+ uv_texture = TexturesUV([uv_color.clone() for t in maps_tmp], texture_tmp.faces_uvs_padded(), texture_tmp.verts_uvs_padded(), sampling_mode="bilinear")
137
+ meshes.textures = uv_texture
138
+ texels = meshes.sample_textures(fragments)
139
+ meshes.textures = texture_tmp
140
+ texels = torch.cat((texels, texels[...,-1:]*0), dim=-1)
141
+ return texels
step1x3d_texture/texture_sync/project.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch3d
3
+
4
+
5
+ from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj, IO
6
+
7
+ from pytorch3d.structures import Meshes
8
+ from pytorch3d.renderer import (
9
+ look_at_view_transform,
10
+ FoVPerspectiveCameras,
11
+ FoVOrthographicCameras,
12
+ AmbientLights,
13
+ PointLights,
14
+ DirectionalLights,
15
+ Materials,
16
+ RasterizationSettings,
17
+ MeshRenderer,
18
+ MeshRasterizer,
19
+ TexturesUV
20
+ )
21
+
22
+ from .geometry import HardGeometryShader
23
+ from .shader import HardNChannelFlatShader
24
+ from .voronoi import voronoi_solve
25
+ from trimesh import Trimesh
26
+
27
+ # Pytorch3D based renderering functions, managed in a class
28
+ # Render size is recommended to be the same as your latent view size
29
+ # DO NOT USE "bilinear" sampling when you are handling latents.
30
+ # Stable Diffusion has 4 latent channels so use channels=4
31
+
32
+ class UVProjection():
33
+ def __init__(self, texture_size=96, render_size=64, sampling_mode="nearest", channels=3, device=None):
34
+ self.channels = channels
35
+ self.device = device or torch.device("cpu")
36
+ self.lights = AmbientLights(ambient_color=((1.0,)*channels,), device=self.device)
37
+ self.target_size = (texture_size,texture_size)
38
+ self.render_size = render_size
39
+ self.sampling_mode = sampling_mode
40
+
41
+
42
+ # # Load obj mesh, rescale the mesh to fit into the bounding box
43
+ # def load_mesh(self, mesh_path, scale_factor=2.0, auto_center=True, autouv=False):
44
+ # mesh = load_objs_as_meshes([mesh_path], device=self.device)
45
+ # if auto_center:
46
+ # verts = mesh.verts_packed()
47
+ # max_bb = (verts - 0).max(0)[0]
48
+ # min_bb = (verts - 0).min(0)[0]
49
+ # scale = (max_bb - min_bb).max()/2
50
+ # center = (max_bb+min_bb) /2
51
+ # mesh.offset_verts_(-center)
52
+ # mesh.scale_verts_((scale_factor / float(scale)))
53
+ # else:
54
+ # mesh.scale_verts_((scale_factor))
55
+
56
+ # if autouv or (mesh.textures is None):
57
+ # mesh = self.uv_unwrap(mesh)
58
+ # self.mesh = mesh
59
+ # Load obj mesh, rescale the mesh to fit into the bounding box
60
+ def load_mesh(self, mesh, scale_factor=2.0, auto_center=True, autouv=False, normals=None):
61
+ if isinstance(mesh, Trimesh):
62
+ vertices = torch.tensor(mesh.vertices, dtype=torch.float32).to(self.device)
63
+ faces = torch.tensor(mesh.faces, dtype=torch.int64).to(self.device)
64
+ if faces.ndim == 1:
65
+ faces = faces.unsqueeze(0)
66
+ mesh = Meshes(
67
+ verts=[vertices],
68
+ faces=[faces]
69
+ )
70
+ verts = mesh.verts_packed()
71
+ mesh = mesh.update_padded(verts[None,:, :])
72
+ # from pytorch3d.renderer.mesh.textures import TexturesVertex
73
+ # if normals is None:
74
+ # normals = mesh.verts_normals_packed()
75
+ # # set normals as vertext colors
76
+ # mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5])
77
+ elif isinstance(mesh, str) and os.path.isfile(mesh):
78
+ mesh = load_objs_as_meshes([mesh_path], device=self.device)
79
+ if auto_center:
80
+ verts = mesh.verts_packed()
81
+ max_bb = (verts - 0).max(0)[0]
82
+ min_bb = (verts - 0).min(0)[0]
83
+ scale = (max_bb - min_bb).max()/2
84
+ center = (max_bb+min_bb) /2
85
+ mesh.offset_verts_(-center)
86
+ mesh.scale_verts_((scale_factor / float(scale)))
87
+ else:
88
+ mesh.scale_verts_((scale_factor))
89
+
90
+ if autouv or (mesh.textures is None):
91
+ mesh = self.uv_unwrap(mesh)
92
+ self.mesh = mesh
93
+
94
+ def load_glb_mesh(self, mesh_path, scale_factor=2.0, auto_center=True, autouv=False):
95
+ from pytorch3d.io.experimental_gltf_io import MeshGlbFormat
96
+ io = IO()
97
+ io.register_meshes_format(MeshGlbFormat())
98
+ with open(mesh_path, "rb") as f:
99
+ mesh = io.load_mesh(f, include_textures=True, device=self.device)
100
+ if auto_center:
101
+ verts = mesh.verts_packed()
102
+ max_bb = (verts - 0).max(0)[0]
103
+ min_bb = (verts - 0).min(0)[0]
104
+ scale = (max_bb - min_bb).max()/2
105
+ center = (max_bb+min_bb) /2
106
+ mesh.offset_verts_(-center)
107
+ mesh.scale_verts_((scale_factor / float(scale)))
108
+ else:
109
+ mesh.scale_verts_((scale_factor))
110
+ if autouv or (mesh.textures is None):
111
+ mesh = self.uv_unwrap(mesh)
112
+ self.mesh = mesh
113
+
114
+
115
+ # Save obj mesh
116
+ def save_mesh(self, mesh_path, texture):
117
+ save_obj(mesh_path,
118
+ self.mesh.verts_list()[0],
119
+ self.mesh.faces_list()[0],
120
+ verts_uvs= self.mesh.textures.verts_uvs_list()[0],
121
+ faces_uvs= self.mesh.textures.faces_uvs_list()[0],
122
+ texture_map=texture)
123
+
124
+ # Code referred to TEXTure code (https://github.com/TEXTurePaper/TEXTurePaper.git)
125
+ def uv_unwrap(self, mesh):
126
+ verts_list = mesh.verts_list()[0]
127
+ faces_list = mesh.faces_list()[0]
128
+
129
+
130
+ import xatlas
131
+ import numpy as np
132
+ v_np = verts_list.cpu().numpy()
133
+ f_np = faces_list.int().cpu().numpy()
134
+ atlas = xatlas.Atlas()
135
+ atlas.add_mesh(v_np, f_np)
136
+ chart_options = xatlas.ChartOptions()
137
+ chart_options.max_iterations = 4
138
+ atlas.generate(chart_options=chart_options)
139
+ vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
140
+
141
+ vt = torch.from_numpy(vt_np.astype(np.float32)).type(verts_list.dtype).to(mesh.device)
142
+ ft = torch.from_numpy(ft_np.astype(np.int64)).type(faces_list.dtype).to(mesh.device)
143
+
144
+ new_map = torch.zeros(self.target_size+(self.channels,), device=mesh.device)
145
+ new_tex = TexturesUV(
146
+ [new_map],
147
+ [ft],
148
+ [vt],
149
+ sampling_mode=self.sampling_mode
150
+ )
151
+
152
+ mesh.textures = new_tex
153
+ return mesh
154
+
155
+
156
+ '''
157
+ A functions that disconnect faces in the mesh according to
158
+ its UV seams. The number of vertices are made equal to the
159
+ number of unique vertices its UV layout, while the faces list
160
+ is intact.
161
+ '''
162
+ def disconnect_faces(self):
163
+ mesh = self.mesh
164
+ verts_list = mesh.verts_list()
165
+ faces_list = mesh.faces_list()
166
+ verts_uvs_list = mesh.textures.verts_uvs_list()
167
+ faces_uvs_list = mesh.textures.faces_uvs_list()
168
+ packed_list = [v[f] for v,f in zip(verts_list, faces_list)]
169
+ verts_disconnect_list = [
170
+ torch.zeros(
171
+ (verts_uvs_list[i].shape[0], 3),
172
+ dtype=verts_list[0].dtype,
173
+ device=verts_list[0].device
174
+ )
175
+ for i in range(len(verts_list))]
176
+ for i in range(len(verts_list)):
177
+ verts_disconnect_list[i][faces_uvs_list] = packed_list[i]
178
+ assert not mesh.has_verts_normals(), "Not implemented for vertex normals"
179
+ self.mesh_d = Meshes(verts_disconnect_list, faces_uvs_list, mesh.textures)
180
+ return self.mesh_d
181
+
182
+
183
+ '''
184
+ A function that construct a temp mesh for back-projection.
185
+ Take a disconnected mesh and a rasterizer, the function calculates
186
+ the projected faces as the UV, as use its original UV with pseudo
187
+ z value as world space geometry.
188
+ '''
189
+ def construct_uv_mesh(self):
190
+ mesh = self.mesh_d
191
+ verts_list = mesh.verts_list()
192
+ verts_uvs_list = mesh.textures.verts_uvs_list()
193
+ # faces_list = [torch.flip(faces, [-1]) for faces in mesh.faces_list()]
194
+ new_verts_list = []
195
+ for i, (verts, verts_uv) in enumerate(zip(verts_list, verts_uvs_list)):
196
+ verts = verts.clone()
197
+ verts_uv = verts_uv.clone()
198
+ verts[...,0:2] = verts_uv[...,:]
199
+ verts = (verts - 0.5) * 2
200
+ verts[...,2] *= 1
201
+ new_verts_list.append(verts)
202
+ textures_uv = mesh.textures.clone()
203
+ self.mesh_uv = Meshes(new_verts_list, mesh.faces_list(), textures_uv)
204
+ return self.mesh_uv
205
+
206
+
207
+ # Set texture for the current mesh.
208
+ def set_texture_map(self, texture):
209
+ new_map = texture.permute(1, 2, 0)
210
+ new_map = new_map.to(self.device)
211
+ new_tex = TexturesUV(
212
+ [new_map],
213
+ self.mesh.textures.faces_uvs_padded(),
214
+ self.mesh.textures.verts_uvs_padded(),
215
+ sampling_mode=self.sampling_mode
216
+ )
217
+ self.mesh.textures = new_tex
218
+
219
+
220
+ # Set the initial normal noise texture
221
+ # No generator here for replication of the experiment result. Add one as you wish
222
+ def set_noise_texture(self, channels=None):
223
+ if not channels:
224
+ channels = self.channels
225
+ noise_texture = torch.normal(0, 1, (channels,) + self.target_size, device=self.device)
226
+ self.set_texture_map(noise_texture)
227
+ return noise_texture
228
+
229
+
230
+ # Set the cameras given the camera poses and centers
231
+ def set_cameras(self, camera_poses, centers=None, camera_distance=2.7, scale=None):
232
+ elev = torch.FloatTensor([pose[0] for pose in camera_poses])
233
+ azim = torch.FloatTensor([pose[1] for pose in camera_poses])
234
+ R, T = look_at_view_transform(dist=camera_distance, elev=elev, azim=azim, at=centers or ((0,0,0),))
235
+ # self.cameras = FoVOrthographicCameras(device=self.device, R=R, T=T, scale_xyz=scale or ((1,1,1),))
236
+ self.cameras = FoVOrthographicCameras(device=self.device, R=R, T=T, scale_xyz=scale or ((1,1,1),), znear=0.1, min_x=-0.55, max_x=0.55, min_y=-0.55, max_y=0.55)
237
+
238
+ # Set all necessary internal data for rendering and texture baking
239
+ # Can be used to refresh after changing camera positions
240
+ def set_cameras_and_render_settings(self, camera_poses, centers=None, camera_distance=2.7, render_size=None, scale=None):
241
+ self.set_cameras(camera_poses, centers, camera_distance, scale=scale)
242
+ if render_size is None:
243
+ render_size = self.render_size
244
+ if not hasattr(self, "renderer"):
245
+ self.setup_renderer(size=render_size)
246
+ if not hasattr(self, "mesh_d"):
247
+ self.disconnect_faces()
248
+ if not hasattr(self, "mesh_uv"):
249
+ self.construct_uv_mesh()
250
+ self.calculate_tex_gradient()
251
+ self.calculate_visible_triangle_mask()
252
+ _,_,_,cos_maps,_, _ = self.render_geometry()
253
+ self.calculate_cos_angle_weights(cos_maps)
254
+
255
+
256
+ # Setup renderers for rendering
257
+ # max faces per bin set to 30000 to avoid overflow in many test cases.
258
+ # You can use default value to let pytorch3d handle that for you.
259
+ def setup_renderer(self, size=64, blur=0.0, face_per_pix=1, perspective_correct=False, channels=None):
260
+ if not channels:
261
+ channels = self.channels
262
+
263
+ self.raster_settings = RasterizationSettings(
264
+ image_size=size,
265
+ blur_radius=blur,
266
+ faces_per_pixel=face_per_pix,
267
+ perspective_correct=perspective_correct,
268
+ cull_backfaces=True,
269
+ max_faces_per_bin=30000,
270
+ )
271
+
272
+ self.renderer = MeshRenderer(
273
+ rasterizer=MeshRasterizer(
274
+ cameras=self.cameras,
275
+ raster_settings=self.raster_settings,
276
+
277
+ ),
278
+ shader=HardNChannelFlatShader(
279
+ device=self.device,
280
+ cameras=self.cameras,
281
+ lights=self.lights,
282
+ channels=channels
283
+ # materials=materials
284
+ )
285
+ )
286
+
287
+
288
+ # Bake screen-space cosine weights to UV space
289
+ # May be able to reimplement using the generic "bake_texture" function, but it works so leave it here for now
290
+ @torch.enable_grad()
291
+ def calculate_cos_angle_weights(self, cos_angles, fill=True, channels=None):
292
+ if not channels:
293
+ channels = self.channels
294
+ cos_maps = []
295
+ tmp_mesh = self.mesh.clone()
296
+ for i in range(len(self.cameras)):
297
+
298
+ zero_map = torch.zeros(self.target_size+(channels,), device=self.device, requires_grad=True)
299
+ optimizer = torch.optim.SGD([zero_map], lr=1, momentum=0)
300
+ optimizer.zero_grad()
301
+ zero_tex = TexturesUV([zero_map], self.mesh.textures.faces_uvs_padded(), self.mesh.textures.verts_uvs_padded(), sampling_mode=self.sampling_mode)
302
+ tmp_mesh.textures = zero_tex
303
+
304
+ images_predicted = self.renderer(tmp_mesh, cameras=self.cameras[i], lights=self.lights)
305
+
306
+ loss = torch.sum((cos_angles[i,:,:,0:1]**1 - images_predicted)**2)
307
+ loss.backward()
308
+ optimizer.step()
309
+
310
+ if fill:
311
+ zero_map = zero_map.detach() / (self.gradient_maps[i] + 1E-8)
312
+ zero_map = voronoi_solve(zero_map, self.gradient_maps[i][...,0])
313
+ else:
314
+ zero_map = zero_map.detach() / (self.gradient_maps[i]+1E-8)
315
+ cos_maps.append(zero_map)
316
+ self.cos_maps = cos_maps
317
+
318
+
319
+ # Get geometric info from fragment shader
320
+ # Can be used for generating conditioning image and cosine weights
321
+ # Returns some information you may not need, remember to release them for memory saving
322
+ @torch.no_grad()
323
+ def render_geometry(self, image_size=None):
324
+ if image_size:
325
+ size = self.renderer.rasterizer.raster_settings.image_size
326
+ self.renderer.rasterizer.raster_settings.image_size = image_size
327
+ shader = self.renderer.shader
328
+ self.renderer.shader = HardGeometryShader(device=self.device, cameras=self.cameras[0], lights=self.lights)
329
+ tmp_mesh = self.mesh.clone()
330
+
331
+ verts, normals, depths, cos_angles, texels, fragments = self.renderer(tmp_mesh.extend(len(self.cameras)), cameras=self.cameras, lights=self.lights)
332
+ self.renderer.shader = shader
333
+
334
+ if image_size:
335
+ self.renderer.rasterizer.raster_settings.image_size = size
336
+
337
+ return verts, normals, depths, cos_angles, texels, fragments
338
+
339
+
340
+ # Project world normal to view space and normalize
341
+ @torch.no_grad()
342
+ def decode_view_normal(self, normals):
343
+ w2v_mat = self.cameras.get_full_projection_transform()
344
+ normals_view = torch.clone(normals)[:,:,:,0:3]
345
+ normals_view = normals_view.reshape(normals_view.shape[0], -1, 3)
346
+ normals_view = w2v_mat.transform_normals(normals_view)
347
+ normals_view = normals_view.reshape(normals.shape[0:3]+(3,))
348
+ normals_view[:,:,:,2] *= -1
349
+ normals = (normals_view[...,0:3]+1) * normals[...,3:] / 2 + torch.FloatTensor(((((0.5,0.5,1))))).to(self.device) * (1 - normals[...,3:])
350
+ # normals = torch.cat([normal for normal in normals], dim=1)
351
+ normals = normals.clamp(0, 1)
352
+ return normals
353
+
354
+
355
+ # Normalize absolute depth to inverse depth
356
+ @torch.no_grad()
357
+ def decode_normalized_depth(self, depths, batched_norm=False):
358
+ view_z, mask = depths.unbind(-1)
359
+ view_z = view_z * mask + 100 * (1-mask)
360
+ inv_z = 1 / view_z
361
+ inv_z_min = inv_z * mask + 100 * (1-mask)
362
+ if not batched_norm:
363
+ max_ = torch.max(inv_z, 1, keepdim=True)
364
+ max_ = torch.max(max_[0], 2, keepdim=True)[0]
365
+
366
+ min_ = torch.min(inv_z_min, 1, keepdim=True)
367
+ min_ = torch.min(min_[0], 2, keepdim=True)[0]
368
+ else:
369
+ max_ = torch.max(inv_z)
370
+ min_ = torch.min(inv_z_min)
371
+ inv_z = (inv_z - min_) / (max_ - min_)
372
+ inv_z = inv_z.clamp(0,1)
373
+ inv_z = inv_z[...,None].repeat(1,1,1,3)
374
+
375
+ return inv_z
376
+
377
+
378
+ # Multiple screen pixels could pass gradient to a same texel
379
+ # We can precalculate this gradient strength and use it to normalize gradients when we bake textures
380
+ @torch.enable_grad()
381
+ def calculate_tex_gradient(self, channels=None):
382
+ if not channels:
383
+ channels = self.channels
384
+ tmp_mesh = self.mesh.clone()
385
+ gradient_maps = []
386
+ for i in range(len(self.cameras)):
387
+ zero_map = torch.zeros(self.target_size+(channels,), device=self.device, requires_grad=True)
388
+ optimizer = torch.optim.SGD([zero_map], lr=1, momentum=0)
389
+ optimizer.zero_grad()
390
+ zero_tex = TexturesUV([zero_map], self.mesh.textures.faces_uvs_padded(), self.mesh.textures.verts_uvs_padded(), sampling_mode=self.sampling_mode)
391
+ tmp_mesh.textures = zero_tex
392
+ images_predicted = self.renderer(tmp_mesh, cameras=self.cameras[i], lights=self.lights)
393
+ loss = torch.sum((1 - images_predicted)**2)
394
+ loss.backward()
395
+ optimizer.step()
396
+
397
+ gradient_maps.append(zero_map.detach())
398
+
399
+ self.gradient_maps = gradient_maps
400
+
401
+
402
+ # Get the UV space masks of triangles visible in each view
403
+ # First get face ids from each view, then filter pixels on UV space to generate masks
404
+ @torch.no_grad()
405
+ def calculate_visible_triangle_mask(self, channels=None, image_size=(512,512)):
406
+ if not channels:
407
+ channels = self.channels
408
+
409
+ pix2face_list = []
410
+ for i in range(len(self.cameras)):
411
+ self.renderer.rasterizer.raster_settings.image_size=image_size
412
+ pix2face = self.renderer.rasterizer(self.mesh_d, cameras=self.cameras[i]).pix_to_face
413
+ self.renderer.rasterizer.raster_settings.image_size=self.render_size
414
+ pix2face_list.append(pix2face)
415
+
416
+ if not hasattr(self, "mesh_uv"):
417
+ self.construct_uv_mesh()
418
+
419
+ raster_settings = RasterizationSettings(
420
+ image_size=self.target_size,
421
+ blur_radius=0,
422
+ faces_per_pixel=1,
423
+ perspective_correct=False,
424
+ cull_backfaces=False,
425
+ max_faces_per_bin=30000,
426
+ )
427
+
428
+ R, T = look_at_view_transform(dist=2, elev=0, azim=0)
429
+ cameras = FoVOrthographicCameras(device=self.device, R=R, T=T)
430
+
431
+ rasterizer=MeshRasterizer(
432
+ cameras=cameras,
433
+ raster_settings=raster_settings
434
+ )
435
+ uv_pix2face = rasterizer(self.mesh_uv).pix_to_face
436
+
437
+ visible_triangles = []
438
+ for i in range(len(pix2face_list)):
439
+ valid_faceid = torch.unique(pix2face_list[i])
440
+ valid_faceid = valid_faceid[1:] if valid_faceid[0]==-1 else valid_faceid
441
+ mask = torch.isin(uv_pix2face[0], valid_faceid, assume_unique=False)
442
+ # uv_pix2face[0][~mask] = -1
443
+ triangle_mask = torch.ones(self.target_size+(1,), device=self.device)
444
+ triangle_mask[~mask] = 0
445
+
446
+ triangle_mask[:,1:][triangle_mask[:,:-1] > 0] = 1
447
+ triangle_mask[:,:-1][triangle_mask[:,1:] > 0] = 1
448
+ triangle_mask[1:,:][triangle_mask[:-1,:] > 0] = 1
449
+ triangle_mask[:-1,:][triangle_mask[1:,:] > 0] = 1
450
+ visible_triangles.append(triangle_mask)
451
+
452
+ self.visible_triangles = visible_triangles
453
+
454
+
455
+
456
+ # Render the current mesh and texture from current cameras
457
+ def render_textured_views(self):
458
+ meshes = self.mesh.extend(len(self.cameras))
459
+ images_predicted = self.renderer(meshes, cameras=self.cameras, lights=self.lights)
460
+
461
+ return [image.permute(2, 0, 1) for image in images_predicted]
462
+
463
+
464
+ # Bake views into a texture
465
+ # First bake into individual textures then combine based on cosine weight
466
+ @torch.enable_grad()
467
+ def bake_texture(self, views=None, main_views=[], cos_weighted=True, channels=None, exp=None, noisy=False, generator=None):
468
+ if not exp:
469
+ exp=1
470
+ if not channels:
471
+ channels = self.channels
472
+ views = [view.permute(1, 2, 0) for view in views]
473
+
474
+ tmp_mesh = self.mesh
475
+ bake_maps = [torch.zeros(self.target_size+(views[0].shape[2],), device=self.device, requires_grad=True) for view in views]
476
+ optimizer = torch.optim.SGD(bake_maps, lr=1, momentum=0)
477
+ optimizer.zero_grad()
478
+ loss = 0
479
+ for i in range(len(self.cameras)):
480
+ bake_tex = TexturesUV([bake_maps[i]], tmp_mesh.textures.faces_uvs_padded(), tmp_mesh.textures.verts_uvs_padded(), sampling_mode=self.sampling_mode)
481
+ tmp_mesh.textures = bake_tex
482
+ images_predicted = self.renderer(tmp_mesh, cameras=self.cameras[i], lights=self.lights, device=self.device)
483
+ predicted_rgb = images_predicted[..., :-1]
484
+ loss += (((predicted_rgb[...] - views[i]))**2).sum()
485
+ loss.backward(retain_graph=False)
486
+ optimizer.step()
487
+
488
+ total_weights = 0
489
+ baked = 0
490
+ for i in range(len(bake_maps)):
491
+ normalized_baked_map = bake_maps[i].detach() / (self.gradient_maps[i] + 1E-8)
492
+ bake_map = voronoi_solve(normalized_baked_map, self.gradient_maps[i][...,0])
493
+ weight = self.visible_triangles[i] * (self.cos_maps[i]) ** exp
494
+ if noisy:
495
+ noise = torch.rand(weight.shape[:-1]+(1,), generator=generator).type(weight.dtype).to(weight.device)
496
+ weight *= noise
497
+ total_weights += weight
498
+ baked += bake_map * weight
499
+ baked /= total_weights + 1E-8
500
+ baked = voronoi_solve(baked, total_weights[...,0])
501
+
502
+ bake_tex = TexturesUV([baked], tmp_mesh.textures.faces_uvs_padded(), tmp_mesh.textures.verts_uvs_padded(), sampling_mode=self.sampling_mode)
503
+ tmp_mesh.textures = bake_tex
504
+ extended_mesh = tmp_mesh.extend(len(self.cameras))
505
+ images_predicted = self.renderer(extended_mesh, cameras=self.cameras, lights=self.lights)
506
+ learned_views = [image.permute(2, 0, 1) for image in images_predicted]
507
+
508
+ return learned_views, baked.permute(2, 0, 1), total_weights.permute(2, 0, 1)
509
+
510
+
511
+ # Move the internel data to a specific device
512
+ def to(self, device):
513
+ for mesh_name in ["mesh", "mesh_d", "mesh_uv"]:
514
+ if hasattr(self, mesh_name):
515
+ mesh = getattr(self, mesh_name)
516
+ setattr(self, mesh_name, mesh.to(device))
517
+ for list_name in ["visible_triangles", "visibility_maps", "cos_maps"]:
518
+ if hasattr(self, list_name):
519
+ map_list = getattr(self, list_name)
520
+ for i in range(len(map_list)):
521
+ map_list[i] = map_list[i].to(device)
step1x3d_texture/texture_sync/shader.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import pytorch3d
5
+
6
+
7
+ from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj
8
+ from pytorch3d.ops import interpolate_face_attributes
9
+
10
+ from pytorch3d.structures import Meshes
11
+ from pytorch3d.renderer import (
12
+ look_at_view_transform,
13
+ FoVPerspectiveCameras,
14
+ AmbientLights,
15
+ PointLights,
16
+ DirectionalLights,
17
+ Materials,
18
+ RasterizationSettings,
19
+ MeshRenderer,
20
+ MeshRasterizer,
21
+ SoftPhongShader,
22
+ SoftSilhouetteShader,
23
+ HardPhongShader,
24
+ TexturesVertex,
25
+ TexturesUV,
26
+ Materials,
27
+
28
+ )
29
+ from pytorch3d.renderer.blending import BlendParams,hard_rgb_blend
30
+ from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties
31
+
32
+ from pytorch3d.renderer.lighting import AmbientLights
33
+ from pytorch3d.renderer.materials import Materials
34
+ from pytorch3d.renderer.mesh.shader import ShaderBase
35
+ from pytorch3d.renderer.mesh.shading import _apply_lighting, flat_shading
36
+ from pytorch3d.renderer.mesh.rasterizer import Fragments
37
+
38
+
39
+ '''
40
+ Customized the original pytorch3d hard flat shader to support N channel flat shading
41
+ '''
42
+ class HardNChannelFlatShader(ShaderBase):
43
+ """
44
+ Per face lighting - the lighting model is applied using the average face
45
+ position and the face normal. The blending function hard assigns
46
+ the color of the closest face for each pixel.
47
+
48
+ To use the default values, simply initialize the shader with the desired
49
+ device e.g.
50
+
51
+ .. code-block::
52
+
53
+ shader = HardFlatShader(device=torch.device("cuda:0"))
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ device = "cpu",
59
+ cameras: Optional[TensorProperties] = None,
60
+ lights: Optional[TensorProperties] = None,
61
+ materials: Optional[Materials] = None,
62
+ blend_params: Optional[BlendParams] = None,
63
+ channels: int = 3,
64
+ ):
65
+ self.channels = channels
66
+ ones = ((1.0,)*channels,)
67
+ zeros = ((0.0,)*channels,)
68
+
69
+ if not isinstance(lights, AmbientLights) or not lights.ambient_color.shape[-1] == channels:
70
+ lights = AmbientLights(
71
+ ambient_color=ones,
72
+ device=device,
73
+ )
74
+
75
+ if not materials or not materials.ambient_color.shape[-1] == channels:
76
+ materials = Materials(
77
+ device=device,
78
+ diffuse_color=zeros,
79
+ ambient_color=ones,
80
+ specular_color=zeros,
81
+ shininess=0.0,
82
+ )
83
+
84
+ blend_params_new = BlendParams(background_color=(1.0,)*channels)
85
+ if not isinstance(blend_params, BlendParams):
86
+ blend_params = blend_params_new
87
+ else:
88
+ background_color_ = blend_params.background_color
89
+ if isinstance(background_color_, Sequence[float]) and not len(background_color_) == channels:
90
+ blend_params = blend_params_new
91
+ if isinstance(background_color_, torch.Tensor) and not background_color_.shape[-1] == channels:
92
+ blend_params = blend_params_new
93
+
94
+ super().__init__(
95
+ device,
96
+ cameras,
97
+ lights,
98
+ materials,
99
+ blend_params,
100
+ )
101
+
102
+
103
+ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
104
+ cameras = super()._get_cameras(**kwargs)
105
+ texels = meshes.sample_textures(fragments)
106
+ lights = kwargs.get("lights", self.lights)
107
+ materials = kwargs.get("materials", self.materials)
108
+ blend_params = kwargs.get("blend_params", self.blend_params)
109
+ colors = flat_shading(
110
+ meshes=meshes,
111
+ fragments=fragments,
112
+ texels=texels,
113
+ lights=lights,
114
+ cameras=cameras,
115
+ materials=materials,
116
+ )
117
+ images = hard_rgb_blend(colors, fragments, blend_params)
118
+ return images
step1x3d_texture/texture_sync/step_sync.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.utils.torch_utils import randn_tensor
3
+
4
+ '''
5
+
6
+ Customized Step Function
7
+ step on texture
8
+ '''
9
+ @torch.no_grad()
10
+ def step_tex_sync(
11
+ scheduler,
12
+ uvp,
13
+ model_output: torch.FloatTensor,
14
+ timestep: int,
15
+ sample: torch.FloatTensor,
16
+ texture: None,
17
+ generator=None,
18
+ return_dict: bool = True,
19
+ guidance_scale = 1,
20
+ main_views = [],
21
+ hires_original_views = True,
22
+ exp=None,
23
+ cos_weighted=True
24
+ ):
25
+ t = timestep
26
+
27
+ prev_t = scheduler.previous_timestep(t)
28
+
29
+ if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]:
30
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
31
+ else:
32
+ predicted_variance = None
33
+
34
+ # 1. compute alphas, betas
35
+ alpha_prod_t = scheduler.alphas_cumprod[t]
36
+ alpha_prod_t_prev = scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one
37
+ beta_prod_t = 1 - alpha_prod_t
38
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
39
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
40
+ current_beta_t = 1 - current_alpha_t
41
+
42
+ # 2. compute predicted original sample from predicted noise also called
43
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
44
+ if scheduler.config.prediction_type == "epsilon":
45
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
46
+ elif scheduler.config.prediction_type == "sample":
47
+ pred_original_sample = model_output
48
+ elif scheduler.config.prediction_type == "v_prediction":
49
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
50
+ else:
51
+ raise ValueError(
52
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or"
53
+ " `v_prediction` for the DDPMScheduler."
54
+ )
55
+ # 3. Clip or threshold "predicted x_0"
56
+ if scheduler.config.thresholding:
57
+ pred_original_sample = scheduler._threshold_sample(pred_original_sample)
58
+ elif scheduler.config.clip_sample:
59
+ pred_original_sample = pred_original_sample.clamp(
60
+ -scheduler.config.clip_sample_range, scheduler.config.clip_sample_range
61
+ )
62
+
63
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
64
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
65
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
66
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
67
+
68
+ '''
69
+ Add multidiffusion here
70
+ '''
71
+
72
+ if texture is None:
73
+ sample_views = [view for view in sample]
74
+ sample_views, texture, _ = uvp.bake_texture(views=sample_views, main_views=main_views, exp=exp)
75
+ sample_views = torch.stack(sample_views, axis=0)[:,:-1,...]
76
+
77
+
78
+ original_views = [view for view in pred_original_sample]
79
+ original_views, original_tex, visibility_weights = uvp.bake_texture(views=original_views, main_views=main_views, exp=exp)
80
+ uvp.set_texture_map(original_tex)
81
+ original_views = uvp.render_textured_views()
82
+ original_views = torch.stack(original_views, axis=0)[:,:-1,...]
83
+
84
+ # 5. Compute predicted previous sample µ_t
85
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
86
+ # pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
87
+ prev_tex = pred_original_sample_coeff * original_tex + current_sample_coeff * texture
88
+
89
+ # 6. Add noise
90
+ variance = 0
91
+
92
+ if predicted_variance is not None:
93
+ variance_views = [view for view in predicted_variance]
94
+ variance_views, variance_tex, visibility_weights = uvp.bake_texture(views=variance_views, main_views=main_views, cos_weighted=cos_weighted, exp=exp)
95
+ variance_views = torch.stack(variance_views, axis=0)[:,:-1,...]
96
+ else:
97
+ variance_tex = None
98
+
99
+ if t > 0:
100
+ device = texture.device
101
+ variance_noise = randn_tensor(
102
+ texture.shape, generator=generator, device=device, dtype=texture.dtype
103
+ )
104
+ if scheduler.variance_type == "fixed_small_log":
105
+ variance = scheduler._get_variance(t, predicted_variance=variance_tex) * variance_noise
106
+ elif scheduler.variance_type == "learned_range":
107
+ variance = scheduler._get_variance(t, predicted_variance=variance_tex)
108
+ variance = torch.exp(0.5 * variance) * variance_noise
109
+ else:
110
+ variance = (scheduler._get_variance(t, predicted_variance=variance_tex) ** 0.5) * variance_noise
111
+ prev_tex = prev_tex + variance
112
+
113
+ uvp.set_texture_map(prev_tex)
114
+ prev_views = uvp.render_textured_views()
115
+ pred_prev_sample = torch.clone(sample)
116
+ for i, view in enumerate(prev_views):
117
+ pred_prev_sample[i] = view[:-1]
118
+ masks = [view[-1:] for view in prev_views]
119
+
120
+ return {"prev_sample": pred_prev_sample, "pred_original_sample":pred_original_sample, "prev_tex": prev_tex}
121
+
122
+ if not return_dict:
123
+ return pred_prev_sample, pred_original_sample
124
+ pass
125
+
step1x3d_texture/{renderer → texture_sync}/voronoi.py RENAMED
File without changes