leynessa commited on
Commit
fab62eb
·
verified ·
1 Parent(s): 2947776

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +17 -9
streamlit_app.py CHANGED
@@ -37,22 +37,30 @@ def load_model():
37
  st.error("Model file not found. Please upload butterfly_classifier.pth to your space.")
38
  return None
39
 
40
- # Load the checkpoint first to check the actual number of classes
41
  checkpoint = torch.load(MODEL_PATH, map_location="cpu")
42
 
43
- # Get the number of classes from the saved model weights
44
- if 'classifier.weight' in checkpoint:
45
- num_classes_in_model = checkpoint['classifier.weight'].shape[0]
46
- elif 'fc.weight' in checkpoint: # Alternative naming
47
- num_classes_in_model = checkpoint['fc.weight'].shape[0]
 
 
 
 
 
 
 
48
  else:
49
  # Fallback: assume it matches class_names
50
  num_classes_in_model = len(class_names)
51
 
52
-
53
- # Create model with the correct number of classes from the saved model
54
  model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=num_classes_in_model)
55
- model.load_state_dict(checkpoint)
 
 
56
  model.eval()
57
 
58
  return model
 
37
  st.error("Model file not found. Please upload butterfly_classifier.pth to your space.")
38
  return None
39
 
40
+ # Load the checkpoint
41
  checkpoint = torch.load(MODEL_PATH, map_location="cpu")
42
 
43
+ # Extract the actual model state dict
44
+ if 'model_state_dict' in checkpoint:
45
+ model_state_dict = checkpoint['model_state_dict']
46
+ else:
47
+ # If it's just the state dict directly
48
+ model_state_dict = checkpoint
49
+
50
+ # Get the number of classes from the model weights
51
+ if 'classifier.weight' in model_state_dict:
52
+ num_classes_in_model = model_state_dict['classifier.weight'].shape[0]
53
+ elif 'fc.weight' in model_state_dict:
54
+ num_classes_in_model = model_state_dict['fc.weight'].shape[0]
55
  else:
56
  # Fallback: assume it matches class_names
57
  num_classes_in_model = len(class_names)
58
 
59
+ # Create model with the correct number of classes
 
60
  model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=num_classes_in_model)
61
+
62
+ # Load the model state dict (not the entire checkpoint)
63
+ model.load_state_dict(model_state_dict)
64
  model.eval()
65
 
66
  return model