leynessa commited on
Commit
9d53d92
·
verified ·
1 Parent(s): 6be4ed9

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +14 -9
streamlit_app.py CHANGED
@@ -61,13 +61,12 @@ def load_model():
61
  model_state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
62
  num_classes = len(class_names)
63
 
64
- # Try to detect classifier shape; fallback if key missing
65
- classifier_key = [k for k in model_state_dict.keys() if 'classifier.weight' in k or 'head.fc.weight' in k]
66
- if not classifier_key:
67
- st.error("Could not detect classifier layer in checkpoint.")
68
- return None
69
-
70
- classifier_input = model_state_dict[classifier_key[0]].shape[1]
71
 
72
  feature_map = {
73
  1280: 'efficientnet_b0',
@@ -79,8 +78,13 @@ def load_model():
79
  2304: 'efficientnet_b6',
80
  2560: 'efficientnet_b7'
81
  }
82
- model_name = feature_map.get(classifier_input, 'efficientnet_b3')
83
- st.info(f"Detected model architecture: {model_name}")
 
 
 
 
 
84
 
85
  model = timm.create_model(model_name, pretrained=False, num_classes=num_classes, drop_rate=0.4, drop_path_rate=0.3)
86
  model.load_state_dict(model_state_dict, strict=False)
@@ -122,6 +126,7 @@ st.write("Tuvasta liblikaid oma kaamera abil või laadi üles pilt! / Identify b
122
 
123
  tab1, tab2 = st.tabs(["📷 Live Camera / Kaamera", "📁 Upload Image / Laadi üles"])
124
 
 
125
  with tab1:
126
  st.header("Kaamera jäädvustamine / Camera Capture")
127
  st.write("Tee pilt liblikast tuvastamiseks / Take a photo of a butterfly for identification.")
 
61
  model_state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
62
  num_classes = len(class_names)
63
 
64
+ # Attempt to auto-detect model from batch norm layer dimensions
65
+ bn2_shape = None
66
+ for key in model_state_dict:
67
+ if key.endswith("bn2.weight"):
68
+ bn2_shape = model_state_dict[key].shape[0]
69
+ break
 
70
 
71
  feature_map = {
72
  1280: 'efficientnet_b0',
 
78
  2304: 'efficientnet_b6',
79
  2560: 'efficientnet_b7'
80
  }
81
+
82
+ if bn2_shape is None:
83
+ st.warning("Could not detect classifier or bn2 layer in checkpoint. Defaulting to efficientnet_b3")
84
+ model_name = 'efficientnet_b3'
85
+ else:
86
+ model_name = feature_map.get(bn2_shape, 'efficientnet_b3')
87
+ st.info(f"Detected model architecture: {model_name}")
88
 
89
  model = timm.create_model(model_name, pretrained=False, num_classes=num_classes, drop_rate=0.4, drop_path_rate=0.3)
90
  model.load_state_dict(model_state_dict, strict=False)
 
126
 
127
  tab1, tab2 = st.tabs(["📷 Live Camera / Kaamera", "📁 Upload Image / Laadi üles"])
128
 
129
+
130
  with tab1:
131
  st.header("Kaamera jäädvustamine / Camera Capture")
132
  st.write("Tee pilt liblikast tuvastamiseks / Take a photo of a butterfly for identification.")