Spaces:
mashroo
/
Runtime error

YoussefAnso commited on
Commit
81d5d11
·
1 Parent(s): 71312a3

Refactor gen_image function in app.py to return base64 encoded GLB data instead of a file path. Update mesh export in inference.py to return .obj file extension. Modify Mesh class to streamline GLB export and enhance texture handling, ensuring proper integration of vertex attributes.

Browse files
Files changed (4) hide show
  1. app.py +8 -13
  2. inference.py +2 -2
  3. mesh.py +120 -128
  4. model/crm/model.py +7 -18
app.py CHANGED
@@ -96,6 +96,7 @@ def preprocess_image(image, background_choice, foreground_ratio, backgroud_color
96
 
97
 
98
  @spaces.GPU
 
99
  def gen_image(input_image, seed, scale, step):
100
  global pipeline, model, args
101
  pipeline.set_seed(seed)
@@ -107,19 +108,13 @@ def gen_image(input_image, seed, scale, step):
107
 
108
  glb_path = generate3d(model, np_imgs, np_xyzs, args.device)
109
 
110
- # Create a temporary file with a proper name for the GLB data
111
- import tempfile
112
- import shutil
113
-
114
- # Create a temporary file with a proper extension
115
- temp_glb = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
116
- temp_glb.close()
117
-
118
- # Copy the generated GLB file to our temporary file
119
- shutil.copy2(glb_path, temp_glb.name)
120
 
121
- # Return images and the path to the temporary GLB file
122
- return Image.fromarray(np_imgs), Image.fromarray(np_xyzs), temp_glb.name
123
 
124
 
125
  parser = argparse.ArgumentParser()
@@ -220,7 +215,7 @@ with gr.Blocks() as demo:
220
  xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
221
 
222
  output_model = gr.Model3D(
223
- label="Output 3D Model (GLB)",
224
  interactive=False,
225
  )
226
  gr.Markdown("Note: Ensure that the input image is correctly pre-processed into a grey background, otherwise the results will be unpredictable.")
 
96
 
97
 
98
  @spaces.GPU
99
+
100
  def gen_image(input_image, seed, scale, step):
101
  global pipeline, model, args
102
  pipeline.set_seed(seed)
 
108
 
109
  glb_path = generate3d(model, np_imgs, np_xyzs, args.device)
110
 
111
+ # Read the GLB file and encode it in base64
112
+ with open(glb_path, 'rb') as f:
113
+ glb_bytes = f.read()
114
+ encoded_glb = 'data:model/gltf-binary;base64,' + base64.b64encode(glb_bytes).decode('utf-8')
 
 
 
 
 
 
115
 
116
+ # Return images and the encoded GLB data
117
+ return Image.fromarray(np_imgs), Image.fromarray(np_xyzs), encoded_glb
118
 
119
 
120
  parser = argparse.ArgumentParser()
 
215
  xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
216
 
217
  output_model = gr.Model3D(
218
+ label="Output OBJ",
219
  interactive=False,
220
  )
221
  gr.Markdown("Note: Ensure that the input image is correctly pre-processed into a grey background, otherwise the results will be unpredictable.")
inference.py CHANGED
@@ -73,7 +73,7 @@ def generate3d(model, rgb, ccm, device):
73
 
74
  start_time = time.time()
75
  with torch.no_grad():
76
- mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False).name
77
  model.export_mesh(data_config, mesh_path_glb, tri_fea_2 = triplane_feature2)
78
 
79
  # glctx = dr.RasterizeGLContext()#dr.RasterizeCudaContext()
@@ -96,4 +96,4 @@ def generate3d(model, rgb, ccm, device):
96
  end_time = time.time()
97
  elapsed_time = end_time - start_time
98
  print(f"uv takes {elapsed_time}s")
99
- return mesh_path_glb
 
73
 
74
  start_time = time.time()
75
  with torch.no_grad():
76
+ mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
77
  model.export_mesh(data_config, mesh_path_glb, tri_fea_2 = triplane_feature2)
78
 
79
  # glctx = dr.RasterizeGLContext()#dr.RasterizeCudaContext()
 
96
  end_time = time.time()
97
  elapsed_time = end_time - start_time
98
  print(f"uv takes {elapsed_time}s")
99
+ return mesh_path_glb+".obj"
mesh.py CHANGED
@@ -10,7 +10,6 @@ from kiui.typing import *
10
  class Mesh:
11
  """
12
  A torch-native trimesh class, with support for ``ply/obj/glb`` formats.
13
-
14
  Note:
15
  This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture).
16
  """
@@ -28,7 +27,6 @@ class Mesh:
28
  device: Optional[torch.device] = None,
29
  ):
30
  """Init a mesh directly using all attributes.
31
-
32
  Args:
33
  v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None.
34
  f (Optional[Tensor]): faces, int [M, 3]. Defaults to None.
@@ -62,7 +60,6 @@ class Mesh:
62
  @classmethod
63
  def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs):
64
  """load mesh from path.
65
-
66
  Args:
67
  path (str): path to mesh file, supports ply, obj, glb.
68
  clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False.
@@ -76,7 +73,6 @@ class Mesh:
76
  Note:
77
  a ``device`` keyword argument can be provided to specify the torch device.
78
  If it's not provided, we will try to use ``'cuda'`` as the device if it's available.
79
-
80
  Returns:
81
  Mesh: the loaded Mesh object.
82
  """
@@ -140,7 +136,6 @@ class Mesh:
140
  @classmethod
141
  def load_obj(cls, path, albedo_path=None, device=None):
142
  """load an ``obj`` mesh.
143
-
144
  Args:
145
  path (str): path to mesh.
146
  albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None.
@@ -149,7 +144,6 @@ class Mesh:
149
  Note:
150
  We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension.
151
  The `usemtl` statement is ignored, and we only use the last material path in `mtl` file.
152
-
153
  Returns:
154
  Mesh: the loaded Mesh object.
155
  """
@@ -313,17 +307,13 @@ class Mesh:
313
  @classmethod
314
  def load_trimesh(cls, path, device=None):
315
  """load a mesh using ``trimesh.load()``.
316
-
317
  Can load various formats like ``glb`` and serves as a fallback.
318
-
319
  Note:
320
  We will try to merge all meshes if the glb contains more than one,
321
  but **this may cause the texture to lose**, since we only support one texture image!
322
-
323
  Args:
324
  path (str): path to the mesh file.
325
  device (torch.device, optional): torch device. Defaults to None.
326
-
327
  Returns:
328
  Mesh: the loaded Mesh object.
329
  """
@@ -423,10 +413,8 @@ class Mesh:
423
  # sample surface (using trimesh)
424
  def sample_surface(self, count: int):
425
  """sample points on the surface of the mesh.
426
-
427
  Args:
428
  count (int): number of points to sample.
429
-
430
  Returns:
431
  torch.Tensor: the sampled points, float [count, 3].
432
  """
@@ -438,7 +426,6 @@ class Mesh:
438
  # aabb
439
  def aabb(self):
440
  """get the axis-aligned bounding box of the mesh.
441
-
442
  Returns:
443
  Tuple[torch.Tensor]: the min xyz and max xyz of the mesh.
444
  """
@@ -448,7 +435,6 @@ class Mesh:
448
  @torch.no_grad()
449
  def auto_size(self, bound=0.9):
450
  """auto resize the mesh.
451
-
452
  Args:
453
  bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9.
454
  """
@@ -484,7 +470,6 @@ class Mesh:
484
 
485
  def auto_uv(self, cache_path=None, vmap=True):
486
  """auto calculate the uv coordinates.
487
-
488
  Args:
489
  cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None.
490
  vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf).
@@ -523,7 +508,6 @@ class Mesh:
523
 
524
  def align_v_to_vt(self, vmapping=None):
525
  """ remap v/f and vn/fn to vt/ft.
526
-
527
  Args:
528
  vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None.
529
  """
@@ -542,10 +526,8 @@ class Mesh:
542
 
543
  def to(self, device):
544
  """move all tensor attributes to device.
545
-
546
  Args:
547
  device (torch.device): target device.
548
-
549
  Returns:
550
  Mesh: self.
551
  """
@@ -558,7 +540,6 @@ class Mesh:
558
 
559
  def write(self, path):
560
  """write the mesh to a path.
561
-
562
  Args:
563
  path (str): path to write, supports ply, obj and glb.
564
  """
@@ -573,7 +554,6 @@ class Mesh:
573
 
574
  def write_ply(self, path):
575
  """write the mesh in ply format. Only for geometry!
576
-
577
  Args:
578
  path (str): path to write.
579
  """
@@ -591,16 +571,16 @@ class Mesh:
591
  def write_glb(self, path):
592
  """write the mesh in glb/gltf format.
593
  This will create a scene with a single mesh.
594
-
595
  Args:
596
  path (str): path to write.
597
  """
598
- import pygltflib
599
 
600
  # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0]
601
  if self.vt is not None and self.v.shape[0] != self.vt.shape[0]:
602
  self.align_v_to_vt()
603
 
 
 
604
  f_np = self.f.detach().cpu().numpy().astype(np.uint32)
605
  f_np_blob = f_np.flatten().tobytes()
606
 
@@ -610,104 +590,79 @@ class Mesh:
610
  blob = f_np_blob + v_np_blob
611
  byteOffset = len(blob)
612
 
613
- # Create attributes dictionary
614
- attributes = pygltflib.Attributes(POSITION=1)
615
- accessor_count = 2 # Start after position (0) and indices (1)
616
-
617
- # Initialize buffer views list
618
- buffer_views = [
619
- # triangles; as flatten (element) array
620
- pygltflib.BufferView(
621
- buffer=0,
622
- byteLength=len(f_np_blob),
623
- target=pygltflib.ELEMENT_ARRAY_BUFFER,
624
- ),
625
- # positions; as vec3 array
626
- pygltflib.BufferView(
627
- buffer=0,
628
- byteOffset=len(f_np_blob),
629
- byteLength=len(v_np_blob),
630
- byteStride=12, # vec3
631
- target=pygltflib.ARRAY_BUFFER,
632
- ),
633
- ]
634
-
635
- # Initialize accessors list
636
- accessors = [
637
- # 0 = triangles
638
- pygltflib.Accessor(
639
- bufferView=0,
640
- componentType=pygltflib.UNSIGNED_INT,
641
- count=f_np.size,
642
- type=pygltflib.SCALAR,
643
- max=[int(f_np.max())],
644
- min=[int(f_np.min())],
645
- ),
646
- # 1 = positions
647
- pygltflib.Accessor(
648
- bufferView=1,
649
- componentType=pygltflib.FLOAT,
650
- count=len(v_np),
651
- type=pygltflib.VEC3,
652
- max=v_np.max(axis=0).tolist(),
653
- min=v_np.min(axis=0).tolist(),
654
- ),
655
- ]
656
-
657
- # Add vertex colors if they exist
658
- if self.vc is not None:
659
- vc_np = self.vc.detach().cpu().numpy().astype(np.float32)
660
- vc_np_blob = vc_np.tobytes()
661
-
662
- # Add vertex color buffer view
663
- buffer_views.append(
664
- pygltflib.BufferView(
665
- buffer=0,
666
- byteOffset=byteOffset,
667
- byteLength=len(vc_np_blob),
668
- byteStride=12, # vec3
669
- target=pygltflib.ARRAY_BUFFER,
670
- )
671
- )
672
-
673
- # Add vertex color accessor
674
- accessors.append(
675
- pygltflib.Accessor(
676
- bufferView=accessor_count,
677
- componentType=pygltflib.FLOAT,
678
- count=len(vc_np),
679
- type=pygltflib.VEC3,
680
- max=vc_np.max(axis=0).tolist(),
681
- min=vc_np.min(axis=0).tolist(),
682
- )
683
- )
684
-
685
- # Add to attributes
686
- attributes.COLOR_0 = accessor_count
687
- accessor_count += 1
688
-
689
- blob += vc_np_blob
690
- byteOffset += len(vc_np_blob)
691
-
692
- # Create the GLTF object with all components
693
  gltf = pygltflib.GLTF2(
694
  scene=0,
695
  scenes=[pygltflib.Scene(nodes=[0])],
696
  nodes=[pygltflib.Node(mesh=0)],
697
  meshes=[pygltflib.Mesh(primitives=[pygltflib.Primitive(
698
- attributes=attributes,
 
 
 
699
  indices=0,
700
  )])],
701
- buffers=[pygltflib.Buffer(byteLength=byteOffset)],
702
- bufferViews=buffer_views,
703
- accessors=accessors,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
704
  )
705
 
706
- # Add material for vertex colors
707
- if self.vc is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
708
  gltf.materials.append(pygltflib.Material(
709
  pbrMetallicRoughness=pygltflib.PbrMetallicRoughness(
710
- baseColorFactor=[1.0, 1.0, 1.0, 1.0],
711
  metallicFactor=0.0,
712
  roughnessFactor=1.0,
713
  ),
@@ -715,28 +670,27 @@ class Mesh:
715
  alphaCutoff=None,
716
  doubleSided=True,
717
  ))
718
- gltf.meshes[0].primitives[0].material = 0
719
 
720
- # Handle textures if they exist
721
- if self.vt is not None:
722
- vt_np = self.vt.detach().cpu().numpy().astype(np.float32)
723
- vt_np_blob = vt_np.tobytes()
724
 
725
- # Add texture coordinates buffer view
726
  gltf.bufferViews.append(
 
727
  pygltflib.BufferView(
728
  buffer=0,
729
  byteOffset=byteOffset,
730
  byteLength=len(vt_np_blob),
731
- byteStride=8, # vec2
732
  target=pygltflib.ARRAY_BUFFER,
733
  )
734
  )
735
 
736
- # Add texture coordinates accessor
737
  gltf.accessors.append(
 
738
  pygltflib.Accessor(
739
- bufferView=len(gltf.bufferViews) - 1,
740
  componentType=pygltflib.FLOAT,
741
  count=len(vt_np),
742
  type=pygltflib.VEC2,
@@ -745,25 +699,64 @@ class Mesh:
745
  )
746
  )
747
 
748
- # Add texture coordinates to attributes
749
- gltf.meshes[0].primitives[0].attributes.TEXCOORD_0 = len(gltf.accessors) - 1
750
-
751
- blob += vt_np_blob
752
  byteOffset += len(vt_np_blob)
753
 
754
- # Update buffer size
 
 
 
 
 
 
 
 
 
 
 
755
  gltf.buffers[0].byteLength = byteOffset
756
 
757
- # Set the binary blob
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758
  gltf.set_binary_blob(blob)
759
 
760
- # Save the GLB file
761
  gltf.save(path)
762
 
763
 
764
  def write_obj(self, path):
765
  """write the mesh in obj format. Will also write the texture and mtl files.
766
-
767
  Args:
768
  path (str): path to write.
769
  """
@@ -826,5 +819,4 @@ class Mesh:
826
  metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
827
  metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
828
  cv2.imwrite(metallic_path, metallicRoughness[..., 2])
829
- cv2.imwrite(roughness_path, metallicRoughness[..., 1])
830
-
 
10
  class Mesh:
11
  """
12
  A torch-native trimesh class, with support for ``ply/obj/glb`` formats.
 
13
  Note:
14
  This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture).
15
  """
 
27
  device: Optional[torch.device] = None,
28
  ):
29
  """Init a mesh directly using all attributes.
 
30
  Args:
31
  v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None.
32
  f (Optional[Tensor]): faces, int [M, 3]. Defaults to None.
 
60
  @classmethod
61
  def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs):
62
  """load mesh from path.
 
63
  Args:
64
  path (str): path to mesh file, supports ply, obj, glb.
65
  clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False.
 
73
  Note:
74
  a ``device`` keyword argument can be provided to specify the torch device.
75
  If it's not provided, we will try to use ``'cuda'`` as the device if it's available.
 
76
  Returns:
77
  Mesh: the loaded Mesh object.
78
  """
 
136
  @classmethod
137
  def load_obj(cls, path, albedo_path=None, device=None):
138
  """load an ``obj`` mesh.
 
139
  Args:
140
  path (str): path to mesh.
141
  albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None.
 
144
  Note:
145
  We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension.
146
  The `usemtl` statement is ignored, and we only use the last material path in `mtl` file.
 
147
  Returns:
148
  Mesh: the loaded Mesh object.
149
  """
 
307
  @classmethod
308
  def load_trimesh(cls, path, device=None):
309
  """load a mesh using ``trimesh.load()``.
 
310
  Can load various formats like ``glb`` and serves as a fallback.
 
311
  Note:
312
  We will try to merge all meshes if the glb contains more than one,
313
  but **this may cause the texture to lose**, since we only support one texture image!
 
314
  Args:
315
  path (str): path to the mesh file.
316
  device (torch.device, optional): torch device. Defaults to None.
 
317
  Returns:
318
  Mesh: the loaded Mesh object.
319
  """
 
413
  # sample surface (using trimesh)
414
  def sample_surface(self, count: int):
415
  """sample points on the surface of the mesh.
 
416
  Args:
417
  count (int): number of points to sample.
 
418
  Returns:
419
  torch.Tensor: the sampled points, float [count, 3].
420
  """
 
426
  # aabb
427
  def aabb(self):
428
  """get the axis-aligned bounding box of the mesh.
 
429
  Returns:
430
  Tuple[torch.Tensor]: the min xyz and max xyz of the mesh.
431
  """
 
435
  @torch.no_grad()
436
  def auto_size(self, bound=0.9):
437
  """auto resize the mesh.
 
438
  Args:
439
  bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9.
440
  """
 
470
 
471
  def auto_uv(self, cache_path=None, vmap=True):
472
  """auto calculate the uv coordinates.
 
473
  Args:
474
  cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None.
475
  vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf).
 
508
 
509
  def align_v_to_vt(self, vmapping=None):
510
  """ remap v/f and vn/fn to vt/ft.
 
511
  Args:
512
  vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None.
513
  """
 
526
 
527
  def to(self, device):
528
  """move all tensor attributes to device.
 
529
  Args:
530
  device (torch.device): target device.
 
531
  Returns:
532
  Mesh: self.
533
  """
 
540
 
541
  def write(self, path):
542
  """write the mesh to a path.
 
543
  Args:
544
  path (str): path to write, supports ply, obj and glb.
545
  """
 
554
 
555
  def write_ply(self, path):
556
  """write the mesh in ply format. Only for geometry!
 
557
  Args:
558
  path (str): path to write.
559
  """
 
571
  def write_glb(self, path):
572
  """write the mesh in glb/gltf format.
573
  This will create a scene with a single mesh.
 
574
  Args:
575
  path (str): path to write.
576
  """
 
577
 
578
  # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0]
579
  if self.vt is not None and self.v.shape[0] != self.vt.shape[0]:
580
  self.align_v_to_vt()
581
 
582
+ import pygltflib
583
+
584
  f_np = self.f.detach().cpu().numpy().astype(np.uint32)
585
  f_np_blob = f_np.flatten().tobytes()
586
 
 
590
  blob = f_np_blob + v_np_blob
591
  byteOffset = len(blob)
592
 
593
+ # base mesh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
  gltf = pygltflib.GLTF2(
595
  scene=0,
596
  scenes=[pygltflib.Scene(nodes=[0])],
597
  nodes=[pygltflib.Node(mesh=0)],
598
  meshes=[pygltflib.Mesh(primitives=[pygltflib.Primitive(
599
+ # indices to accessors (0 is triangles)
600
+ attributes=pygltflib.Attributes(
601
+ POSITION=1,
602
+ ),
603
  indices=0,
604
  )])],
605
+ buffers=[
606
+ pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob))
607
+ ],
608
+ # buffer view (based on dtype)
609
+ bufferViews=[
610
+ # triangles; as flatten (element) array
611
+ pygltflib.BufferView(
612
+ buffer=0,
613
+ byteLength=len(f_np_blob),
614
+ target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963)
615
+ ),
616
+ # positions; as vec3 array
617
+ pygltflib.BufferView(
618
+ buffer=0,
619
+ byteOffset=len(f_np_blob),
620
+ byteLength=len(v_np_blob),
621
+ byteStride=12, # vec3
622
+ target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962)
623
+ ),
624
+ ],
625
+ accessors=[
626
+ # 0 = triangles
627
+ pygltflib.Accessor(
628
+ bufferView=0,
629
+ componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125)
630
+ count=f_np.size,
631
+ type=pygltflib.SCALAR,
632
+ max=[int(f_np.max())],
633
+ min=[int(f_np.min())],
634
+ ),
635
+ # 1 = positions
636
+ pygltflib.Accessor(
637
+ bufferView=1,
638
+ componentType=pygltflib.FLOAT, # GL_FLOAT (5126)
639
+ count=len(v_np),
640
+ type=pygltflib.VEC3,
641
+ max=v_np.max(axis=0).tolist(),
642
+ min=v_np.min(axis=0).tolist(),
643
+ ),
644
+ ],
645
  )
646
 
647
+ # append texture info
648
+ if self.vt is not None:
649
+
650
+ vt_np = self.vt.detach().cpu().numpy().astype(np.float32)
651
+ vt_np_blob = vt_np.tobytes()
652
+
653
+ albedo = self.albedo.detach().cpu().numpy()
654
+ albedo = (albedo * 255).astype(np.uint8)
655
+ albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)
656
+ albedo_blob = cv2.imencode('.png', albedo)[1].tobytes()
657
+
658
+ # update primitive
659
+ gltf.meshes[0].primitives[0].attributes.TEXCOORD_0 = 2
660
+ gltf.meshes[0].primitives[0].material = 0
661
+
662
+ # update materials
663
  gltf.materials.append(pygltflib.Material(
664
  pbrMetallicRoughness=pygltflib.PbrMetallicRoughness(
665
+ baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0),
666
  metallicFactor=0.0,
667
  roughnessFactor=1.0,
668
  ),
 
670
  alphaCutoff=None,
671
  doubleSided=True,
672
  ))
 
673
 
674
+ gltf.textures.append(pygltflib.Texture(sampler=0, source=0))
675
+ gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
676
+ gltf.images.append(pygltflib.Image(bufferView=3, mimeType="image/png"))
 
677
 
678
+ # update buffers
679
  gltf.bufferViews.append(
680
+ # index = 2, texcoords; as vec2 array
681
  pygltflib.BufferView(
682
  buffer=0,
683
  byteOffset=byteOffset,
684
  byteLength=len(vt_np_blob),
685
+ byteStride=8, # vec2
686
  target=pygltflib.ARRAY_BUFFER,
687
  )
688
  )
689
 
 
690
  gltf.accessors.append(
691
+ # 2 = texcoords
692
  pygltflib.Accessor(
693
+ bufferView=2,
694
  componentType=pygltflib.FLOAT,
695
  count=len(vt_np),
696
  type=pygltflib.VEC2,
 
699
  )
700
  )
701
 
702
+ blob += vt_np_blob
 
 
 
703
  byteOffset += len(vt_np_blob)
704
 
705
+ gltf.bufferViews.append(
706
+ # index = 3, albedo texture; as none target
707
+ pygltflib.BufferView(
708
+ buffer=0,
709
+ byteOffset=byteOffset,
710
+ byteLength=len(albedo_blob),
711
+ )
712
+ )
713
+
714
+ blob += albedo_blob
715
+ byteOffset += len(albedo_blob)
716
+
717
  gltf.buffers[0].byteLength = byteOffset
718
 
719
+ # append metllic roughness
720
+ if self.metallicRoughness is not None:
721
+ metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
722
+ metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
723
+ metallicRoughness = cv2.cvtColor(metallicRoughness, cv2.COLOR_RGB2BGR)
724
+ metallicRoughness_blob = cv2.imencode('.png', metallicRoughness)[1].tobytes()
725
+
726
+ # update texture definition
727
+ gltf.materials[0].pbrMetallicRoughness.metallicFactor = 1.0
728
+ gltf.materials[0].pbrMetallicRoughness.roughnessFactor = 1.0
729
+ gltf.materials[0].pbrMetallicRoughness.metallicRoughnessTexture = pygltflib.TextureInfo(index=1, texCoord=0)
730
+
731
+ gltf.textures.append(pygltflib.Texture(sampler=1, source=1))
732
+ gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
733
+ gltf.images.append(pygltflib.Image(bufferView=4, mimeType="image/png"))
734
+
735
+ # update buffers
736
+ gltf.bufferViews.append(
737
+ # index = 4, metallicRoughness texture; as none target
738
+ pygltflib.BufferView(
739
+ buffer=0,
740
+ byteOffset=byteOffset,
741
+ byteLength=len(metallicRoughness_blob),
742
+ )
743
+ )
744
+
745
+ blob += metallicRoughness_blob
746
+ byteOffset += len(metallicRoughness_blob)
747
+
748
+ gltf.buffers[0].byteLength = byteOffset
749
+
750
+
751
+ # set actual data
752
  gltf.set_binary_blob(blob)
753
 
754
+ # glb = b"".join(gltf.save_to_bytes())
755
  gltf.save(path)
756
 
757
 
758
  def write_obj(self, path):
759
  """write the mesh in obj format. Will also write the texture and mtl files.
 
760
  Args:
761
  path (str): path to write.
762
  """
 
819
  metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
820
  metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
821
  cv2.imwrite(metallic_path, metallicRoughness[..., 2])
822
+ cv2.imwrite(roughness_path, metallicRoughness[..., 1])
 
model/crm/model.py CHANGED
@@ -98,28 +98,17 @@ class CRM(nn.Module):
98
  # Expect predicted colors value range from [-1, 1]
99
  colors = (colors * 0.5 + 0.5).clip(0, 1)
100
 
101
- # Transform vertices to match GLB coordinate system
102
- # GLB uses right-handed coordinate system with Y up
103
- verts = verts[..., [0, 2, 1]] # Swap Y and Z to get Y up
104
- verts[..., 0] *= -1 # Flip X to get right-handed
105
  verts = verts.squeeze().cpu().numpy()
106
-
107
- # Transform faces to maintain correct winding order
108
- faces = faces[..., [2, 1, 0]] # Reverse winding order
109
- faces = faces.squeeze().cpu().numpy()
110
 
111
  # export the final mesh
112
  with torch.no_grad():
113
- # Create a Mesh object with the data
114
- from mesh import Mesh
115
- mesh = Mesh(
116
- v=torch.from_numpy(verts).float(),
117
- f=torch.from_numpy(faces).int(),
118
- vc=torch.from_numpy(colors).float(),
119
- device='cpu'
120
- )
121
- # Write as GLB
122
- mesh.write(out_dir)
123
 
124
  def export_mesh_wt_uv(self, ctx, data, out_dir, ind, device, res, tri_fea_2=None):
125
 
 
98
  # Expect predicted colors value range from [-1, 1]
99
  colors = (colors * 0.5 + 0.5).clip(0, 1)
100
 
101
+ verts = verts[..., [0, 2, 1]]
102
+ verts[..., 0]*= -1
103
+ verts[..., 2]*= -1
 
104
  verts = verts.squeeze().cpu().numpy()
105
+ faces = faces[..., [2, 1, 0]][..., [0, 2, 1]]#[..., [1, 0, 2]]
106
+ faces = faces.squeeze().cpu().numpy()#faces[..., [2, 1, 0]].squeeze().cpu().numpy()
 
 
107
 
108
  # export the final mesh
109
  with torch.no_grad():
110
+ mesh = trimesh.Trimesh(verts, faces, vertex_colors=colors, process=False) # important, process=True leads to seg fault...
111
+ mesh.export(f'{out_dir}.obj')
 
 
 
 
 
 
 
 
112
 
113
  def export_mesh_wt_uv(self, ctx, data, out_dir, ind, device, res, tri_fea_2=None):
114