Regino commited on
Commit
2763fea
Β·
2 Parent(s): 62d9225 7db6979

jsdbjksdbvf

Browse files
Files changed (1) hide show
  1. app.py +68 -50
app.py CHANGED
@@ -1,24 +1,33 @@
1
  import streamlit as st
2
  import os
 
3
  import torch
4
  import torch.nn as nn
5
  import torchvision.transforms as transforms
6
  import torchvision.models as models
7
  import matplotlib.pyplot as plt
8
  import seaborn as sns
 
9
  from PIL import Image
10
  from torchvision import datasets
11
  from torch.utils.data import DataLoader
12
  from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
13
  import random
14
 
15
- # βœ… Sidebar Navigation
16
- st.sidebar.title("Navigation")
17
- page = st.sidebar.radio("Go to", ["Dataset", "Visualizations", "Model Metrics", "Disease Predictor"])
 
 
 
 
 
18
 
19
- # βœ… Dataset Path
20
- DATASET_PATH = "train" # Update if needed
21
- CLASS_NAMES = os.listdir(DATASET_PATH) # Get class names from folder structure
 
 
22
 
23
  # βœ… Load Model
24
  @st.cache_resource
@@ -31,71 +40,80 @@ def load_model():
31
 
32
  model = load_model()
33
 
 
 
 
 
34
  # βœ… Dataset Page – Show Sample Images
35
  if page == "Dataset":
36
  st.title("πŸ“Š Dataset Preview")
37
  st.write(f"### Classes: {CLASS_NAMES}")
38
 
39
- # Show sample images from each class
40
- cols = st.columns(4)
41
- for i, class_name in enumerate(CLASS_NAMES[:4]): # Show 4 classes
42
- class_path = os.path.join(DATASET_PATH, class_name)
43
- image_name = random.choice(os.listdir(class_path))
44
- image_path = os.path.join(class_path, image_name)
45
- image = Image.open(image_path)
46
- cols[i].image(image, caption=class_name, use_column_width=True)
 
47
 
48
  # βœ… Visualizations Page – Show Class Distribution
49
  elif page == "Visualizations":
50
  st.title("πŸ“ˆ Dataset Visualizations")
51
 
52
- # Count images per class
53
- class_counts = {cls: len(os.listdir(os.path.join(DATASET_PATH, cls))) for cls in CLASS_NAMES}
54
-
55
- # Pie Chart
56
- st.write("### Disease Distribution")
57
- fig, ax = plt.subplots()
58
- ax.pie(class_counts.values(), labels=class_counts.keys(), autopct='%1.1f%%', colors=plt.cm.viridis.colors)
59
- st.pyplot(fig)
60
-
61
- # Bar Chart
62
- st.write("### Class Count")
63
- fig, ax = plt.subplots()
64
- sns.barplot(x=list(class_counts.keys()), y=list(class_counts.values()), palette="viridis")
65
- plt.xticks(rotation=45)
66
- st.pyplot(fig)
 
67
 
68
  # βœ… Model Metrics Page
69
  elif page == "Model Metrics":
70
  st.title("πŸ“Š Model Performance")
71
 
72
- # Load True Labels and Predictions
73
- y_true = torch.load("y_true.pth")
74
- y_pred = torch.load("y_pred.pth")
75
-
76
- # Accuracy
77
- accuracy = accuracy_score(y_true, y_pred)
78
- st.write(f"### βœ… Accuracy: {accuracy:.2f}")
79
-
80
- # Classification Report
81
- st.write("### πŸ“‹ Classification Report")
82
- report = classification_report(y_true, y_pred, target_names=CLASS_NAMES, output_dict=True)
83
- st.write(pd.DataFrame(report).T)
84
-
85
- # Confusion Matrix
86
- st.write("### πŸ”€ Confusion Matrix")
87
- cm = confusion_matrix(y_true, y_pred)
88
- fig, ax = plt.subplots()
89
- sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
90
- st.pyplot(fig)
 
 
 
91
 
92
  # βœ… Disease Predictor Page
93
  elif page == "Disease Predictor":
94
  st.title("🌿 Plant Disease Classifier")
95
-
96
  # File Upload
97
  uploaded_file = st.file_uploader("Upload a plant leaf image", type=["jpg", "png", "jpeg"])
98
-
99
  if uploaded_file is not None:
100
  image = Image.open(uploaded_file)
101
  st.image(image, caption="Uploaded Image", use_column_width=True)
 
1
  import streamlit as st
2
  import os
3
+ import zipfile
4
  import torch
5
  import torch.nn as nn
6
  import torchvision.transforms as transforms
7
  import torchvision.models as models
8
  import matplotlib.pyplot as plt
9
  import seaborn as sns
10
+ import pandas as pd
11
  from PIL import Image
12
  from torchvision import datasets
13
  from torch.utils.data import DataLoader
14
  from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
15
  import random
16
 
17
+ # βœ… Automatically unzip `train.zip` if `train/` folder is missing
18
+ DATASET_PATH = "train"
19
+ ZIP_FILE = "train.zip"
20
+
21
+ if not os.path.exists(DATASET_PATH):
22
+ if os.path.exists(ZIP_FILE):
23
+ with zipfile.ZipFile(ZIP_FILE, 'r') as zip_ref:
24
+ zip_ref.extractall(".") # Extract to current directory
25
 
26
+ # βœ… Load Class Names from Dataset
27
+ if os.path.exists(DATASET_PATH):
28
+ CLASS_NAMES = sorted(os.listdir(DATASET_PATH))
29
+ else:
30
+ CLASS_NAMES = []
31
 
32
  # βœ… Load Model
33
  @st.cache_resource
 
40
 
41
  model = load_model()
42
 
43
+ # βœ… Sidebar Navigation
44
+ st.sidebar.title("Navigation")
45
+ page = st.sidebar.radio("Go to", ["Dataset", "Visualizations", "Model Metrics", "Disease Predictor"])
46
+
47
  # βœ… Dataset Page – Show Sample Images
48
  if page == "Dataset":
49
  st.title("πŸ“Š Dataset Preview")
50
  st.write(f"### Classes: {CLASS_NAMES}")
51
 
52
+ if CLASS_NAMES:
53
+ cols = st.columns(4)
54
+ for i, class_name in enumerate(CLASS_NAMES[:4]): # Show 4 classes
55
+ class_path = os.path.join(DATASET_PATH, class_name)
56
+ if os.path.exists(class_path):
57
+ image_name = random.choice(os.listdir(class_path))
58
+ image_path = os.path.join(class_path, image_name)
59
+ image = Image.open(image_path)
60
+ cols[i].image(image, caption=class_name, use_column_width=True)
61
 
62
  # βœ… Visualizations Page – Show Class Distribution
63
  elif page == "Visualizations":
64
  st.title("πŸ“ˆ Dataset Visualizations")
65
 
66
+ if CLASS_NAMES:
67
+ class_counts = {cls: len(os.listdir(os.path.join(DATASET_PATH, cls))) for cls in CLASS_NAMES}
68
+
69
+ # Pie Chart with Proper Colors
70
+ st.write("### Disease Distribution")
71
+ fig, ax = plt.subplots()
72
+ colors = sns.color_palette("husl", len(CLASS_NAMES)) # Generate unique colors
73
+ ax.pie(class_counts.values(), labels=class_counts.keys(), autopct='%1.1f%%', colors=colors)
74
+ st.pyplot(fig)
75
+
76
+ # Bar Chart
77
+ st.write("### Class Count")
78
+ fig, ax = plt.subplots()
79
+ sns.barplot(x=list(class_counts.keys()), y=list(class_counts.values()), palette="husl")
80
+ plt.xticks(rotation=45)
81
+ st.pyplot(fig)
82
 
83
  # βœ… Model Metrics Page
84
  elif page == "Model Metrics":
85
  st.title("πŸ“Š Model Performance")
86
 
87
+ try:
88
+ # Load True Labels and Predictions
89
+ y_true = torch.load("y_true.pth")
90
+ y_pred = torch.load("y_pred.pth")
91
+
92
+ # Accuracy
93
+ accuracy = accuracy_score(y_true, y_pred)
94
+ st.write(f"### βœ… Accuracy: {accuracy:.2f}")
95
+
96
+ # Classification Report
97
+ st.write("### πŸ“‹ Classification Report")
98
+ report = classification_report(y_true, y_pred, target_names=CLASS_NAMES, output_dict=True)
99
+ st.write(pd.DataFrame(report).T)
100
+
101
+ # Confusion Matrix
102
+ st.write("### πŸ”€ Confusion Matrix")
103
+ cm = confusion_matrix(y_true, y_pred)
104
+ fig, ax = plt.subplots()
105
+ sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
106
+ st.pyplot(fig)
107
+ except:
108
+ st.error("🚨 Model metrics files (`y_true.pth` and `y_pred.pth`) not found!")
109
 
110
  # βœ… Disease Predictor Page
111
  elif page == "Disease Predictor":
112
  st.title("🌿 Plant Disease Classifier")
113
+
114
  # File Upload
115
  uploaded_file = st.file_uploader("Upload a plant leaf image", type=["jpg", "png", "jpeg"])
116
+
117
  if uploaded_file is not None:
118
  image = Image.open(uploaded_file)
119
  st.image(image, caption="Uploaded Image", use_column_width=True)