leynessa commited on
Commit
9625ec8
·
verified ·
1 Parent(s): fab62eb

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +33 -5
streamlit_app.py CHANGED
@@ -43,26 +43,48 @@ def load_model():
43
  # Extract the actual model state dict
44
  if 'model_state_dict' in checkpoint:
45
  model_state_dict = checkpoint['model_state_dict']
 
 
 
 
46
  else:
47
  # If it's just the state dict directly
48
  model_state_dict = checkpoint
 
49
 
50
  # Get the number of classes from the model weights
51
  if 'classifier.weight' in model_state_dict:
52
  num_classes_in_model = model_state_dict['classifier.weight'].shape[0]
 
 
53
  elif 'fc.weight' in model_state_dict:
54
  num_classes_in_model = model_state_dict['fc.weight'].shape[0]
55
  else:
56
  # Fallback: assume it matches class_names
57
  num_classes_in_model = len(class_names)
58
 
59
- # Create model with the correct number of classes
60
- model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=num_classes_in_model)
61
 
62
- # Load the model state dict (not the entire checkpoint)
63
- model.load_state_dict(model_state_dict)
64
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
 
66
  return model
67
 
68
  # Load the model
@@ -71,9 +93,11 @@ model = load_model()
71
  if model is None:
72
  st.stop()
73
 
 
74
  transform = transforms.Compose([
75
  transforms.Resize((224, 224)),
76
  transforms.ToTensor(),
 
77
  ])
78
 
79
  def predict_butterfly(image):
@@ -85,6 +109,10 @@ def predict_butterfly(image):
85
  if isinstance(image, np.ndarray):
86
  image = Image.fromarray(image)
87
 
 
 
 
 
88
  # Preprocess
89
  input_tensor = transform(image).unsqueeze(0)
90
 
 
43
  # Extract the actual model state dict
44
  if 'model_state_dict' in checkpoint:
45
  model_state_dict = checkpoint['model_state_dict']
46
+ # Get class names from checkpoint if available
47
+ if 'class_names' in checkpoint:
48
+ saved_class_names = checkpoint['class_names']
49
+ print(f"Loaded class names from checkpoint: {len(saved_class_names)} classes")
50
  else:
51
  # If it's just the state dict directly
52
  model_state_dict = checkpoint
53
+ saved_class_names = class_names
54
 
55
  # Get the number of classes from the model weights
56
  if 'classifier.weight' in model_state_dict:
57
  num_classes_in_model = model_state_dict['classifier.weight'].shape[0]
58
+ elif 'head.weight' in model_state_dict: # Alternative naming in some timm versions
59
+ num_classes_in_model = model_state_dict['head.weight'].shape[0]
60
  elif 'fc.weight' in model_state_dict:
61
  num_classes_in_model = model_state_dict['fc.weight'].shape[0]
62
  else:
63
  # Fallback: assume it matches class_names
64
  num_classes_in_model = len(class_names)
65
 
66
+ print(f"Creating model with {num_classes_in_model} classes")
 
67
 
68
+ # Create model exactly as in training - with dropout and drop_path
69
+ model = timm.create_model(
70
+ 'efficientnet_b0',
71
+ pretrained=False, # Don't load pretrained weights
72
+ num_classes=num_classes_in_model,
73
+ drop_rate=0.3, # Match training parameters
74
+ drop_path_rate=0.2 # Match training parameters
75
+ )
76
+
77
+ # Load the model state dict
78
+ try:
79
+ model.load_state_dict(model_state_dict, strict=True)
80
+ print("Model loaded successfully!")
81
+ except RuntimeError as e:
82
+ print(f"Error loading model: {e}")
83
+ # Try with strict=False as fallback
84
+ model.load_state_dict(model_state_dict, strict=False)
85
+ print("Model loaded with some missing/unexpected keys")
86
 
87
+ model.eval()
88
  return model
89
 
90
  # Load the model
 
93
  if model is None:
94
  st.stop()
95
 
96
+ # Use the exact same transforms as in training validation
97
  transform = transforms.Compose([
98
  transforms.Resize((224, 224)),
99
  transforms.ToTensor(),
100
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
101
  ])
102
 
103
  def predict_butterfly(image):
 
109
  if isinstance(image, np.ndarray):
110
  image = Image.fromarray(image)
111
 
112
+ # Ensure RGB format
113
+ if image.mode != 'RGB':
114
+ image = image.convert('RGB')
115
+
116
  # Preprocess
117
  input_tensor = transform(image).unsqueeze(0)
118