Regino commited on
Commit
0e83c47
Β·
1 Parent(s): b90ceb7
Files changed (4) hide show
  1. app.py +154 -0
  2. data.zip +3 -0
  3. model.pth +3 -0
  4. trainmodel.py +53 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ import torchvision.models as models
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ import pandas as pd
9
+ import random
10
+ from PIL import Image
11
+ from torchvision import datasets
12
+ from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
13
+
14
+ # CIFAR-10 Class Names
15
+ CLASS_NAMES = [
16
+ "Airplane", "Automobile", "Bird", "Cat", "Deer",
17
+ "Dog", "Frog", "Horse", "Ship", "Truck"
18
+ ]
19
+
20
+ # Load CIFAR-10 Dataset for Visualization
21
+ transform = transforms.Compose([transforms.ToTensor()])
22
+ dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
23
+
24
+ # Load Trained Model
25
+ @st.cache_resource
26
+ def load_model():
27
+ model = models.resnet18(pretrained=False)
28
+ model.fc = nn.Linear(model.fc.in_features, len(CLASS_NAMES))
29
+ model.load_state_dict(torch.load("model.pth", map_location=torch.device("cpu")))
30
+ model.eval()
31
+ return model
32
+
33
+ model = load_model()
34
+
35
+ # Sidebar Navigation
36
+ st.sidebar.title("Navigation")
37
+ page = st.sidebar.radio("Go to", ["Dataset", "Visualizations", "Model Metrics", "Predictor"])
38
+
39
+ # πŸ“Œ Dataset Preview Page
40
+ if page == "Dataset":
41
+ st.title("πŸ“Š CIFAR-10 Dataset Preview")
42
+
43
+ # Dataset Information
44
+ st.markdown("""
45
+ ## πŸ“ About CIFAR-10
46
+ The **CIFAR-10 dataset** is widely used in image classification research.
47
+ - πŸ“ **Created by**: Alex Krizhevsky, Vinod Nair, Geoffrey Hinton
48
+ - πŸ› **From**: University of Toronto
49
+ - πŸ“Έ **Images**: 60,000 color images (**32Γ—32 pixels**)
50
+ - 🏷 **Classes (10)**:
51
+ - πŸ›« Airplane
52
+ - πŸš— Automobile
53
+ - 🐦 Bird
54
+ - 🐱 Cat
55
+ - 🦌 Deer
56
+ - 🐢 Dog
57
+ - 🐸 Frog
58
+ - 🐴 Horse
59
+ - 🚒 Ship
60
+ - πŸš› Truck
61
+ - πŸ”— **[Dataset Link](https://www.cs.toronto.edu/~kriz/cifar.html)**
62
+ """)
63
+
64
+ # Show 10 Random Images
65
+ st.subheader("πŸ” Random CIFAR-10 Images")
66
+ cols = st.columns(5) # Display in 5 columns
67
+ for i in range(10):
68
+ index = random.randint(0, len(dataset) - 1)
69
+ image, label = dataset[index]
70
+ image = transforms.ToPILImage()(image) # Convert tensor to image
71
+ cols[i % 5].image(image, caption=CLASS_NAMES[label], use_container_width=True)
72
+
73
+ # πŸ“ˆ Visualization Page
74
+ elif page == "Visualizations":
75
+ st.title("πŸ“Š Dataset Visualizations")
76
+
77
+ # Count class occurrences
78
+ class_counts = {cls: 0 for cls in CLASS_NAMES}
79
+ for _, label in dataset:
80
+ class_counts[CLASS_NAMES[label]] += 1
81
+
82
+ # Pie Chart
83
+ st.subheader("πŸ“Œ Class Distribution (Pie Chart)")
84
+ fig, ax = plt.subplots()
85
+ colors = sns.color_palette("husl", len(CLASS_NAMES))
86
+ ax.pie(class_counts.values(), labels=class_counts.keys(), autopct='%1.1f%%', colors=colors)
87
+ st.pyplot(fig)
88
+
89
+ # Bar Chart
90
+ st.subheader("πŸ“Š Class Distribution (Bar Chart)")
91
+ fig, ax = plt.subplots()
92
+ sns.barplot(x=list(class_counts.keys()), y=list(class_counts.values()), palette="husl")
93
+ plt.xticks(rotation=45)
94
+ st.pyplot(fig)
95
+
96
+ # πŸ“Š Model Metrics Page
97
+ elif page == "Model Metrics":
98
+ st.title("πŸ“ˆ Model Performance")
99
+
100
+ try:
101
+ y_true = torch.load("y_true.pth")
102
+ y_pred = torch.load("y_pred.pth")
103
+
104
+ # Display Accuracy
105
+ st.write(f"### βœ… Accuracy: **{accuracy_score(y_true, y_pred):.2f}**")
106
+
107
+ # Classification Report
108
+ report = classification_report(y_true, y_pred, target_names=CLASS_NAMES, output_dict=True)
109
+ st.write(pd.DataFrame(report).T)
110
+
111
+ # Confusion Matrix
112
+ st.subheader("πŸ”„ Confusion Matrix")
113
+ cm = confusion_matrix(y_true, y_pred)
114
+ fig, ax = plt.subplots(figsize=(8, 6))
115
+ sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
116
+ st.pyplot(fig)
117
+
118
+ except:
119
+ st.error("🚨 Model metrics files not found!")
120
+
121
+ # πŸ” Prediction Page
122
+ elif page == "Predictor":
123
+ st.title("πŸ” CIFAR-10 Image Classifier")
124
+
125
+ # About the Classifier
126
+ st.markdown("""
127
+ ## πŸ“ About This App
128
+ This app is a **deep learning image classifier** trained on the **CIFAR-10 dataset**.
129
+ It can recognize **10 different objects/animals**:
130
+ - πŸ›« Airplane, πŸš— Automobile, 🐦 Bird, 🐱 Cat, 🦌 Deer
131
+ - 🐢 Dog, 🐸 Frog, 🐴 Horse, 🚒 Ship, πŸš› Truck
132
+ """)
133
+
134
+ # Upload Image
135
+ uploaded_file = st.file_uploader("πŸ“€ Upload an image", type=["jpg", "png", "jpeg"])
136
+ if uploaded_file is not None:
137
+ image = Image.open(uploaded_file)
138
+ st.image(image, caption="πŸ–Ό Uploaded Image", use_container_width=True)
139
+
140
+ # Transform image for model
141
+ transform = transforms.Compose([
142
+ transforms.Resize((224, 224)),
143
+ transforms.ToTensor(),
144
+ transforms.Normalize([0.5], [0.5])
145
+ ])
146
+ image_tensor = transform(image).unsqueeze(0)
147
+
148
+ # Make prediction
149
+ with torch.no_grad():
150
+ output = model(image_tensor)
151
+ predicted_class = torch.argmax(output, dim=1).item()
152
+
153
+ # Display Prediction
154
+ st.success(f"### βœ… Prediction: **{CLASS_NAMES[predicted_class]}**")
data.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:606dc5044505a1048b1a6527d230d6dd9172ab373ffb638519c3b2edc1fb1cd4
3
+ size 340726644
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0911fbb81fcf760f000c8e6b5eef931a7ca7077f0b702738011b1956b11294a
3
+ size 44796930
trainmodel.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torchvision
5
+ import torchvision.transforms as transforms
6
+ import tqdm
7
+
8
+ # Define transformations
9
+ transform = transforms.Compose([
10
+ transforms.Resize((224, 224)), # Resize images for ResNet
11
+ transforms.ToTensor(),
12
+ transforms.Normalize((0.5,), (0.5,))
13
+ ])
14
+
15
+ # Load CIFAR-10 Dataset
16
+ trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
17
+ trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
18
+
19
+ testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
20
+ testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
21
+
22
+ # Define Model (ResNet-18)
23
+ model = torchvision.models.resnet18(pretrained=True)
24
+ model.fc = nn.Linear(model.fc.in_features, 10) # Adjust for 10 CIFAR-10 classes
25
+
26
+ # Define Loss and Optimizer
27
+ criterion = nn.CrossEntropyLoss()
28
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
29
+
30
+ # Train the Model
31
+ num_epochs = 5
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ model.to(device)
34
+
35
+ for epoch in range(num_epochs):
36
+ model.train()
37
+ running_loss = 0.0
38
+ for images, labels in tqdm.tqdm(trainloader):
39
+ images, labels = images.to(device), labels.to(device)
40
+
41
+ optimizer.zero_grad()
42
+ outputs = model(images)
43
+ loss = criterion(outputs, labels)
44
+ loss.backward()
45
+ optimizer.step()
46
+
47
+ running_loss += loss.item()
48
+
49
+ print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(trainloader)}")
50
+
51
+ # Save the Trained Model
52
+ torch.save(model.state_dict(), "model.pth")
53
+ print("Model training complete and saved as model.pth!")