yejunliang23 commited on
Commit
5f74b1f
·
verified ·
1 Parent(s): 513ff13

Update trellis/pipelines/trellis_text_to_3d.py

Browse files
trellis/pipelines/trellis_text_to_3d.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import torch.nn as nn
4
  import numpy as np
5
  from transformers import CLIPTextModel, AutoTokenizer
6
- import open3d as o3d
7
  from .base import Pipeline
8
  from . import samplers
9
  from ..modules import sparse as sp
@@ -225,16 +225,9 @@ class TrellisTextTo3DPipeline(Pipeline):
225
  coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
226
  slat = self.sample_slat(cond, coords, slat_sampler_params)
227
  return self.decode_slat(slat, formats)
228
-
229
  def voxelize(self, mesh: o3d.geometry.TriangleMesh) -> torch.Tensor:
230
- """
231
- Voxelize a mesh.
232
 
233
- Args:
234
- mesh (o3d.geometry.TriangleMesh): The mesh to voxelize.
235
- sha256 (str): The SHA256 hash of the mesh.
236
- output_dir (str): The output directory.
237
- """
238
  vertices = np.asarray(mesh.vertices)
239
  aabb = np.stack([vertices.min(0), vertices.max(0)])
240
  center = (aabb[0] + aabb[1]) / 2
@@ -256,17 +249,7 @@ class TrellisTextTo3DPipeline(Pipeline):
256
  slat_sampler_params: dict = {},
257
  formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
258
  ) -> dict:
259
- """
260
- Run the pipeline for making variants of an asset.
261
 
262
- Args:
263
- mesh (o3d.geometry.TriangleMesh): The base mesh.
264
- prompt (str): The text prompt.
265
- num_samples (int): The number of samples to generate.
266
- seed (int): The random seed
267
- slat_sampler_params (dict): Additional parameters for the structured latent sampler.
268
- formats (List[str]): The formats to decode the structured latent to.
269
- """
270
  cond = self.get_cond([prompt])
271
  coords = self.voxelize(mesh)
272
  coords = torch.cat([
@@ -276,3 +259,4 @@ class TrellisTextTo3DPipeline(Pipeline):
276
  torch.manual_seed(seed)
277
  slat = self.sample_slat(cond, coords, slat_sampler_params)
278
  return self.decode_slat(slat, formats)
 
 
3
  import torch.nn as nn
4
  import numpy as np
5
  from transformers import CLIPTextModel, AutoTokenizer
6
+ #import open3d as o3d
7
  from .base import Pipeline
8
  from . import samplers
9
  from ..modules import sparse as sp
 
225
  coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
226
  slat = self.sample_slat(cond, coords, slat_sampler_params)
227
  return self.decode_slat(slat, formats)
228
+ """
229
  def voxelize(self, mesh: o3d.geometry.TriangleMesh) -> torch.Tensor:
 
 
230
 
 
 
 
 
 
231
  vertices = np.asarray(mesh.vertices)
232
  aabb = np.stack([vertices.min(0), vertices.max(0)])
233
  center = (aabb[0] + aabb[1]) / 2
 
249
  slat_sampler_params: dict = {},
250
  formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
251
  ) -> dict:
 
 
252
 
 
 
 
 
 
 
 
 
253
  cond = self.get_cond([prompt])
254
  coords = self.voxelize(mesh)
255
  coords = torch.cat([
 
259
  torch.manual_seed(seed)
260
  slat = self.sample_slat(cond, coords, slat_sampler_params)
261
  return self.decode_slat(slat, formats)
262
+ """