|
import torch |
|
import torchvision.transforms as transforms |
|
from torchvision import models |
|
from PIL import Image |
|
|
|
|
|
with open("class_names.txt", "r") as f: |
|
class_names = [line.strip() for line in f.readlines()] |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = models.mobilenet_v2(pretrained=False) |
|
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, len(class_names)) |
|
model.load_state_dict(torch.load("plant_disease_model.pth", map_location=device)) |
|
model = model.to(device) |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((128, 128)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5], [0.5]) |
|
]) |
|
|
|
|
|
def predict_image(image_path): |
|
image = Image.open(image_path).convert("RGB") |
|
image = transform(image).unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
output = model(image) |
|
predicted_class = torch.argmax(output, dim=1).item() |
|
|
|
return class_names[predicted_class] |
|
|
|
|
|
if __name__ == "__main__": |
|
sample_image = "test_image.jpg" |
|
prediction = predict_image(sample_image) |
|
print(f"Predicted Class: {prediction}") |
|
|