Regino
commited on
Commit
Β·
1773e4e
1
Parent(s):
138a538
shdvfsdj
Browse files- splitdata.py +37 -0
- train.py +75 -0
splitdata.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import random
|
4 |
+
|
5 |
+
# β
Define paths
|
6 |
+
data_dir = "PlantVillage"
|
7 |
+
train_dir = "dataset/train"
|
8 |
+
test_dir = "dataset/test"
|
9 |
+
split_ratio = 0.8 # 80% train, 20% test
|
10 |
+
|
11 |
+
# β
Ensure train/test folders exist
|
12 |
+
os.makedirs(train_dir, exist_ok=True)
|
13 |
+
os.makedirs(test_dir, exist_ok=True)
|
14 |
+
|
15 |
+
# β
Split dataset
|
16 |
+
for category in os.listdir(data_dir): # Loop through plant disease categories
|
17 |
+
category_path = os.path.join(data_dir, category)
|
18 |
+
|
19 |
+
if os.path.isdir(category_path): # Ensure it's a folder
|
20 |
+
images = os.listdir(category_path)
|
21 |
+
random.shuffle(images) # Shuffle to ensure randomness
|
22 |
+
|
23 |
+
split_index = int(len(images) * split_ratio)
|
24 |
+
train_images = images[:split_index]
|
25 |
+
test_images = images[split_index:]
|
26 |
+
|
27 |
+
# β
Create category folders
|
28 |
+
os.makedirs(os.path.join(train_dir, category), exist_ok=True)
|
29 |
+
os.makedirs(os.path.join(test_dir, category), exist_ok=True)
|
30 |
+
|
31 |
+
# β
Move images
|
32 |
+
for img in train_images:
|
33 |
+
shutil.move(os.path.join(category_path, img), os.path.join(train_dir, category, img))
|
34 |
+
for img in test_images:
|
35 |
+
shutil.move(os.path.join(category_path, img), os.path.join(test_dir, category, img))
|
36 |
+
|
37 |
+
print("β
Dataset successfully split into train/test!")
|
train.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
import torchvision.datasets as datasets
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torchvision import models
|
9 |
+
from tqdm import tqdm # β
Progress bar
|
10 |
+
|
11 |
+
# β
Define dataset paths
|
12 |
+
train_dir = "dataset/train"
|
13 |
+
test_dir = "dataset/test"
|
14 |
+
|
15 |
+
# β
Optimized Transformations (Smaller image size)
|
16 |
+
transform = transforms.Compose([
|
17 |
+
transforms.Resize((128, 128)), # Reduce size for faster training
|
18 |
+
transforms.RandomHorizontalFlip(),
|
19 |
+
transforms.ToTensor(),
|
20 |
+
transforms.Normalize([0.5], [0.5])
|
21 |
+
])
|
22 |
+
|
23 |
+
# β
Load datasets
|
24 |
+
train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
|
25 |
+
test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)
|
26 |
+
|
27 |
+
# β
Get class names
|
28 |
+
class_names = train_dataset.classes
|
29 |
+
print(f"Class Names: {class_names}")
|
30 |
+
|
31 |
+
# β
Save class names for later use in `app.py`
|
32 |
+
with open("class_names.txt", "w") as f:
|
33 |
+
for name in class_names:
|
34 |
+
f.write(name + "\n")
|
35 |
+
|
36 |
+
# β
Optimized DataLoaders (Smaller batch size)
|
37 |
+
batch_size = 16 # Reduce batch size for speed
|
38 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
39 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
40 |
+
|
41 |
+
# β
Use a Faster Model (MobileNetV2)
|
42 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
43 |
+
model = models.mobilenet_v2(pretrained=True)
|
44 |
+
model.classifier[1] = nn.Linear(model.classifier[1].in_features, len(class_names))
|
45 |
+
model = model.to(device)
|
46 |
+
|
47 |
+
# β
Define Loss Function & Optimizer
|
48 |
+
criterion = nn.CrossEntropyLoss()
|
49 |
+
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
50 |
+
|
51 |
+
# β
Train the Model with Progress Bar
|
52 |
+
num_epochs = 3 # Reduce epochs for faster training
|
53 |
+
for epoch in range(num_epochs):
|
54 |
+
model.train()
|
55 |
+
running_loss = 0.0
|
56 |
+
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
|
57 |
+
|
58 |
+
for images, labels in progress_bar:
|
59 |
+
images, labels = images.to(device), labels.to(device)
|
60 |
+
|
61 |
+
optimizer.zero_grad()
|
62 |
+
outputs = model(images)
|
63 |
+
loss = criterion(outputs, labels)
|
64 |
+
loss.backward()
|
65 |
+
optimizer.step()
|
66 |
+
|
67 |
+
running_loss += loss.item()
|
68 |
+
progress_bar.set_postfix(loss=f"{running_loss/len(train_loader):.4f}")
|
69 |
+
|
70 |
+
print(f"β
Epoch [{epoch+1}/{num_epochs}] - Loss: {running_loss/len(train_loader):.4f}")
|
71 |
+
|
72 |
+
# β
Save the Trained Model
|
73 |
+
torch.save(model.state_dict(), "plant_disease_model.pth")
|
74 |
+
print("β
Model training complete and saved as plant_disease_model.pth")
|
75 |
+
|