Spaces:
Sleeping
Sleeping
Update streamlit_app.py
Browse files- streamlit_app.py +14 -1
streamlit_app.py
CHANGED
@@ -25,12 +25,25 @@ except:
|
|
25 |
|
26 |
@st.cache_resource
|
27 |
def load_model():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
model = models.resnet18(pretrained=False)
|
29 |
model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
|
30 |
-
model.load_state_dict(torch.load(
|
31 |
model.eval()
|
32 |
return model
|
33 |
|
|
|
34 |
model = load_model()
|
35 |
|
36 |
transform = transforms.Compose([
|
|
|
25 |
|
26 |
@st.cache_resource
|
27 |
def load_model():
|
28 |
+
import os
|
29 |
+
import urllib.request
|
30 |
+
|
31 |
+
MODEL_URL = "https://huggingface.co/your-username/your-repo/resolve/main/butterfly_classifier.pth"
|
32 |
+
MODEL_PATH = "butterfly_classifier.pth"
|
33 |
+
|
34 |
+
# Download if not already present
|
35 |
+
if not os.path.exists(MODEL_PATH):
|
36 |
+
st.info("Downloading model...")
|
37 |
+
urllib.request.urlretrieve(MODEL_URL, MODEL_PATH)
|
38 |
+
|
39 |
+
# Load the model
|
40 |
model = models.resnet18(pretrained=False)
|
41 |
model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
|
42 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
|
43 |
model.eval()
|
44 |
return model
|
45 |
|
46 |
+
|
47 |
model = load_model()
|
48 |
|
49 |
transform = transforms.Compose([
|