Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import os | |
import joblib | |
from sklearn.model_selection import train_test_split | |
from sklearn.preprocessing import StandardScaler | |
from tensorflow import keras | |
# ------------------------- | |
# Streamlit Page Config | |
# ------------------------- | |
st.set_page_config(page_title="Fertility Prediction", layout="wide") | |
st.title("๐งฌ Fertility Health Prediction App") | |
# ------------------------- | |
# Sidebar Navigation | |
# ------------------------- | |
page = st.sidebar.radio("๐ Navigate", ["๐ EDA", "๐ค Model Training", "๐ฎ Prediction"]) | |
# ------------------------- | |
# Load Data | |
# ------------------------- | |
def load_data(): | |
df=pd.read_csv("fertility_synthetic_50000.csv") | |
df.drop_duplicates(inplace=True) | |
return df | |
df = load_data() | |
# ------------------------- | |
# EDA Page | |
# ------------------------- | |
if page == "๐ EDA": | |
st.header("๐ Exploratory Data Analysis") | |
st.subheader("๐ Dataset Overview") | |
st.write(f"๐๏ธ Shape of dataset: {df.shape}") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.write("๐ First 5 rows:") | |
st.dataframe(df.head()) | |
with col2: | |
st.write("๐ Basic statistics:") | |
st.dataframe(df.describe()) | |
st.subheader("โ Missing Values") | |
st.write(df.isna().sum()) | |
st.subheader("๐ Data Visualization") | |
# Target vs Sperm Count | |
fig, ax = plt.subplots(figsize=(6,4)) | |
sns.barplot(data=df, x='Target_HealthyOffspring', | |
y='Male_SpermCount_million_per_mL', | |
estimator=np.mean, ax=ax, palette="viridis") | |
ax.set_title('๐ฏ Target vs Male Sperm Count (mean)') | |
st.pyplot(fig) | |
# Correlation heatmap | |
st.write("๐ก๏ธ Correlation Heatmap:") | |
fig, ax = plt.subplots(figsize=(12,10)) | |
sns.heatmap(df.corr(numeric_only=True), annot=False, cmap="coolwarm", ax=ax) | |
st.pyplot(fig) | |
# ------------------------- | |
# Model Training Page | |
# ------------------------- | |
elif page == "๐ค Model Training": | |
st.header("โ๏ธ Model Training") | |
# Prepare data | |
X = df.drop("Target_HealthyOffspring", axis=1) | |
y = df["Target_HealthyOffspring"] | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=29) | |
# Scale data | |
scaler = StandardScaler() | |
X_train_scaled = scaler.fit_transform(X_train) | |
X_test_scaled = scaler.transform(X_test) | |
# Ensure models directory | |
os.makedirs("models", exist_ok=True) | |
joblib.dump(scaler, "models/fertility_scaler.pkl") | |
# Model architecture | |
model = keras.Sequential([ | |
keras.layers.Input(shape=(X_train.shape[1],)), | |
keras.layers.Dense(7, activation="relu"), | |
keras.layers.Dense(5, activation="relu"), | |
keras.layers.Dense(4, activation="relu"), | |
keras.layers.Dense(2, activation="softmax") | |
]) | |
model.compile(loss="sparse_categorical_crossentropy", | |
optimizer="adam", | |
metrics=["accuracy"]) | |
# Train model | |
if st.button("๐ Train Model"): | |
with st.spinner("โณ Training in progress..."): | |
history = model.fit(X_train_scaled, y_train, epochs=10, validation_split=0.2, verbose=1) | |
# Save model | |
model.save("models/fertility_model.h5") | |
st.success("โ Model trained and saved successfully!") | |
# Plot training history | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) | |
ax1.plot(history.history["accuracy"], label="Training Accuracy", color="green") | |
ax1.plot(history.history["val_accuracy"], label="Validation Accuracy", color="blue") | |
ax1.set_title("๐ Accuracy") | |
ax1.set_xlabel("Epochs") | |
ax1.set_ylabel("Accuracy") | |
ax1.legend() | |
ax2.plot(history.history["loss"], label="Training Loss", color="red") | |
ax2.plot(history.history["val_loss"], label="Validation Loss", color="orange") | |
ax2.set_title("๐ Loss") | |
ax2.set_xlabel("Epochs") | |
ax2.set_ylabel("Loss") | |
ax2.legend() | |
st.pyplot(fig) | |
# Evaluate on test set | |
test_loss, test_acc = model.evaluate(X_test_scaled, y_test, verbose=0) | |
st.metric("๐งช Test Accuracy", f"{test_acc:.4f}") | |
st.metric("๐ Test Loss", f"{test_loss:.4f}") | |
# ------------------------- | |
# Prediction Page | |
# ------------------------- | |
elif page == "๐ฎ Prediction": | |
st.header("๐ฎ Make a Prediction") | |
try: | |
model = keras.models.load_model("models/fertility_model.h5") | |
scaler = joblib.load("models/fertility_scaler.pkl") | |
st.success("๐ Model & Scaler loaded successfully!") | |
except: | |
st.error("โ Model not found. Please train it first under 'Model Training'.") | |
st.stop() | |
# Create input form | |
with st.form("prediction_form"): | |
st.subheader("๐งพ Enter Patient Details") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.markdown("**๐จ Male Factors**") | |
male_sperm_count = st.number_input("Sperm Count (million/mL)", min_value=0.0, value=15.0) | |
male_sperm_motility = st.number_input("Sperm Motility (%)", min_value=0.0, max_value=100.0, value=40.0) | |
male_sperm_morphology = st.number_input("Sperm Morphology (%)", min_value=0.0, max_value=100.0, value=4.0) | |
male_testosterone = st.number_input("Testosterone (ng/dL)", min_value=0.0, value=300.0) | |
male_fsh = st.number_input("Male FSH (mIU/mL)", min_value=0.0, value=1.5) | |
with col2: | |
st.markdown("**๐ฉ Female Factors**") | |
female_age = st.number_input("Female Age (years)", min_value=18, max_value=50, value=30) | |
female_ovulation = st.number_input("Ovulation Regularity (days)", min_value=0, value=28) | |
female_estradiol = st.number_input("Estradiol (pg/mL)", min_value=0.0, value=20.0) | |
female_progesterone = st.number_input("Progesterone (ng/mL)", min_value=0.0, value=10.0) | |
female_fsh = st.number_input("Female FSH (mIU/mL)", min_value=0.0, value=3.0) | |
st.markdown("**๐ Lifestyle Factors**") | |
col3, col4 = st.columns(2) | |
with col3: | |
intercourse_freq = st.number_input("Intercourse Frequency (per week)", min_value=0, value=2) | |
folic_acid = st.number_input("Folic Acid Intake (mcg/day)", min_value=0, value=400) | |
with col4: | |
smoking = st.number_input("Cigarettes per day", min_value=0, value=0) | |
alcohol = st.number_input("Alcoholic drinks per week", min_value=0, value=0) | |
hba1c = st.number_input("HbA1c (%)", min_value=0.0, max_value=20.0, value=5.0) | |
submitted = st.form_submit_button("โจ Predict") | |
if submitted: | |
input_data = np.array([[male_sperm_count, male_sperm_motility, male_sperm_morphology, | |
male_testosterone, male_fsh, female_age, female_ovulation, | |
female_estradiol, female_progesterone, female_fsh, | |
intercourse_freq, folic_acid, smoking, alcohol, hba1c]]) | |
scaled_input = scaler.transform(input_data) | |
prediction = model.predict(scaled_input) | |
predicted_class = np.argmax(prediction, axis=1) | |
confidence = np.max(prediction) * 100 | |
st.subheader("๐ Prediction Results") | |
if predicted_class[0] == 1: | |
st.success(f"โ Likely to have healthy offspring (Confidence: {confidence:.2f}%)") | |
else: | |
st.error(f"โ Unlikely to have healthy offspring (Confidence: {confidence:.2f}%)") | |
st.progress(int(confidence)) | |
# Probability distribution | |
fig, ax = plt.subplots(figsize=(6, 4)) | |
ax.bar(['โ Unlikely (0)', 'โ Likely (1)'], prediction[0], | |
color=['crimson', 'seagreen']) | |
ax.set_title('๐ Prediction Probability Distribution') | |
ax.set_ylabel('Probability') | |
st.pyplot(fig) | |