dschandra commited on
Commit
86a21c2
·
verified ·
1 Parent(s): 84197fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -31
app.py CHANGED
@@ -1,14 +1,15 @@
1
- # Import necessary libraries
2
- import numpy as np
3
  import os
 
4
  import gradio as gr
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
- from torchvision import transforms
 
 
9
  from PIL import Image
10
 
11
- # Define the neural network model using PyTorch
12
  class Net(nn.Module):
13
  def __init__(self):
14
  super(Net, self).__init__()
@@ -23,26 +24,43 @@ class Net(nn.Module):
23
  x = self.fc3(x)
24
  return F.log_softmax(x, dim=1)
25
 
26
- # Initialize the model and load the trained weights
 
 
 
 
 
 
 
 
 
27
  model = Net()
28
- model.load_state_dict(torch.load('mnist_model.pth'))
29
- model.eval()
30
 
31
- # Define the image transformations
32
- transform = transforms.Compose([
33
- transforms.Resize((28, 28)), # Resize image to 28x28
34
- transforms.Grayscale(), # Convert to grayscale
35
- transforms.ToTensor(), # Convert to tensor
36
- transforms.Normalize((0.5,), (0.5,)) # Normalize
37
- ])
 
38
 
39
- # Define the prediction function
40
  def predict_image(img):
41
- img = transform(img) # Apply transformations
42
- img = img.unsqueeze(0) # Add batch dimension
 
 
 
 
 
 
43
  with torch.no_grad():
44
  output = model(img)
45
- predicted_digit = output.argmax(dim=1).item()
 
46
  return predicted_digit
47
 
48
  # Create the Gradio interface
@@ -54,18 +72,6 @@ iface = gr.Interface(
54
  description="Upload an image of a handwritten digit, and the model will predict the digit."
55
  )
56
 
57
- # Check if the file exists
58
- if not os.path.isfile('mnist_model.pth'):
59
- raise FileNotFoundError("The model file 'mnist_model.pth' was not found.")
60
- else:
61
- print("Model file found, proceeding with loading.")
62
-
63
- # Load the model state dict
64
- model.load_state_dict(torch.load('mnist_model.pth'))
65
-
66
- model.load_state_dict(torch.load('mnist_model.pth', weights_only=True))
67
-
68
-
69
- # Launch the Gradio interface
70
  if __name__ == '__main__':
71
  iface.launch()
 
 
 
1
  import os
2
+ import numpy as np
3
  import gradio as gr
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ from torchvision import datasets, transforms
9
+ from torch.utils.data import DataLoader
10
  from PIL import Image
11
 
12
+ # Define the neural network model
13
  class Net(nn.Module):
14
  def __init__(self):
15
  super(Net, self).__init__()
 
24
  x = self.fc3(x)
25
  return F.log_softmax(x, dim=1)
26
 
27
+ # Load and preprocess the MNIST dataset
28
+ transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
29
+
30
+ train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
31
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
32
+
33
+ test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
34
+ test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
35
+
36
+ # Initialize the model, loss function, and optimizer
37
  model = Net()
38
+ criterion = nn.CrossEntropyLoss()
39
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
40
 
41
+ # Check if the model file exists
42
+ model_path = 'mnist_model.pth'
43
+ if not os.path.isfile(model_path):
44
+ raise FileNotFoundError(f"The model file '{model_path}' was not found.")
45
+
46
+ # Load the model state dict
47
+ model.load_state_dict(torch.load(model_path, weights_only=True))
48
+ model.eval()
49
 
50
+ # Define the predict function
51
  def predict_image(img):
52
+ # Preprocess the image
53
+ img = img.convert('L')
54
+ img = img.resize((28, 28))
55
+ img = np.array(img).astype('float32') / 255.0
56
+ img = (img - 0.5) / 0.5 # Normalize
57
+ img = torch.tensor(img).unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions
58
+
59
+ # Make a prediction
60
  with torch.no_grad():
61
  output = model(img)
62
+ predicted_digit = output.argmax(dim=1, keepdim=True).item()
63
+
64
  return predicted_digit
65
 
66
  # Create the Gradio interface
 
72
  description="Upload an image of a handwritten digit, and the model will predict the digit."
73
  )
74
 
75
+ # Launch the interface
 
 
 
 
 
 
 
 
 
 
 
 
76
  if __name__ == '__main__':
77
  iface.launch()