cpu
Browse files- configs/demo.yaml +1 -0
- util.py +7 -6
configs/demo.yaml
CHANGED
|
@@ -24,6 +24,7 @@ dual_conditioner: False
|
|
| 24 |
steps: 50
|
| 25 |
init_step: 0
|
| 26 |
num_workers: 0
|
|
|
|
| 27 |
gpu: 0
|
| 28 |
max_iter: 100
|
| 29 |
|
|
|
|
| 24 |
steps: 50
|
| 25 |
init_step: 0
|
| 26 |
num_workers: 0
|
| 27 |
+
use_gpu: False
|
| 28 |
gpu: 0
|
| 29 |
max_iter: 100
|
| 30 |
|
util.py
CHANGED
|
@@ -32,18 +32,19 @@ SD_XL_BASE_RATIOS = {
|
|
| 32 |
"3.0": (1728, 576),
|
| 33 |
}
|
| 34 |
|
| 35 |
-
def init_model(
|
| 36 |
|
| 37 |
-
model_cfg = OmegaConf.load(
|
| 38 |
-
ckpt =
|
| 39 |
|
| 40 |
model = instantiate_from_config(model_cfg.model)
|
| 41 |
model.init_from_ckpt(ckpt)
|
| 42 |
|
| 43 |
-
if
|
| 44 |
model.train()
|
| 45 |
else:
|
| 46 |
-
|
|
|
|
| 47 |
model.eval()
|
| 48 |
model.freeze()
|
| 49 |
|
|
@@ -108,7 +109,7 @@ def deep_copy(batch):
|
|
| 108 |
def prepare_batch(cfgs, batch):
|
| 109 |
|
| 110 |
for key in batch:
|
| 111 |
-
if isinstance(batch[key], torch.Tensor):
|
| 112 |
batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu))
|
| 113 |
|
| 114 |
if not cfgs.dual_conditioner:
|
|
|
|
| 32 |
"3.0": (1728, 576),
|
| 33 |
}
|
| 34 |
|
| 35 |
+
def init_model(cfgs):
|
| 36 |
|
| 37 |
+
model_cfg = OmegaConf.load(cfgs.model_cfg_path)
|
| 38 |
+
ckpt = cfgs.load_ckpt_path
|
| 39 |
|
| 40 |
model = instantiate_from_config(model_cfg.model)
|
| 41 |
model.init_from_ckpt(ckpt)
|
| 42 |
|
| 43 |
+
if cfgs.type == "train":
|
| 44 |
model.train()
|
| 45 |
else:
|
| 46 |
+
if cfgs.use_gpu:
|
| 47 |
+
model.to(torch.device("cuda", index=cfgs.gpu))
|
| 48 |
model.eval()
|
| 49 |
model.freeze()
|
| 50 |
|
|
|
|
| 109 |
def prepare_batch(cfgs, batch):
|
| 110 |
|
| 111 |
for key in batch:
|
| 112 |
+
if isinstance(batch[key], torch.Tensor) and cfgs.use_gpu:
|
| 113 |
batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu))
|
| 114 |
|
| 115 |
if not cfgs.dual_conditioner:
|