Lucky Sharma
Update app.py
54e24db verified
raw
history blame contribute delete
923 Bytes
import joblib
import numpy as np
from sklearn import datasets
import gradio as gr
import os
# Load the model and class names
if os.path.exists('new_iris_model.pkl'):
model = joblib.load('new_iris_model.pkl')
else:
model = joblib.load('iris_model.pkl')
iris = datasets.load_iris()
class_names = iris.target_names
# Define prediction function
def predict_species(sepal_length, sepal_width, petal_length, petal_width):
features = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
prediction = model.predict(features)[0]
return f"Iris {class_names[prediction]}"
# Create Gradio interface
demo = gr.Interface(
fn=predict_species,
inputs=[
gr.Number(label="Sepal Length"),
gr.Number(label="Sepal Width"),
gr.Number(label="Petal Length"),
gr.Number(label="Petal Width")
],
outputs="text",
title="Iris Flower Classifier"
)
demo.launch()