Spaces:
Sleeping
Sleeping
Update streamlit_app.py
Browse files- streamlit_app.py +17 -2
streamlit_app.py
CHANGED
@@ -60,7 +60,23 @@ def load_model():
|
|
60 |
checkpoint = torch.load(MODEL_PATH, map_location='cpu')
|
61 |
model_state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
|
62 |
num_classes = len(class_names)
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
model.load_state_dict(model_state_dict, strict=False)
|
65 |
model.eval()
|
66 |
return model
|
@@ -68,7 +84,6 @@ def load_model():
|
|
68 |
st.error(f"Error loading model: {str(e)}")
|
69 |
return None
|
70 |
|
71 |
-
model = load_model()
|
72 |
|
73 |
def predict_butterfly(image, threshold=0.5):
|
74 |
try:
|
|
|
60 |
checkpoint = torch.load(MODEL_PATH, map_location='cpu')
|
61 |
model_state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
|
62 |
num_classes = len(class_names)
|
63 |
+
|
64 |
+
# Detect model size via known feature dimensions
|
65 |
+
classifier_input = model_state_dict['classifier.weight'].shape[1]
|
66 |
+
feature_map = {
|
67 |
+
1280: 'efficientnet_b0',
|
68 |
+
1408: 'efficientnet_b1',
|
69 |
+
1536: 'efficientnet_b2',
|
70 |
+
1792: 'efficientnet_b3',
|
71 |
+
1920: 'efficientnet_b4',
|
72 |
+
2048: 'efficientnet_b5',
|
73 |
+
2304: 'efficientnet_b6',
|
74 |
+
2560: 'efficientnet_b7'
|
75 |
+
}
|
76 |
+
model_name = feature_map.get(classifier_input, 'efficientnet_b3')
|
77 |
+
print(f"Auto-detected model architecture: {model_name}")
|
78 |
+
|
79 |
+
model = timm.create_model(model_name, pretrained=False, num_classes=num_classes, drop_rate=0.4, drop_path_rate=0.3)
|
80 |
model.load_state_dict(model_state_dict, strict=False)
|
81 |
model.eval()
|
82 |
return model
|
|
|
84 |
st.error(f"Error loading model: {str(e)}")
|
85 |
return None
|
86 |
|
|
|
87 |
|
88 |
def predict_butterfly(image, threshold=0.5):
|
89 |
try:
|