dschandra commited on
Commit
c0c0bc8
·
verified ·
1 Parent(s): 2827bbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -30
app.py CHANGED
@@ -1,10 +1,8 @@
1
- # train_model.py
2
-
3
  import torch
4
  import torch.nn as nn
5
- import torch.optim as optim
6
- from torchvision import datasets, transforms
7
- from torch.utils.data import DataLoader
8
 
9
  # Define the neural network model
10
  class Net(nn.Module):
@@ -21,33 +19,35 @@ class Net(nn.Module):
21
  x = self.fc3(x)
22
  return torch.log_softmax(x, dim=1)
23
 
24
- # Load and preprocess the MNIST dataset
 
 
 
 
 
25
  transform = transforms.Compose([
 
 
26
  transforms.ToTensor(),
27
  transforms.Normalize((0.5,), (0.5,))
28
  ])
29
 
30
- train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
31
- train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
32
-
33
- # Initialize the model, loss function, and optimizer
34
- model = Net()
35
- criterion = nn.CrossEntropyLoss()
36
- optimizer = optim.Adam(model.parameters(), lr=0.001)
37
-
38
- # Train the model
39
- def train(model, train_loader, criterion, optimizer, epochs=5):
40
- model.train()
41
- for epoch in range(epochs):
42
- for data, target in train_loader:
43
- optimizer.zero_grad()
44
- output = model(data)
45
- loss = criterion(output, target)
46
- loss.backward()
47
- optimizer.step()
48
-
49
- train(model, train_loader, criterion, optimizer)
50
-
51
- # Save the trained model
52
- torch.save(model.state_dict(), 'mnist_model.pth')
53
- print("Model saved as 'mnist_model.pth'")
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ import torchvision.transforms as transforms
4
+ import gradio as gr
5
+ from PIL import Image
6
 
7
  # Define the neural network model
8
  class Net(nn.Module):
 
19
  x = self.fc3(x)
20
  return torch.log_softmax(x, dim=1)
21
 
22
+ # Load the trained model
23
+ model = Net()
24
+ model.load_state_dict(torch.load('mnist_model.pth', map_location=torch.device('cpu')))
25
+ model.eval()
26
+
27
+ # Define the transform to preprocess the input image
28
  transform = transforms.Compose([
29
+ transforms.Grayscale(num_output_channels=1),
30
+ transforms.Resize((28, 28)),
31
  transforms.ToTensor(),
32
  transforms.Normalize((0.5,), (0.5,))
33
  ])
34
 
35
+ # Define the prediction function
36
+ def predict(image):
37
+ image = transform(image).unsqueeze(0) # Add batch dimension
38
+ with torch.no_grad():
39
+ output = model(image)
40
+ prediction = torch.argmax(output, dim=1).item()
41
+ return prediction
42
+
43
+ # Create the Gradio interface
44
+ iface = gr.Interface(
45
+ fn=predict,
46
+ inputs=gr.inputs.Image(shape=(28, 28), image_mode='L', invert_colors=False),
47
+ outputs="label",
48
+ live=True
49
+ )
50
+
51
+ # Launch the Gradio interface
52
+ if __name__ == "__main__":
53
+ iface.launch()