Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
·
b195df1
1
Parent(s):
ace1938
update
Browse files- test_ddgan.py +3 -0
test_ddgan.py
CHANGED
|
@@ -390,9 +390,12 @@ class ObjectFromDict:
|
|
| 390 |
def load_model(config, path, device="cpu"):
|
| 391 |
config = ObjectFromDict(config)
|
| 392 |
text_encoder = build_encoder(name=config.text_encoder, masked_mean=config.masked_mean)
|
|
|
|
| 393 |
config.cond_size = text_encoder.output_size
|
| 394 |
netG = NCSNpp(config)
|
|
|
|
| 395 |
ckpt = torch.load(path, map_location="cpu")
|
|
|
|
| 396 |
for key in list(ckpt.keys()):
|
| 397 |
if key.startswith("module"):
|
| 398 |
ckpt[key[7:]] = ckpt.pop(key)
|
|
|
|
| 390 |
def load_model(config, path, device="cpu"):
|
| 391 |
config = ObjectFromDict(config)
|
| 392 |
text_encoder = build_encoder(name=config.text_encoder, masked_mean=config.masked_mean)
|
| 393 |
+
print(text_encoder)
|
| 394 |
config.cond_size = text_encoder.output_size
|
| 395 |
netG = NCSNpp(config)
|
| 396 |
+
print(netG)
|
| 397 |
ckpt = torch.load(path, map_location="cpu")
|
| 398 |
+
print("CK", ckpt)
|
| 399 |
for key in list(ckpt.keys()):
|
| 400 |
if key.startswith("module"):
|
| 401 |
ckpt[key[7:]] = ckpt.pop(key)
|