Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
# Import necessary libraries
|
| 2 |
import numpy as np
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
|
@@ -53,6 +54,18 @@ iface = gr.Interface(
|
|
| 53 |
description="Upload an image of a handwritten digit, and the model will predict the digit."
|
| 54 |
)
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
# Launch the Gradio interface
|
| 57 |
if __name__ == '__main__':
|
| 58 |
iface.launch()
|
|
|
|
| 1 |
# Import necessary libraries
|
| 2 |
import numpy as np
|
| 3 |
+
import os
|
| 4 |
import gradio as gr
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
|
|
|
| 54 |
description="Upload an image of a handwritten digit, and the model will predict the digit."
|
| 55 |
)
|
| 56 |
|
| 57 |
+
# Check if the file exists
|
| 58 |
+
if not os.path.isfile('mnist_model.pth'):
|
| 59 |
+
raise FileNotFoundError("The model file 'mnist_model.pth' was not found.")
|
| 60 |
+
else:
|
| 61 |
+
print("Model file found, proceeding with loading.")
|
| 62 |
+
|
| 63 |
+
# Load the model state dict
|
| 64 |
+
model.load_state_dict(torch.load('mnist_model.pth'))
|
| 65 |
+
|
| 66 |
+
model.load_state_dict(torch.load('mnist_model.pth', weights_only=True))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
# Launch the Gradio interface
|
| 70 |
if __name__ == '__main__':
|
| 71 |
iface.launch()
|