shaheer-data's picture
Update app.py
0953fe9 verified
raw
history blame
2.72 kB
import streamlit as st
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from PIL import Image
import requests
import tempfile
import os
def download_model_from_gdrive():
"""Downloads the model from Google Drive."""
gdrive_url = "https://drive.google.com/uc?id=1EnokggrC6ymrSibtj2t7IHWb9QVtrehS"
output_path = "final_meta_model.keras"
with requests.get(gdrive_url, stream=True) as r:
r.raise_for_status()
with open(output_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
return output_path
def load_or_download_model():
"""Loads the model from Hugging Face or Google Drive."""
hf_model_path = "final_meta_model.keras"
# Check if model exists locally (from Hugging Face Space storage)
if os.path.exists(hf_model_path):
model = load_model(hf_model_path)
else:
st.write("Model not found locally. Downloading from Google Drive...")
model_path = download_model_from_gdrive()
model = load_model(model_path)
return model
def preprocess_image(image):
"""Preprocesses the uploaded image for model prediction."""
image = image.resize((224, 224)) # Assuming input size of 224x224 for the model
image_array = np.array(image)
image_array = image_array / 255.0 # Normalize pixel values to [0, 1]
image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
return image_array
def predict_severity(model, image):
"""Predicts the severity using the model."""
predictions = model.predict(image)
class_names = ['0', 'MR', 'MRMS', 'MS', 'R', 'S']
predicted_class = class_names[np.argmax(predictions)]
return predicted_class, predictions
# Streamlit App
st.title("Disease Severity Prediction App")
st.write("Upload an image to predict the severity of the disease.")
# Image Upload
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_image:
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
# Load model
with st.spinner("Loading model..."):
model = load_or_download_model()
# Preprocess image
image = Image.open(uploaded_image)
preprocessed_image = preprocess_image(image)
# Predict severity
with st.spinner("Predicting severity..."):
predicted_class, prediction_scores = predict_severity(model, preprocessed_image)
# Display results
st.success(f"Predicted Class: {predicted_class}")
st.write("Prediction Scores:")
for class_name, score in zip(['0', 'MR', 'MRMS', 'MS', 'R', 'S'], prediction_scores[0]):
st.write(f"{class_name}: {score:.4f}")