Spaces:
Running
on
Zero
Running
on
Zero
Update trellis/pipelines/trellis_text_to_3d.py
Browse files
trellis/pipelines/trellis_text_to_3d.py
CHANGED
@@ -3,7 +3,7 @@ import torch
|
|
3 |
import torch.nn as nn
|
4 |
import numpy as np
|
5 |
from transformers import CLIPTextModel, AutoTokenizer
|
6 |
-
import open3d as o3d
|
7 |
from .base import Pipeline
|
8 |
from . import samplers
|
9 |
from ..modules import sparse as sp
|
@@ -225,16 +225,9 @@ class TrellisTextTo3DPipeline(Pipeline):
|
|
225 |
coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
|
226 |
slat = self.sample_slat(cond, coords, slat_sampler_params)
|
227 |
return self.decode_slat(slat, formats)
|
228 |
-
|
229 |
def voxelize(self, mesh: o3d.geometry.TriangleMesh) -> torch.Tensor:
|
230 |
-
"""
|
231 |
-
Voxelize a mesh.
|
232 |
|
233 |
-
Args:
|
234 |
-
mesh (o3d.geometry.TriangleMesh): The mesh to voxelize.
|
235 |
-
sha256 (str): The SHA256 hash of the mesh.
|
236 |
-
output_dir (str): The output directory.
|
237 |
-
"""
|
238 |
vertices = np.asarray(mesh.vertices)
|
239 |
aabb = np.stack([vertices.min(0), vertices.max(0)])
|
240 |
center = (aabb[0] + aabb[1]) / 2
|
@@ -256,17 +249,7 @@ class TrellisTextTo3DPipeline(Pipeline):
|
|
256 |
slat_sampler_params: dict = {},
|
257 |
formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
|
258 |
) -> dict:
|
259 |
-
"""
|
260 |
-
Run the pipeline for making variants of an asset.
|
261 |
|
262 |
-
Args:
|
263 |
-
mesh (o3d.geometry.TriangleMesh): The base mesh.
|
264 |
-
prompt (str): The text prompt.
|
265 |
-
num_samples (int): The number of samples to generate.
|
266 |
-
seed (int): The random seed
|
267 |
-
slat_sampler_params (dict): Additional parameters for the structured latent sampler.
|
268 |
-
formats (List[str]): The formats to decode the structured latent to.
|
269 |
-
"""
|
270 |
cond = self.get_cond([prompt])
|
271 |
coords = self.voxelize(mesh)
|
272 |
coords = torch.cat([
|
@@ -276,3 +259,4 @@ class TrellisTextTo3DPipeline(Pipeline):
|
|
276 |
torch.manual_seed(seed)
|
277 |
slat = self.sample_slat(cond, coords, slat_sampler_params)
|
278 |
return self.decode_slat(slat, formats)
|
|
|
|
3 |
import torch.nn as nn
|
4 |
import numpy as np
|
5 |
from transformers import CLIPTextModel, AutoTokenizer
|
6 |
+
#import open3d as o3d
|
7 |
from .base import Pipeline
|
8 |
from . import samplers
|
9 |
from ..modules import sparse as sp
|
|
|
225 |
coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
|
226 |
slat = self.sample_slat(cond, coords, slat_sampler_params)
|
227 |
return self.decode_slat(slat, formats)
|
228 |
+
"""
|
229 |
def voxelize(self, mesh: o3d.geometry.TriangleMesh) -> torch.Tensor:
|
|
|
|
|
230 |
|
|
|
|
|
|
|
|
|
|
|
231 |
vertices = np.asarray(mesh.vertices)
|
232 |
aabb = np.stack([vertices.min(0), vertices.max(0)])
|
233 |
center = (aabb[0] + aabb[1]) / 2
|
|
|
249 |
slat_sampler_params: dict = {},
|
250 |
formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
|
251 |
) -> dict:
|
|
|
|
|
252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
cond = self.get_cond([prompt])
|
254 |
coords = self.voxelize(mesh)
|
255 |
coords = torch.cat([
|
|
|
259 |
torch.manual_seed(seed)
|
260 |
slat = self.sample_slat(cond, coords, slat_sampler_params)
|
261 |
return self.decode_slat(slat, formats)
|
262 |
+
"""
|