Regino
new comit
138a538
raw
history blame
4 kB
import streamlit as st
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from torchvision import datasets
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import random
# βœ… Sidebar Navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Dataset", "Visualizations", "Model Metrics", "Disease Predictor"])
# βœ… Dataset Path
DATASET_PATH = "dataset/train" # Update if needed
CLASS_NAMES = os.listdir(DATASET_PATH) # Get class names from folder structure
# βœ… Load Model
@st.cache_resource
def load_model():
model = models.mobilenet_v2(pretrained=False)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, len(CLASS_NAMES))
model.load_state_dict(torch.load("plant_disease_model.pth", map_location=torch.device("cpu")))
model.eval()
return model
model = load_model()
# βœ… Dataset Page – Show Sample Images
if page == "Dataset":
st.title("πŸ“Š Dataset Preview")
st.write(f"### Classes: {CLASS_NAMES}")
# Show sample images from each class
cols = st.columns(4)
for i, class_name in enumerate(CLASS_NAMES[:4]): # Show 4 classes
class_path = os.path.join(DATASET_PATH, class_name)
image_name = random.choice(os.listdir(class_path))
image_path = os.path.join(class_path, image_name)
image = Image.open(image_path)
cols[i].image(image, caption=class_name, use_column_width=True)
# βœ… Visualizations Page – Show Class Distribution
elif page == "Visualizations":
st.title("πŸ“ˆ Dataset Visualizations")
# Count images per class
class_counts = {cls: len(os.listdir(os.path.join(DATASET_PATH, cls))) for cls in CLASS_NAMES}
# Pie Chart
st.write("### Disease Distribution")
fig, ax = plt.subplots()
ax.pie(class_counts.values(), labels=class_counts.keys(), autopct='%1.1f%%', colors=plt.cm.viridis.colors)
st.pyplot(fig)
# Bar Chart
st.write("### Class Count")
fig, ax = plt.subplots()
sns.barplot(x=list(class_counts.keys()), y=list(class_counts.values()), palette="viridis")
plt.xticks(rotation=45)
st.pyplot(fig)
# βœ… Model Metrics Page
elif page == "Model Metrics":
st.title("πŸ“Š Model Performance")
# Load True Labels and Predictions
y_true = torch.load("y_true.pth")
y_pred = torch.load("y_pred.pth")
# Accuracy
accuracy = accuracy_score(y_true, y_pred)
st.write(f"### βœ… Accuracy: {accuracy:.2f}")
# Classification Report
st.write("### πŸ“‹ Classification Report")
report = classification_report(y_true, y_pred, target_names=CLASS_NAMES, output_dict=True)
st.write(pd.DataFrame(report).T)
# Confusion Matrix
st.write("### πŸ”€ Confusion Matrix")
cm = confusion_matrix(y_true, y_pred)
fig, ax = plt.subplots()
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
st.pyplot(fig)
# βœ… Disease Predictor Page
elif page == "Disease Predictor":
st.title("🌿 Plant Disease Classifier")
# File Upload
uploaded_file = st.file_uploader("Upload a plant leaf image", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
# Transform Image
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
image_tensor = transform(image).unsqueeze(0)
# Predict Disease
with torch.no_grad():
output = model(image_tensor)
predicted_class = torch.argmax(output, dim=1).item()
st.write(f"### βœ… Prediction: {CLASS_NAMES[predicted_class]}")