dschandra commited on
Commit
030f87e
·
verified ·
1 Parent(s): 3ed7773

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -17
app.py CHANGED
@@ -13,23 +13,32 @@ from PIL import Image
13
  class Net(nn.Module):
14
  def __init__(self):
15
  super(Net, self).__init__()
 
16
  self.fc1 = nn.Linear(28 * 28, 128)
17
  self.fc2 = nn.Linear(128, 64)
18
  self.fc3 = nn.Linear(64, 10)
19
 
20
  def forward(self, x):
21
- x = x.view(-1, 28 * 28) # Flatten the input
 
 
22
  x = F.relu(self.fc1(x))
23
  x = F.relu(self.fc2(x))
 
24
  x = self.fc3(x)
25
  return F.log_softmax(x, dim=1)
26
 
27
  # Load and preprocess the MNIST dataset
28
- transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
 
 
 
29
 
 
30
  train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
31
  train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
32
 
 
33
  test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
34
  test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
35
 
@@ -38,8 +47,10 @@ model = Net()
38
  criterion = nn.CrossEntropyLoss()
39
  optimizer = optim.Adam(model.parameters(), lr=0.001)
40
 
41
- # Load the model file
42
- model_path = 'mnist_model.pth' # Adjust this path if needed
 
 
43
  if not os.path.isfile(model_path):
44
  raise FileNotFoundError(f"The model file '{model_path}' was not found.")
45
 
@@ -49,29 +60,29 @@ model.eval()
49
 
50
  # Define the predict function
51
  def predict_image(img):
52
- # Preprocess the image
53
- img = img.convert('L')
54
- img = img.resize((28, 28))
55
- img = np.array(img).astype('float32') / 255.0
56
- img = (img - 0.5) / 0.5 # Normalize
57
  img = torch.tensor(img).unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions
58
 
59
  # Make a prediction
60
  with torch.no_grad():
61
- output = model(img)
62
- predicted_digit = output.argmax(dim=1, keepdim=True).item()
63
 
64
  return predicted_digit
65
 
66
  # Create the Gradio interface
67
  iface = gr.Interface(
68
- fn=predict_image,
69
- inputs=gr.inputs.Image(shape=(28, 28), image_mode='L', invert_colors=False),
70
- outputs='label',
71
- live=True,
72
- description="Upload an image of a handwritten digit, and the model will predict the digit."
73
  )
74
 
75
- # Launch the interface
76
  if __name__ == '__main__':
77
  iface.launch()
 
13
  class Net(nn.Module):
14
  def __init__(self):
15
  super(Net, self).__init__()
16
+ # Define layers of the neural network
17
  self.fc1 = nn.Linear(28 * 28, 128)
18
  self.fc2 = nn.Linear(128, 64)
19
  self.fc3 = nn.Linear(64, 10)
20
 
21
  def forward(self, x):
22
+ # Flatten the input tensor
23
+ x = x.view(-1, 28 * 28)
24
+ # Apply ReLU activation function
25
  x = F.relu(self.fc1(x))
26
  x = F.relu(self.fc2(x))
27
+ # Output layer with log softmax activation
28
  x = self.fc3(x)
29
  return F.log_softmax(x, dim=1)
30
 
31
  # Load and preprocess the MNIST dataset
32
+ transform = transforms.Compose([
33
+ transforms.ToTensor(),
34
+ transforms.Normalize((0.5,), (0.5,))
35
+ ])
36
 
37
+ # Download and load training dataset
38
  train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
39
  train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
40
 
41
+ # Download and load test dataset
42
  test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
43
  test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
44
 
 
47
  criterion = nn.CrossEntropyLoss()
48
  optimizer = optim.Adam(model.parameters(), lr=0.001)
49
 
50
+ # Path to the model file
51
+ model_path = 'mnist_model.pth'
52
+
53
+ # Check if the model file exists
54
  if not os.path.isfile(model_path):
55
  raise FileNotFoundError(f"The model file '{model_path}' was not found.")
56
 
 
60
 
61
  # Define the predict function
62
  def predict_image(img):
63
+ # Preprocess the uploaded image
64
+ img = img.convert('L') # Convert image to grayscale
65
+ img = img.resize((28, 28)) # Resize image to 28x28 pixels
66
+ img = np.array(img).astype('float32') / 255.0 # Normalize pixel values
67
+ img = (img - 0.5) / 0.5 # Normalize to range [-1, 1]
68
  img = torch.tensor(img).unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions
69
 
70
  # Make a prediction
71
  with torch.no_grad():
72
+ output = model(img) # Forward pass through the model
73
+ predicted_digit = output.argmax(dim=1, keepdim=True).item() # Get the predicted digit
74
 
75
  return predicted_digit
76
 
77
  # Create the Gradio interface
78
  iface = gr.Interface(
79
+ fn=predict_image, # Function to be called on image upload
80
+ inputs=gr.inputs.Image(shape=(28, 28), image_mode='L', invert_colors=False), # Input format
81
+ outputs='label', # Output format
82
+ live=True, # Live update
83
+ description="Upload an image of a handwritten digit, and the model will predict the digit." # Description of the interface
84
  )
85
 
86
+ # Launch the Gradio interface
87
  if __name__ == '__main__':
88
  iface.launch()