Regino commited on
Commit
1773e4e
Β·
1 Parent(s): 138a538
Files changed (2) hide show
  1. splitdata.py +37 -0
  2. train.py +75 -0
splitdata.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import random
4
+
5
+ # βœ… Define paths
6
+ data_dir = "PlantVillage"
7
+ train_dir = "dataset/train"
8
+ test_dir = "dataset/test"
9
+ split_ratio = 0.8 # 80% train, 20% test
10
+
11
+ # βœ… Ensure train/test folders exist
12
+ os.makedirs(train_dir, exist_ok=True)
13
+ os.makedirs(test_dir, exist_ok=True)
14
+
15
+ # βœ… Split dataset
16
+ for category in os.listdir(data_dir): # Loop through plant disease categories
17
+ category_path = os.path.join(data_dir, category)
18
+
19
+ if os.path.isdir(category_path): # Ensure it's a folder
20
+ images = os.listdir(category_path)
21
+ random.shuffle(images) # Shuffle to ensure randomness
22
+
23
+ split_index = int(len(images) * split_ratio)
24
+ train_images = images[:split_index]
25
+ test_images = images[split_index:]
26
+
27
+ # βœ… Create category folders
28
+ os.makedirs(os.path.join(train_dir, category), exist_ok=True)
29
+ os.makedirs(os.path.join(test_dir, category), exist_ok=True)
30
+
31
+ # βœ… Move images
32
+ for img in train_images:
33
+ shutil.move(os.path.join(category_path, img), os.path.join(train_dir, category, img))
34
+ for img in test_images:
35
+ shutil.move(os.path.join(category_path, img), os.path.join(test_dir, category, img))
36
+
37
+ print("βœ… Dataset successfully split into train/test!")
train.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import torchvision.transforms as transforms
6
+ import torchvision.datasets as datasets
7
+ from torch.utils.data import DataLoader
8
+ from torchvision import models
9
+ from tqdm import tqdm # βœ… Progress bar
10
+
11
+ # βœ… Define dataset paths
12
+ train_dir = "dataset/train"
13
+ test_dir = "dataset/test"
14
+
15
+ # βœ… Optimized Transformations (Smaller image size)
16
+ transform = transforms.Compose([
17
+ transforms.Resize((128, 128)), # Reduce size for faster training
18
+ transforms.RandomHorizontalFlip(),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize([0.5], [0.5])
21
+ ])
22
+
23
+ # βœ… Load datasets
24
+ train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
25
+ test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)
26
+
27
+ # βœ… Get class names
28
+ class_names = train_dataset.classes
29
+ print(f"Class Names: {class_names}")
30
+
31
+ # βœ… Save class names for later use in `app.py`
32
+ with open("class_names.txt", "w") as f:
33
+ for name in class_names:
34
+ f.write(name + "\n")
35
+
36
+ # βœ… Optimized DataLoaders (Smaller batch size)
37
+ batch_size = 16 # Reduce batch size for speed
38
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
39
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
40
+
41
+ # βœ… Use a Faster Model (MobileNetV2)
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ model = models.mobilenet_v2(pretrained=True)
44
+ model.classifier[1] = nn.Linear(model.classifier[1].in_features, len(class_names))
45
+ model = model.to(device)
46
+
47
+ # βœ… Define Loss Function & Optimizer
48
+ criterion = nn.CrossEntropyLoss()
49
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
50
+
51
+ # βœ… Train the Model with Progress Bar
52
+ num_epochs = 3 # Reduce epochs for faster training
53
+ for epoch in range(num_epochs):
54
+ model.train()
55
+ running_loss = 0.0
56
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
57
+
58
+ for images, labels in progress_bar:
59
+ images, labels = images.to(device), labels.to(device)
60
+
61
+ optimizer.zero_grad()
62
+ outputs = model(images)
63
+ loss = criterion(outputs, labels)
64
+ loss.backward()
65
+ optimizer.step()
66
+
67
+ running_loss += loss.item()
68
+ progress_bar.set_postfix(loss=f"{running_loss/len(train_loader):.4f}")
69
+
70
+ print(f"βœ… Epoch [{epoch+1}/{num_epochs}] - Loss: {running_loss/len(train_loader):.4f}")
71
+
72
+ # βœ… Save the Trained Model
73
+ torch.save(model.state_dict(), "plant_disease_model.pth")
74
+ print("βœ… Model training complete and saved as plant_disease_model.pth")
75
+