Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
|
@@ -44,7 +44,7 @@ class Model:
|
|
| 44 |
print("image tokenizer is loaded")
|
| 45 |
return vq_model
|
| 46 |
|
| 47 |
-
def load_gpt(self, condition_type='
|
| 48 |
gpt_ckpt = models[condition_type]
|
| 49 |
# precision = torch.bfloat16
|
| 50 |
precision = torch.float32
|
|
@@ -57,7 +57,7 @@ class Model:
|
|
| 57 |
adapter_size='base',
|
| 58 |
).to(device='cpu', dtype=precision)
|
| 59 |
model_weight = load_file(gpt_ckpt)
|
| 60 |
-
gpt_model.load_state_dict(model_weight, strict=
|
| 61 |
gpt_model.eval()
|
| 62 |
print("gpt model is loaded")
|
| 63 |
return gpt_model
|
|
|
|
| 44 |
print("image tokenizer is loaded")
|
| 45 |
return vq_model
|
| 46 |
|
| 47 |
+
def load_gpt(self, condition_type='edge'):
|
| 48 |
gpt_ckpt = models[condition_type]
|
| 49 |
# precision = torch.bfloat16
|
| 50 |
precision = torch.float32
|
|
|
|
| 57 |
adapter_size='base',
|
| 58 |
).to(device='cpu', dtype=precision)
|
| 59 |
model_weight = load_file(gpt_ckpt)
|
| 60 |
+
gpt_model.load_state_dict(model_weight, strict=False)
|
| 61 |
gpt_model.eval()
|
| 62 |
print("gpt model is loaded")
|
| 63 |
return gpt_model
|