Update
Browse files
model.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import gc
|
| 2 |
import tempfile
|
| 3 |
|
| 4 |
import numpy as np
|
|
@@ -70,17 +69,15 @@ class Model:
|
|
| 70 |
'cuda' if torch.cuda.is_available() else 'cpu')
|
| 71 |
self.xm = load_model('transmitter', device=self.device)
|
| 72 |
self.diffusion = diffusion_from_config(load_config('diffusion'))
|
| 73 |
-
self.
|
| 74 |
-
self.
|
| 75 |
|
| 76 |
def load_model(self, model_name: str) -> None:
|
| 77 |
assert model_name in ['text300M', 'image300M']
|
| 78 |
-
if model_name == self.
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
gc.collect()
|
| 83 |
-
torch.cuda.empty_cache()
|
| 84 |
|
| 85 |
def to_glb(self, latent: torch.Tensor) -> str:
|
| 86 |
ply_path = tempfile.NamedTemporaryFile(suffix='.ply',
|
|
@@ -109,7 +106,7 @@ class Model:
|
|
| 109 |
|
| 110 |
latents = sample_latents(
|
| 111 |
batch_size=1,
|
| 112 |
-
model=self.
|
| 113 |
diffusion=self.diffusion,
|
| 114 |
guidance_scale=guidance_scale,
|
| 115 |
model_kwargs=dict(texts=[prompt]),
|
|
@@ -135,7 +132,7 @@ class Model:
|
|
| 135 |
image = load_image(image_path)
|
| 136 |
latents = sample_latents(
|
| 137 |
batch_size=1,
|
| 138 |
-
model=self.
|
| 139 |
diffusion=self.diffusion,
|
| 140 |
guidance_scale=guidance_scale,
|
| 141 |
model_kwargs=dict(images=[image]),
|
|
|
|
|
|
|
| 1 |
import tempfile
|
| 2 |
|
| 3 |
import numpy as np
|
|
|
|
| 69 |
'cuda' if torch.cuda.is_available() else 'cpu')
|
| 70 |
self.xm = load_model('transmitter', device=self.device)
|
| 71 |
self.diffusion = diffusion_from_config(load_config('diffusion'))
|
| 72 |
+
self.model_text = None
|
| 73 |
+
self.model_image = None
|
| 74 |
|
| 75 |
def load_model(self, model_name: str) -> None:
|
| 76 |
assert model_name in ['text300M', 'image300M']
|
| 77 |
+
if model_name == 'text300M' and self.model_text is None:
|
| 78 |
+
self.model_text = load_model(model_name, device=self.device)
|
| 79 |
+
elif model_name == 'image300M' and self.model_image is None:
|
| 80 |
+
self.model_image = load_model(model_name, device=self.device)
|
|
|
|
|
|
|
| 81 |
|
| 82 |
def to_glb(self, latent: torch.Tensor) -> str:
|
| 83 |
ply_path = tempfile.NamedTemporaryFile(suffix='.ply',
|
|
|
|
| 106 |
|
| 107 |
latents = sample_latents(
|
| 108 |
batch_size=1,
|
| 109 |
+
model=self.model_text,
|
| 110 |
diffusion=self.diffusion,
|
| 111 |
guidance_scale=guidance_scale,
|
| 112 |
model_kwargs=dict(texts=[prompt]),
|
|
|
|
| 132 |
image = load_image(image_path)
|
| 133 |
latents = sample_latents(
|
| 134 |
batch_size=1,
|
| 135 |
+
model=self.model_image,
|
| 136 |
diffusion=self.diffusion,
|
| 137 |
guidance_scale=guidance_scale,
|
| 138 |
model_kwargs=dict(images=[image]),
|