Spaces:
Running
Running
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| # Define the CNN | |
| class SimpleCNN(nn.Module): | |
| def __init__(self): | |
| super(SimpleCNN, self).__init__() | |
| self.conv1 = nn.Conv2d(3, 16, 3, padding=1) | |
| self.conv2 = nn.Conv2d(16, 32, 3, padding=1) | |
| self.pool = nn.MaxPool2d(2, 2) | |
| self.fc1 = nn.Linear(32 * 8 * 8, 128) | |
| self.fc2 = nn.Linear(128, 10) | |
| def forward(self, x): | |
| x = self.pool(torch.relu(self.conv1(x))) | |
| x = self.pool(torch.relu(self.conv2(x))) | |
| x = x.view(-1, 32 * 8 * 8) | |
| x = torch.relu(self.fc1(x)) | |
| x = self.fc2(x) | |
| return x | |
| # Function to train the model | |
| def train_model(num_epochs): | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) | |
| trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) | |
| testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) | |
| testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) | |
| CIFAR10_CLASSES = [ | |
| 'plane', 'car', 'bird', 'cat', 'deer', | |
| 'dog', 'frog', 'horse', 'ship', 'truck' | |
| ] | |
| net = SimpleCNN() | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9) | |
| loss_values = [] | |
| st.write("Training the model...") | |
| for epoch in range(num_epochs): | |
| running_loss = 0.0 | |
| for i, data in enumerate(trainloader, 0): | |
| inputs, labels = data | |
| optimizer.zero_grad() | |
| outputs = net(inputs) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() | |
| loss_values.append(running_loss / len(trainloader)) | |
| st.write(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader):.3f}') | |
| st.write('Finished Training') | |
| # Plot the loss values | |
| plt.figure(figsize=(10, 5)) | |
| plt.plot(range(1, num_epochs + 1), loss_values, marker='o') | |
| plt.title('Training Loss over Epochs') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Loss') | |
| st.pyplot(plt) | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| for data in testloader: | |
| images, labels = data | |
| outputs = net(images) | |
| _, predicted = torch.max(outputs, 1) | |
| total += labels.size(0) | |
| correct += (predicted == labels).sum().item() | |
| st.write(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%') | |
| # Visualize some test images and their predictions | |
| def imshow(img): | |
| img = img / 2 + 0.5 # Unnormalize | |
| npimg = img.numpy() | |
| plt.imshow(np.transpose(npimg, (1, 2, 0))) | |
| plt.show() | |
| dataiter = iter(testloader) | |
| images, labels = next(dataiter) | |
| imshow(torchvision.utils.make_grid(images)) | |
| outputs = net(images) | |
| _, predicted = torch.max(outputs, 1) | |
| st.write('Predicted: ', ' '.join(f'{CIFAR10_CLASSES[predicted[j]]:5s}' for j in range(8))) | |
| st.write('Actual: ', ' '.join(f'{CIFAR10_CLASSES[labels[j]]:5s}' for j in range(8))) | |
| st.pyplot() | |
| # Streamlit interface | |
| st.title('CIFAR-10 Classification with PyTorch') | |
| num_epochs = st.number_input('Enter number of epochs:', min_value=1, max_value=100, value=10) | |
| if st.button('Run'): | |
| train_model(num_epochs) | |