File size: 3,999 Bytes
138a538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import streamlit as st
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from torchvision import datasets
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import random

# βœ… Sidebar Navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Dataset", "Visualizations", "Model Metrics", "Disease Predictor"])

# βœ… Dataset Path
DATASET_PATH = "dataset/train"  # Update if needed
CLASS_NAMES = os.listdir(DATASET_PATH)  # Get class names from folder structure

# βœ… Load Model
@st.cache_resource
def load_model():
    model = models.mobilenet_v2(pretrained=False)
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, len(CLASS_NAMES))
    model.load_state_dict(torch.load("plant_disease_model.pth", map_location=torch.device("cpu")))
    model.eval()
    return model

model = load_model()

# βœ… Dataset Page – Show Sample Images
if page == "Dataset":
    st.title("πŸ“Š Dataset Preview")
    st.write(f"### Classes: {CLASS_NAMES}")

    # Show sample images from each class
    cols = st.columns(4)
    for i, class_name in enumerate(CLASS_NAMES[:4]):  # Show 4 classes
        class_path = os.path.join(DATASET_PATH, class_name)
        image_name = random.choice(os.listdir(class_path))
        image_path = os.path.join(class_path, image_name)
        image = Image.open(image_path)
        cols[i].image(image, caption=class_name, use_column_width=True)

# βœ… Visualizations Page – Show Class Distribution
elif page == "Visualizations":
    st.title("πŸ“ˆ Dataset Visualizations")

    # Count images per class
    class_counts = {cls: len(os.listdir(os.path.join(DATASET_PATH, cls))) for cls in CLASS_NAMES}
    
    # Pie Chart
    st.write("### Disease Distribution")
    fig, ax = plt.subplots()
    ax.pie(class_counts.values(), labels=class_counts.keys(), autopct='%1.1f%%', colors=plt.cm.viridis.colors)
    st.pyplot(fig)

    # Bar Chart
    st.write("### Class Count")
    fig, ax = plt.subplots()
    sns.barplot(x=list(class_counts.keys()), y=list(class_counts.values()), palette="viridis")
    plt.xticks(rotation=45)
    st.pyplot(fig)

# βœ… Model Metrics Page
elif page == "Model Metrics":
    st.title("πŸ“Š Model Performance")

    # Load True Labels and Predictions
    y_true = torch.load("y_true.pth")
    y_pred = torch.load("y_pred.pth")

    # Accuracy
    accuracy = accuracy_score(y_true, y_pred)
    st.write(f"### βœ… Accuracy: {accuracy:.2f}")

    # Classification Report
    st.write("### πŸ“‹ Classification Report")
    report = classification_report(y_true, y_pred, target_names=CLASS_NAMES, output_dict=True)
    st.write(pd.DataFrame(report).T)

    # Confusion Matrix
    st.write("### πŸ”€ Confusion Matrix")
    cm = confusion_matrix(y_true, y_pred)
    fig, ax = plt.subplots()
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
    st.pyplot(fig)

# βœ… Disease Predictor Page
elif page == "Disease Predictor":
    st.title("🌿 Plant Disease Classifier")
    
    # File Upload
    uploaded_file = st.file_uploader("Upload a plant leaf image", type=["jpg", "png", "jpeg"])
    
    if uploaded_file is not None:
        image = Image.open(uploaded_file)
        st.image(image, caption="Uploaded Image", use_column_width=True)

        # Transform Image
        transform = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
        
        image_tensor = transform(image).unsqueeze(0)

        # Predict Disease
        with torch.no_grad():
            output = model(image_tensor)
            predicted_class = torch.argmax(output, dim=1).item()

        st.write(f"### βœ… Prediction: {CLASS_NAMES[predicted_class]}")