leynessa commited on
Commit
7fb82fb
·
verified ·
1 Parent(s): 06966cd

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +5 -2
streamlit_app.py CHANGED
@@ -61,7 +61,7 @@ def load_model():
61
  model_state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
62
  num_classes = len(class_names)
63
 
64
- # Detect model size via known feature dimensions
65
  classifier_input = model_state_dict['classifier.weight'].shape[1]
66
  feature_map = {
67
  1280: 'efficientnet_b0',
@@ -74,7 +74,7 @@ def load_model():
74
  2560: 'efficientnet_b7'
75
  }
76
  model_name = feature_map.get(classifier_input, 'efficientnet_b3')
77
- print(f"Auto-detected model architecture: {model_name}")
78
 
79
  model = timm.create_model(model_name, pretrained=False, num_classes=num_classes, drop_rate=0.4, drop_path_rate=0.3)
80
  model.load_state_dict(model_state_dict, strict=False)
@@ -84,9 +84,12 @@ def load_model():
84
  st.error(f"Error loading model: {str(e)}")
85
  return None
86
 
 
87
 
88
  def predict_butterfly(image, threshold=0.5):
89
  try:
 
 
90
  if image is None:
91
  return None, None
92
  if isinstance(image, np.ndarray):
 
61
  model_state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
62
  num_classes = len(class_names)
63
 
64
+ # Detect model from classifier shape
65
  classifier_input = model_state_dict['classifier.weight'].shape[1]
66
  feature_map = {
67
  1280: 'efficientnet_b0',
 
74
  2560: 'efficientnet_b7'
75
  }
76
  model_name = feature_map.get(classifier_input, 'efficientnet_b3')
77
+ st.info(f"Detected model architecture: {model_name}")
78
 
79
  model = timm.create_model(model_name, pretrained=False, num_classes=num_classes, drop_rate=0.4, drop_path_rate=0.3)
80
  model.load_state_dict(model_state_dict, strict=False)
 
84
  st.error(f"Error loading model: {str(e)}")
85
  return None
86
 
87
+ model = load_model()
88
 
89
  def predict_butterfly(image, threshold=0.5):
90
  try:
91
+ if model is None:
92
+ raise ValueError("Model is not loaded.")
93
  if image is None:
94
  return None, None
95
  if isinstance(image, np.ndarray):