dhruv2842 commited on
Commit
d6c8fb1
·
verified ·
1 Parent(s): cf2ad62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -40,16 +40,19 @@ def init_db():
40
  init_db()
41
 
42
  model = models.densenet169(pretrained=False)
43
-
44
- # IMPORTANT: Adjust final layer to match your number of classes (3 in your case)
45
  model.classifier = nn.Linear(model.classifier.in_features, 3)
46
 
47
- # Load state dict
48
- state_dict = torch.load('densenet169_seed40_best.pt', map_location=torch.device('cpu'))
49
- model.load_state_dict(state_dict)
 
 
50
 
51
- # Set eval mode
52
- model.eval()
 
 
 
53
 
54
  # ✅ Preprocess Image
55
  transform = transforms.Compose([
 
40
  init_db()
41
 
42
  model = models.densenet169(pretrained=False)
 
 
43
  model.classifier = nn.Linear(model.classifier.in_features, 3)
44
 
45
+ # 2️⃣ Load checkpoint
46
+ checkpoint = torch.load('densenet169_seed40_best.pt', map_location='cpu')
47
+
48
+ # ✅ If checkpoint contains state_dict directly
49
+ # model.load_state_dict(checkpoint)
50
 
51
+ # If checkpoint contains a dict
52
+ if 'model_state_dict' in checkpoint:
53
+ model.load_state_dict(checkpoint['model_state_dict'])
54
+ else:
55
+ model.load_state_dict(checkpoint)
56
 
57
  # ✅ Preprocess Image
58
  transform = transforms.Compose([