dhruv2842 commited on
Commit
0ac5b89
·
verified ·
1 Parent(s): c98790c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -22
app.py CHANGED
@@ -38,30 +38,13 @@ def init_db():
38
 
39
 
40
  init_db()
41
-
42
  # 1️⃣ Instantiate the model
43
- model = models.densenet169(pretrained=False)
44
- model.classifier = nn.Linear(model.classifier.in_features, 3)
45
-
46
- # 2️⃣ Load checkpoint
47
- checkpoint = torch.load('densenet169_seed40_best2.pt', map_location='cpu')
48
- state_dict = checkpoint['state_dict']
49
-
50
- # 3️⃣ Fix state dict
51
- new_state_dict = {}
52
- for k, v in state_dict.items():
53
- # Check if it's prefixed with 'features.0.'
54
- if k.startswith('features.0.'):
55
- new_key = 'features.' + k[len('features.0.'):] # Remove the '0.' segment
56
- else:
57
- new_key = k
58
- new_state_dict[new_key] = v
59
-
60
- # 4️⃣ Load into the model
61
- model.load_state_dict(new_state_dict)
62
-
63
- # Done!
64
  model.eval()
 
65
  # ✅ Class Names
66
  CLASS_NAMES = ["Normal", "Early Glaucoma", "Advanced Glaucoma"]
67
 
 
38
 
39
 
40
  init_db()
41
+ from densenet_withglam import get_model_with_attention
42
  # 1️⃣ Instantiate the model
43
+ model = get_model_with_attention('densenet169', num_classes=3) # Will have GLAM
44
+ state_dict = torch.load('densenet169_seed40_best2.pt', map_location='cpu')
45
+ model.load_state_dict(state_dict) # Should now match
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  model.eval()
47
+
48
  # ✅ Class Names
49
  CLASS_NAMES = ["Normal", "Early Glaucoma", "Advanced Glaucoma"]
50