dschandra commited on
Commit
fa9b186
·
verified ·
1 Parent(s): 75083da

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ # Import necessary libraries
4
+ import numpy as np
5
+ import gradio as gr
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.optim as optim
10
+ from torchvision import datasets, transforms
11
+ from torch.utils.data import DataLoader
12
+ from PIL import Image
13
+
14
+ # Define the neural network model
15
+ class Net(nn.Module):
16
+ def __init__(self):
17
+ super(Net, self).__init__()
18
+ self.fc1 = nn.Linear(28 * 28, 128)
19
+ self.fc2 = nn.Linear(128, 64)
20
+ self.fc3 = nn.Linear(64, 10)
21
+
22
+ def forward(self, x):
23
+ x = x.view(-1, 28 * 28) # Flatten the input
24
+ x = F.relu(self.fc1(x))
25
+ x = F.relu(self.fc2(x))
26
+ x = self.fc3(x)
27
+ return F.log_softmax(x, dim=1)
28
+
29
+ # Load and preprocess the MNIST dataset
30
+ transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
31
+
32
+ train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
33
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
34
+
35
+ test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
36
+ test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
37
+
38
+ # Initialize the model, loss function, and optimizer
39
+ model = Net()
40
+ criterion = nn.CrossEntropyLoss()
41
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
42
+
43
+ # Train the model
44
+ def train(model, train_loader, criterion, optimizer, epochs=5):
45
+ model.train()
46
+ for epoch in range(epochs):
47
+ for data, target in train_loader:
48
+ optimizer.zero_grad()
49
+ output = model(data)
50
+ loss = criterion(output, target)
51
+ loss.backward()
52
+ optimizer.step()
53
+
54
+ train(model, train_loader, criterion, optimizer)
55
+
56
+ # Save the trained model
57
+ torch.save(model.state_dict(), 'mnist_model.pth')
58
+
59
+ # Load the trained model
60
+ model.load_state_dict(torch.load('mnist_model.pth'))
61
+ model.eval()
62
+
63
+ # Define the predict function
64
+ def predict_image(img):
65
+ # Preprocess the image
66
+ img = img.convert('L')
67
+ img = img.resize((28, 28))
68
+ img = np.array(img).astype('float32') / 255.0
69
+ img = (img - 0.5) / 0.5 # Normalize
70
+ img = torch.tensor(img).unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions
71
+
72
+ # Make a prediction
73
+ with torch.no_grad():
74
+ output = model(img)
75
+ predicted_digit = output.argmax(dim=1, keepdim=True).item()
76
+
77
+ return predicted_digit
78
+
79
+ # Create the Gradio interface
80
+ iface = gr.Interface(
81
+ fn=predict_image,
82
+ inputs=gr.inputs.Image(shape=(28, 28), image_mode='L', invert_colors=False),
83
+ outputs='label',
84
+ live=True,
85
+ description="Upload an image of a handwritten digit, and the model will predict the digit."
86
+ )
87
+
88
+ # Launch the interface
89
+ if __name__ == '__main__':
90
+ iface.launch()