import streamlit as st import os import zipfile import torch import torch.nn as nn import torchvision.transforms as transforms import torchvision.models as models import matplotlib.pyplot as plt import seaborn as sns import pandas as pd from PIL import Image from torchvision import datasets from torch.utils.data import DataLoader from sklearn.metrics import accuracy_score, classification_report, confusion_matrix import random # ✅ Automatically unzip `train.zip` if `train/` folder is missing DATASET_PATH = "train" ZIP_FILE = "train.zip" if not os.path.exists(DATASET_PATH): if os.path.exists(ZIP_FILE): with zipfile.ZipFile(ZIP_FILE, 'r') as zip_ref: zip_ref.extractall(".") # Extract to current directory # ✅ Load Class Names from Dataset if os.path.exists(DATASET_PATH): CLASS_NAMES = sorted(os.listdir(DATASET_PATH)) else: CLASS_NAMES = [] # ✅ Load Model @st.cache_resource def load_model(): model = models.mobilenet_v2(pretrained=False) model.classifier[1] = nn.Linear(model.classifier[1].in_features, len(CLASS_NAMES)) model.load_state_dict(torch.load("plant_disease_model.pth", map_location=torch.device("cpu"))) model.eval() return model model = load_model() # ✅ Sidebar Navigation st.sidebar.title("Navigation") page = st.sidebar.radio("Go to", ["Dataset", "Visualizations", "Model Metrics", "Disease Predictor"]) # ✅ Dataset Details DATASET_NAME = "PlantVillage" DATASET_SOURCE = "Kaggle" DATASET_LINK = "https://www.kaggle.com/datasets/emmarex/plantdisease" # ✅ Dataset Page – Show Sample Images if page == "Dataset": st.title("📊 Dataset Preview") # 🔹 Display Dataset Information st.markdown(f""" **🌱 Dataset: {DATASET_NAME}** - 📌 **Source:** [{DATASET_SOURCE}]({DATASET_LINK}) - 🏷️ **Total Classes:** {len(CLASS_NAMES)} - 📂 **Description:** This dataset contains images of healthy and diseased leaves for various plants, helping in plant disease classification. """) # 🔹 Show images for all classes num_columns = 3 # Adjust columns for better layout cols = st.columns(num_columns) for i, class_name in enumerate(CLASS_NAMES): # Show all classes class_path = os.path.join(DATASET_PATH, class_name) if os.path.exists(class_path) and os.listdir(class_path): image_name = random.choice(os.listdir(class_path)) image_path = os.path.join(class_path, image_name) image = Image.open(image_path) cols[i % num_columns].image(image, caption=class_name, use_container_width=True) # ✅ Visualizations Page – Show Class Distribution elif page == "Visualizations": st.title("📈 Dataset Visualizations") if CLASS_NAMES: class_counts = {cls: len(os.listdir(os.path.join(DATASET_PATH, cls))) for cls in CLASS_NAMES} # Pie Chart with Proper Colors st.write("### Disease Distribution") fig, ax = plt.subplots() colors = sns.color_palette("husl", len(CLASS_NAMES)) # Generate unique colors ax.pie(class_counts.values(), labels=class_counts.keys(), autopct='%1.1f%%', colors=colors) st.pyplot(fig) # Bar Chart st.write("### Class Count") fig, ax = plt.subplots() sns.barplot(x=list(class_counts.keys()), y=list(class_counts.values()), palette="husl") plt.xticks(rotation=45) st.pyplot(fig) # ✅ Model Metrics Page elif page == "Model Metrics": st.title("📊 Model Performance") try: # Load True Labels and Predictions y_true = torch.load("y_true.pth", weights_only=False) y_pred = torch.load("y_pred.pth", weights_only=False) # Accuracy accuracy = accuracy_score(y_true, y_pred) st.write(f"### ✅ Accuracy: {accuracy:.2f}") # Classification Report st.write("### 📋 Classification Report") report = classification_report(y_true, y_pred, target_names=CLASS_NAMES, output_dict=True) st.write(pd.DataFrame(report).T) # Confusion Matrix st.write("### 🔀 Confusion Matrix") cm = confusion_matrix(y_true, y_pred) fig, ax = plt.subplots() sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES) st.pyplot(fig) except: st.error("🚨 Model metrics files (`y_true.pth` and `y_pred.pth`) not found!") # ✅ Disease Predictor Page elif page == "Disease Predictor": st.title("🌿 Plant Disease Classifier") # ✅ App Overview st.write(""" This app uses a MobileNet V2 model to classify plant diseases based on leaf images. Simply upload an image of a plant leaf, and the model will predict the disease or identify if the plant is healthy. ### 🏷️ Supported Plant Diseases: #### 🌶️ **Pepper Bell** - Bacterial Spot - Healthy #### 🥔 **Potato** - Early Blight - Late Blight - Healthy #### 🍅 **Tomato** - Bacterial Spot - Early Blight - Late Blight - Leaf Mold - Septoria Leaf Spot - Spider Mites (Two-Spotted Spider Mite) - Target Spot - Tomato Yellow Leaf Curl Virus - Tomato Mosaic Virus - Healthy """) # ✅ File Upload uploaded_file = st.file_uploader("Upload a plant leaf image", type=["jpg", "png", "jpeg"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_container_width=True) # ✅ Transform Image transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) image_tensor = transform(image).unsqueeze(0) # ✅ Predict Disease with torch.no_grad(): output = model(image_tensor) predicted_class = torch.argmax(output, dim=1).item() st.write(f"### ✅ Prediction: **{CLASS_NAMES[predicted_class]}**") st.success("✔ The prediction is based on a trained deep learning model.")