File size: 6,235 Bytes
138a538 2763fea 138a538 2763fea 138a538 2763fea 138a538 2763fea 138a538 2763fea 482778d 138a538 482778d 507faf2 482778d fb1c47e 482778d 138a538 2763fea 138a538 2763fea 482778d 2763fea 138a538 875165f 138a538 97de020 7fa49be 875165f 97de020 875165f 97de020 875165f 97de020 875165f 97de020 138a538 2763fea 138a538 6b7e536 138a538 97de020 138a538 97de020 138a538 97de020 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import streamlit as st
import os
import zipfile
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from PIL import Image
from torchvision import datasets
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import random
# β
Automatically unzip `train.zip` if `train/` folder is missing
DATASET_PATH = "train"
ZIP_FILE = "train.zip"
if not os.path.exists(DATASET_PATH):
if os.path.exists(ZIP_FILE):
with zipfile.ZipFile(ZIP_FILE, 'r') as zip_ref:
zip_ref.extractall(".") # Extract to current directory
# β
Load Class Names from Dataset
if os.path.exists(DATASET_PATH):
CLASS_NAMES = sorted(os.listdir(DATASET_PATH))
else:
CLASS_NAMES = []
# β
Load Model
@st.cache_resource
def load_model():
model = models.mobilenet_v2(pretrained=False)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, len(CLASS_NAMES))
model.load_state_dict(torch.load("plant_disease_model.pth", map_location=torch.device("cpu")))
model.eval()
return model
model = load_model()
# β
Sidebar Navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Dataset", "Visualizations", "Model Metrics", "Disease Predictor"])
# β
Dataset Details
DATASET_NAME = "PlantVillage"
DATASET_SOURCE = "Kaggle"
DATASET_LINK = "https://www.kaggle.com/datasets/emmarex/plantdisease"
# β
Dataset Page β Show Sample Images
if page == "Dataset":
st.title("π Dataset Preview")
# πΉ Display Dataset Information
st.markdown(f"""
**π± Dataset: {DATASET_NAME}**
- π **Source:** [{DATASET_SOURCE}]({DATASET_LINK})
- π·οΈ **Total Classes:** {len(CLASS_NAMES)}
- π **Description:** This dataset contains images of healthy and diseased leaves for various plants,
helping in plant disease classification.
""")
# πΉ Show images for all classes
num_columns = 3 # Adjust columns for better layout
cols = st.columns(num_columns)
for i, class_name in enumerate(CLASS_NAMES): # Show all classes
class_path = os.path.join(DATASET_PATH, class_name)
if os.path.exists(class_path) and os.listdir(class_path):
image_name = random.choice(os.listdir(class_path))
image_path = os.path.join(class_path, image_name)
image = Image.open(image_path)
cols[i % num_columns].image(image, caption=class_name, use_container_width=True)
# β
Visualizations Page β Show Class Distribution
elif page == "Visualizations":
st.title("π Dataset Visualizations")
if CLASS_NAMES:
class_counts = {cls: len(os.listdir(os.path.join(DATASET_PATH, cls))) for cls in CLASS_NAMES}
# Pie Chart with Proper Colors
st.write("### Disease Distribution")
fig, ax = plt.subplots()
colors = sns.color_palette("husl", len(CLASS_NAMES)) # Generate unique colors
ax.pie(class_counts.values(), labels=class_counts.keys(), autopct='%1.1f%%', colors=colors)
st.pyplot(fig)
# Bar Chart
st.write("### Class Count")
fig, ax = plt.subplots()
sns.barplot(x=list(class_counts.keys()), y=list(class_counts.values()), palette="husl")
plt.xticks(rotation=45)
st.pyplot(fig)
# β
Model Metrics Page
elif page == "Model Metrics":
st.title("π Model Performance")
try:
# Load True Labels and Predictions
y_true = torch.load("y_true.pth", weights_only=False)
y_pred = torch.load("y_pred.pth", weights_only=False)
# Accuracy
accuracy = accuracy_score(y_true, y_pred)
st.write(f"### β
Accuracy: {accuracy:.2f}")
# Classification Report
st.write("### π Classification Report")
report = classification_report(y_true, y_pred, target_names=CLASS_NAMES, output_dict=True)
st.write(pd.DataFrame(report).T)
# Confusion Matrix
st.write("### π Confusion Matrix")
cm = confusion_matrix(y_true, y_pred)
fig, ax = plt.subplots()
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
st.pyplot(fig)
except:
st.error("π¨ Model metrics files (`y_true.pth` and `y_pred.pth`) not found!")
# β
Disease Predictor Page
elif page == "Disease Predictor":
st.title("πΏ Plant Disease Classifier")
# β
App Overview
st.write("""
This app uses a MobileNet V2 model to classify plant diseases based on leaf images.
Simply upload an image of a plant leaf, and the model will predict the disease or identify if the plant is healthy.
### π·οΈ Supported Plant Diseases:
#### πΆοΈ **Pepper Bell**
- Bacterial Spot
- Healthy
#### π₯ **Potato**
- Early Blight
- Late Blight
- Healthy
#### π
**Tomato**
- Bacterial Spot
- Early Blight
- Late Blight
- Leaf Mold
- Septoria Leaf Spot
- Spider Mites (Two-Spotted Spider Mite)
- Target Spot
- Tomato Yellow Leaf Curl Virus
- Tomato Mosaic Virus
- Healthy
""")
# β
File Upload
uploaded_file = st.file_uploader("Upload a plant leaf image", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_container_width=True)
# β
Transform Image
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
image_tensor = transform(image).unsqueeze(0)
# β
Predict Disease
with torch.no_grad():
output = model(image_tensor)
predicted_class = torch.argmax(output, dim=1).item()
st.write(f"### β
Prediction: **{CLASS_NAMES[predicted_class]}**")
st.success("β The prediction is based on a trained deep learning model.")
|