leynessa commited on
Commit
a233bbd
·
verified ·
1 Parent(s): af316a5

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +14 -1
streamlit_app.py CHANGED
@@ -25,12 +25,25 @@ except:
25
 
26
  @st.cache_resource
27
  def load_model():
 
 
 
 
 
 
 
 
 
 
 
 
28
  model = models.resnet18(pretrained=False)
29
  model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
30
- model.load_state_dict(torch.load("butterfly_classifier.pth", map_location="cpu"))
31
  model.eval()
32
  return model
33
 
 
34
  model = load_model()
35
 
36
  transform = transforms.Compose([
 
25
 
26
  @st.cache_resource
27
  def load_model():
28
+ import os
29
+ import urllib.request
30
+
31
+ MODEL_URL = "https://huggingface.co/your-username/your-repo/resolve/main/butterfly_classifier.pth"
32
+ MODEL_PATH = "butterfly_classifier.pth"
33
+
34
+ # Download if not already present
35
+ if not os.path.exists(MODEL_PATH):
36
+ st.info("Downloading model...")
37
+ urllib.request.urlretrieve(MODEL_URL, MODEL_PATH)
38
+
39
+ # Load the model
40
  model = models.resnet18(pretrained=False)
41
  model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
42
+ model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
43
  model.eval()
44
  return model
45
 
46
+
47
  model = load_model()
48
 
49
  transform = transforms.Compose([