leynessa commited on
Commit
6be4ed9
·
verified ·
1 Parent(s): 7fb82fb

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +8 -2
streamlit_app.py CHANGED
@@ -61,8 +61,14 @@ def load_model():
61
  model_state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
62
  num_classes = len(class_names)
63
 
64
- # Detect model from classifier shape
65
- classifier_input = model_state_dict['classifier.weight'].shape[1]
 
 
 
 
 
 
66
  feature_map = {
67
  1280: 'efficientnet_b0',
68
  1408: 'efficientnet_b1',
 
61
  model_state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
62
  num_classes = len(class_names)
63
 
64
+ # Try to detect classifier shape; fallback if key missing
65
+ classifier_key = [k for k in model_state_dict.keys() if 'classifier.weight' in k or 'head.fc.weight' in k]
66
+ if not classifier_key:
67
+ st.error("Could not detect classifier layer in checkpoint.")
68
+ return None
69
+
70
+ classifier_input = model_state_dict[classifier_key[0]].shape[1]
71
+
72
  feature_map = {
73
  1280: 'efficientnet_b0',
74
  1408: 'efficientnet_b1',