Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -32,7 +32,19 @@ spaBERT_model.load_state_dict(b_model.state_dict(), strict = False)
|
|
| 32 |
|
| 33 |
pre_trained_model = torch.load(pretrained_model, map_location=torch.device('cpu'))
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
|
|
|
|
| 32 |
|
| 33 |
pre_trained_model = torch.load(pretrained_model, map_location=torch.device('cpu'))
|
| 34 |
|
| 35 |
+
model_keys = spaBERT_model.state_dict()
|
| 36 |
+
cnt_layers = 0
|
| 37 |
+
for key in model_keys
|
| 38 |
+
if key in pre_trained_model:
|
| 39 |
+
model_keys[key] = pre_trained_model[key]
|
| 40 |
+
cnt_layers += 1
|
| 41 |
+
else:
|
| 42 |
+
print("No weight for", key)
|
| 43 |
+
print(cnt_layers, 'layers loaded')
|
| 44 |
|
| 45 |
+
spaBERT_model.load_state_dict(model_keys)
|
| 46 |
+
spaBERT_model.to(device)
|
| 47 |
+
spaBERT_model.eval()
|
| 48 |
|
| 49 |
|
| 50 |
|