Rejeno's picture
Update app.py
fb1c47e verified
import streamlit as st
import os
import zipfile
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from PIL import Image
from torchvision import datasets
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import random
# βœ… Automatically unzip `train.zip` if `train/` folder is missing
DATASET_PATH = "train"
ZIP_FILE = "train.zip"
if not os.path.exists(DATASET_PATH):
if os.path.exists(ZIP_FILE):
with zipfile.ZipFile(ZIP_FILE, 'r') as zip_ref:
zip_ref.extractall(".") # Extract to current directory
# βœ… Load Class Names from Dataset
if os.path.exists(DATASET_PATH):
CLASS_NAMES = sorted(os.listdir(DATASET_PATH))
else:
CLASS_NAMES = []
# βœ… Load Model
@st.cache_resource
def load_model():
model = models.mobilenet_v2(pretrained=False)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, len(CLASS_NAMES))
model.load_state_dict(torch.load("plant_disease_model.pth", map_location=torch.device("cpu")))
model.eval()
return model
model = load_model()
# βœ… Sidebar Navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Dataset", "Visualizations", "Model Metrics", "Disease Predictor"])
# βœ… Dataset Details
DATASET_NAME = "PlantVillage"
DATASET_SOURCE = "Kaggle"
DATASET_LINK = "https://www.kaggle.com/datasets/emmarex/plantdisease"
# βœ… Dataset Page – Show Sample Images
if page == "Dataset":
st.title("πŸ“Š Dataset Preview")
# πŸ”Ή Display Dataset Information
st.markdown(f"""
**🌱 Dataset: {DATASET_NAME}**
- πŸ“Œ **Source:** [{DATASET_SOURCE}]({DATASET_LINK})
- 🏷️ **Total Classes:** {len(CLASS_NAMES)}
- πŸ“‚ **Description:** This dataset contains images of healthy and diseased leaves for various plants,
helping in plant disease classification.
""")
# πŸ”Ή Show images for all classes
num_columns = 3 # Adjust columns for better layout
cols = st.columns(num_columns)
for i, class_name in enumerate(CLASS_NAMES): # Show all classes
class_path = os.path.join(DATASET_PATH, class_name)
if os.path.exists(class_path) and os.listdir(class_path):
image_name = random.choice(os.listdir(class_path))
image_path = os.path.join(class_path, image_name)
image = Image.open(image_path)
cols[i % num_columns].image(image, caption=class_name, use_container_width=True)
# βœ… Visualizations Page – Show Class Distribution
elif page == "Visualizations":
st.title("πŸ“ˆ Dataset Visualizations")
if CLASS_NAMES:
class_counts = {cls: len(os.listdir(os.path.join(DATASET_PATH, cls))) for cls in CLASS_NAMES}
# Pie Chart with Proper Colors
st.write("### Disease Distribution")
fig, ax = plt.subplots()
colors = sns.color_palette("husl", len(CLASS_NAMES)) # Generate unique colors
ax.pie(class_counts.values(), labels=class_counts.keys(), autopct='%1.1f%%', colors=colors)
st.pyplot(fig)
# Bar Chart
st.write("### Class Count")
fig, ax = plt.subplots()
sns.barplot(x=list(class_counts.keys()), y=list(class_counts.values()), palette="husl")
plt.xticks(rotation=45)
st.pyplot(fig)
# βœ… Model Metrics Page
elif page == "Model Metrics":
st.title("πŸ“Š Model Performance")
try:
# Load True Labels and Predictions
y_true = torch.load("y_true.pth", weights_only=False)
y_pred = torch.load("y_pred.pth", weights_only=False)
# Accuracy
accuracy = accuracy_score(y_true, y_pred)
st.write(f"### βœ… Accuracy: {accuracy:.2f}")
# Classification Report
st.write("### πŸ“‹ Classification Report")
report = classification_report(y_true, y_pred, target_names=CLASS_NAMES, output_dict=True)
st.write(pd.DataFrame(report).T)
# Confusion Matrix
st.write("### πŸ”€ Confusion Matrix")
cm = confusion_matrix(y_true, y_pred)
fig, ax = plt.subplots()
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
st.pyplot(fig)
except:
st.error("🚨 Model metrics files (`y_true.pth` and `y_pred.pth`) not found!")
# βœ… Disease Predictor Page
elif page == "Disease Predictor":
st.title("🌿 Plant Disease Classifier")
# βœ… App Overview
st.write("""
This app uses a MobileNet V2 model to classify plant diseases based on leaf images.
Simply upload an image of a plant leaf, and the model will predict the disease or identify if the plant is healthy.
### 🏷️ Supported Plant Diseases:
#### 🌢️ **Pepper Bell**
- Bacterial Spot
- Healthy
#### πŸ₯” **Potato**
- Early Blight
- Late Blight
- Healthy
#### πŸ… **Tomato**
- Bacterial Spot
- Early Blight
- Late Blight
- Leaf Mold
- Septoria Leaf Spot
- Spider Mites (Two-Spotted Spider Mite)
- Target Spot
- Tomato Yellow Leaf Curl Virus
- Tomato Mosaic Virus
- Healthy
""")
# βœ… File Upload
uploaded_file = st.file_uploader("Upload a plant leaf image", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_container_width=True)
# βœ… Transform Image
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
image_tensor = transform(image).unsqueeze(0)
# βœ… Predict Disease
with torch.no_grad():
output = model(image_tensor)
predicted_class = torch.argmax(output, dim=1).item()
st.write(f"### βœ… Prediction: **{CLASS_NAMES[predicted_class]}**")
st.success("βœ” The prediction is based on a trained deep learning model.")