import gradio as gr import numpy as np from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing import image import tensorflow as tf # Load the saved model model = load_model('acres-ppdc-01.keras') # Define the classes the model was trained on class_labels = ['Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy'] def classify_potato_plant(img): # Convert the image to a NumPy array for manipulation img = np.array(img) # Get the current width and height of the image h, w, _ = img.shape # Calculate the cropping coordinates to keep the center of the image if h > w: # If height is greater than width, crop the top and bottom start = (h - w) // 2 img = img[start:start + w, :, :] # Crop to width else: # If width is greater than height, crop the left and right start = (w - h) // 2 img = img[:, start:start + h, :] # Crop to height # Convert back to PIL image after cropping img = image.array_to_img(img) # Preprocess the image for the model img = img.resize((128, 128)) # Resize to the same size the model was trained on img = image.img_to_array(img) img = np.expand_dims(img, axis=0) img = img / 255.0 # Normalize the image # Make the prediction predictions = model.predict(img) predicted_class = np.argmax(predictions[0]) confidence = predictions[0][predicted_class] model_output = "None" if class_labels[predicted_class] == "Potato___Early_blight": model_output = "Early blight" elif class_labels[predicted_class] == "Potato___Late_blight": model_output = "Late blight" elif class_labels[predicted_class] == "Potato___healthy": model_output = "Healthy" return model_output, confidence # Create the Gradio interface interface = gr.Interface( fn=classify_potato_plant, inputs=gr.Image(type="pil"), outputs=[gr.Textbox(label="Predicted Output"), gr.Textbox(label="Confidence Score")], title="Acres - PPDC", description="Acres PPDC, is our Potato Plant Disease Classification vision model, capable of accurately classifying potato plant disease, based on a single image." ) # Launch the app if __name__ == "__main__": interface.launch()