leynessa commited on
Commit
06966cd
·
verified ·
1 Parent(s): 56ef809

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +17 -2
streamlit_app.py CHANGED
@@ -60,7 +60,23 @@ def load_model():
60
  checkpoint = torch.load(MODEL_PATH, map_location='cpu')
61
  model_state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
62
  num_classes = len(class_names)
63
- model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=num_classes, drop_rate=0.4, drop_path_rate=0.3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  model.load_state_dict(model_state_dict, strict=False)
65
  model.eval()
66
  return model
@@ -68,7 +84,6 @@ def load_model():
68
  st.error(f"Error loading model: {str(e)}")
69
  return None
70
 
71
- model = load_model()
72
 
73
  def predict_butterfly(image, threshold=0.5):
74
  try:
 
60
  checkpoint = torch.load(MODEL_PATH, map_location='cpu')
61
  model_state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
62
  num_classes = len(class_names)
63
+
64
+ # Detect model size via known feature dimensions
65
+ classifier_input = model_state_dict['classifier.weight'].shape[1]
66
+ feature_map = {
67
+ 1280: 'efficientnet_b0',
68
+ 1408: 'efficientnet_b1',
69
+ 1536: 'efficientnet_b2',
70
+ 1792: 'efficientnet_b3',
71
+ 1920: 'efficientnet_b4',
72
+ 2048: 'efficientnet_b5',
73
+ 2304: 'efficientnet_b6',
74
+ 2560: 'efficientnet_b7'
75
+ }
76
+ model_name = feature_map.get(classifier_input, 'efficientnet_b3')
77
+ print(f"Auto-detected model architecture: {model_name}")
78
+
79
+ model = timm.create_model(model_name, pretrained=False, num_classes=num_classes, drop_rate=0.4, drop_path_rate=0.3)
80
  model.load_state_dict(model_state_dict, strict=False)
81
  model.eval()
82
  return model
 
84
  st.error(f"Error loading model: {str(e)}")
85
  return None
86
 
 
87
 
88
  def predict_butterfly(image, threshold=0.5):
89
  try: