dschandra commited on
Commit
84197fd
·
verified ·
1 Parent(s): 547cda9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -0
app.py CHANGED
@@ -1,5 +1,6 @@
1
  # Import necessary libraries
2
  import numpy as np
 
3
  import gradio as gr
4
  import torch
5
  import torch.nn as nn
@@ -53,6 +54,18 @@ iface = gr.Interface(
53
  description="Upload an image of a handwritten digit, and the model will predict the digit."
54
  )
55
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # Launch the Gradio interface
57
  if __name__ == '__main__':
58
  iface.launch()
 
1
  # Import necessary libraries
2
  import numpy as np
3
+ import os
4
  import gradio as gr
5
  import torch
6
  import torch.nn as nn
 
54
  description="Upload an image of a handwritten digit, and the model will predict the digit."
55
  )
56
 
57
+ # Check if the file exists
58
+ if not os.path.isfile('mnist_model.pth'):
59
+ raise FileNotFoundError("The model file 'mnist_model.pth' was not found.")
60
+ else:
61
+ print("Model file found, proceeding with loading.")
62
+
63
+ # Load the model state dict
64
+ model.load_state_dict(torch.load('mnist_model.pth'))
65
+
66
+ model.load_state_dict(torch.load('mnist_model.pth', weights_only=True))
67
+
68
+
69
  # Launch the Gradio interface
70
  if __name__ == '__main__':
71
  iface.launch()