Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -38,30 +38,13 @@ def init_db():
|
|
38 |
|
39 |
|
40 |
init_db()
|
41 |
-
|
42 |
# 1️⃣ Instantiate the model
|
43 |
-
model =
|
44 |
-
|
45 |
-
|
46 |
-
# 2️⃣ Load checkpoint
|
47 |
-
checkpoint = torch.load('densenet169_seed40_best2.pt', map_location='cpu')
|
48 |
-
state_dict = checkpoint['state_dict']
|
49 |
-
|
50 |
-
# 3️⃣ Fix state dict
|
51 |
-
new_state_dict = {}
|
52 |
-
for k, v in state_dict.items():
|
53 |
-
# Check if it's prefixed with 'features.0.'
|
54 |
-
if k.startswith('features.0.'):
|
55 |
-
new_key = 'features.' + k[len('features.0.'):] # Remove the '0.' segment
|
56 |
-
else:
|
57 |
-
new_key = k
|
58 |
-
new_state_dict[new_key] = v
|
59 |
-
|
60 |
-
# 4️⃣ Load into the model
|
61 |
-
model.load_state_dict(new_state_dict)
|
62 |
-
|
63 |
-
# Done!
|
64 |
model.eval()
|
|
|
65 |
# ✅ Class Names
|
66 |
CLASS_NAMES = ["Normal", "Early Glaucoma", "Advanced Glaucoma"]
|
67 |
|
|
|
38 |
|
39 |
|
40 |
init_db()
|
41 |
+
from densenet_withglam import get_model_with_attention
|
42 |
# 1️⃣ Instantiate the model
|
43 |
+
model = get_model_with_attention('densenet169', num_classes=3) # Will have GLAM
|
44 |
+
state_dict = torch.load('densenet169_seed40_best2.pt', map_location='cpu')
|
45 |
+
model.load_state_dict(state_dict) # Should now match
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
model.eval()
|
47 |
+
|
48 |
# ✅ Class Names
|
49 |
CLASS_NAMES = ["Normal", "Early Glaucoma", "Advanced Glaucoma"]
|
50 |
|