leynessa commited on
Commit
aa86cbd
·
verified ·
1 Parent(s): 4df5740

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +18 -3
streamlit_app.py CHANGED
@@ -37,10 +37,25 @@ def load_model():
37
  st.error("Model file not found. Please upload butterfly_classifier.pth to your space.")
38
  return None
39
 
40
- # Use EfficientNet-B0 (same as training)
41
- model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=len(class_names))
42
- model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  model.eval()
 
44
  return model
45
 
46
  # Load the model
 
37
  st.error("Model file not found. Please upload butterfly_classifier.pth to your space.")
38
  return None
39
 
40
+ # Load the checkpoint first to check the actual number of classes
41
+ checkpoint = torch.load(MODEL_PATH, map_location="cpu")
42
+
43
+ # Get the number of classes from the saved model weights
44
+ if 'classifier.weight' in checkpoint:
45
+ num_classes_in_model = checkpoint['classifier.weight'].shape[0]
46
+ elif 'fc.weight' in checkpoint: # Alternative naming
47
+ num_classes_in_model = checkpoint['fc.weight'].shape[0]
48
+ else:
49
+ # Fallback: assume it matches class_names
50
+ num_classes_in_model = len(class_names)
51
+
52
+ st.info(f"Model has {num_classes_in_model} classes, class_names.txt has {len(class_names)} classes")
53
+
54
+ # Create model with the correct number of classes from the saved model
55
+ model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=num_classes_in_model)
56
+ model.load_state_dict(checkpoint)
57
  model.eval()
58
+
59
  return model
60
 
61
  # Load the model