Tumor_Detection / app.py
DHEIVER's picture
Update app.py
e1af9e0 verified
raw
history blame
2.64 kB
from PIL import Image, ImageOps
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from keras.models import load_model
import gradio as gr
# Load the model and class names outside the prediction function
model = load_model('keras_model.h5', compile=False)
class_names = [line.strip() for line in open('labels.txt', 'r')]
def create_plot(data):
sns.set_theme(style="whitegrid")
f, ax = plt.subplots(figsize=(5, 5))
sns.set_color_codes("pastel")
sns.barplot(x="Total", y="Labels", data=data, label="Total", color="b")
sns.set_color_codes("muted")
sns.barplot(x="Confidence Score", y="Labels", data=data, label="Conficence Score", color="b")
ax.legend(ncol=2, loc="lower right", frameon=True)
sns.despine(left=True, bottom=True)
return f
def predict_tumor(img):
np.set_printoptions(suppress=True)
data = np.ndarray(shape=(1, 224, 224, 3), dtype=np.float32)
# Resize and preprocess the input image
size = (224, 224)
image_PIL = Image.fromarray(img)
image = ImageOps.fit(image_PIL, size, Image.LANCZOS)
image_array = np.asarray(image)
normalized_image_array = (image_array.astype(np.float32) / 127.0) - 1
data[0] = normalized_image_array
# Make a prediction
prediction = model.predict(data)
index = np.argmax(prediction)
class_name = class_names[index]
confidence_score = prediction[0][index]
c_name = class_name.strip()
tumor_prediction = f"Model {'detected' if c_name == 'Yes' else 'did not detect'} Tumor"
other_class = 'No' if c_name == 'Yes' else 'Yes'
# Prepare data for plotting
res = {"Labels": [c_name, other_class], "Confidence Score": [(confidence_score * 100), (1 - confidence_score) * 100], "Total": 100}
data_for_plot = pd.DataFrame.from_dict(res)
tumor_conf_plt = create_plot(data_for_plot)
return tumor_prediction, tumor_conf_plt
# Gradio Interface
with gr.Blocks(title="Brain Tumor Detection | Data Science Dojo", css="styles.css") as demo:
with gr.Row():
with gr.Column(scale=4):
with gr.Row():
imgInput = gr.Image()
with gr.Column(scale=1):
tumor = gr.Textbox(label='Presence of Tumor')
plot = gr.Plot(label="Plot")
submit_button = gr.Button(value="Submit")
submit_button.click(fn=predict_tumor, inputs=[imgInput], outputs=[tumor, plot])
gr.Examples(
examples=["pred2.jpg", "pred3.jpg"],
inputs=imgInput,
outputs=[tumor, plot],
fn=predict_tumor,
cache_examples=True,
)
demo.launch()