Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import pickle | |
import joblib | |
import os | |
# Load the trained model (try both joblib and pickle in case one fails) | |
def load_model(): | |
try: | |
model = joblib.load('pcos_model.joblib') | |
print("Model loaded using joblib") | |
return model | |
except: | |
try: | |
with open('random_forest_model', 'rb') as file: | |
model = pickle.load(file) | |
print("Model loaded using pickle from random_forest_model") | |
return model | |
except: | |
try: | |
with open('random_forest_model.pkl', 'rb') as file: | |
model = pickle.load(file) | |
print("Model loaded using pickle from pcos_model.pkl") | |
return model | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
# Fallback to a simple model for demo purposes | |
from sklearn.ensemble import RandomForestClassifier | |
print("Creating a fallback model for demonstration") | |
fallback_model = RandomForestClassifier(n_estimators=100, random_state=42) | |
# Train with dummy data to initialize | |
X_dummy = np.random.rand(100, 43) | |
y_dummy = np.random.choice([0, 1], 100) | |
fallback_model.fit(X_dummy, y_dummy) | |
return fallback_model | |
# Load the model | |
model = load_model() | |
# Define the features required for prediction | |
features = [ | |
"Age (yrs)", "Weight (Kg)", "Height(Cm)", "BMI", "Blood Group", "Pulse rate(bpm)", | |
"RR (breaths/min)", "Hb(g/dl)", "Cycle length(days)", "Cycle(R/I)", "Marraige Status (Yrs)", | |
"Pregnant(Y/N)", "No. of abortions", "Hip(inch)", "Waist(inch)", "Waist:Hip Ratio", | |
"Weight gain(Y/N)", "hair growth(Y/N)", "Skin darkening (Y/N)", "Hair loss(Y/N)", | |
"Pimples(Y/N)", "Fast food (Y/N)", "Reg.Exercise(Y/N)", "BP _Systolic (mmHg)", | |
"BP _Diastolic (mmHg)", "Follicle No. (L)", "Follicle No. (R)", "Avg. F size (L) (mm)", | |
"Avg. F size (R) (mm)", "Endometrium (mm)", "FSH(mIU/mL)", "LH(mIU/mL)", "FSH/LH", | |
"Hip:Waist Ratio", "TSH (mIU/L)", "AMH(ng/mL)", "PRL(ng/mL)", "Vit D3 (ng/mL)", | |
"PRG(ng/mL)", "RBS(mg/dl)", "Weight gain", "I beta-HCG(mIU/mL)", "II beta-HCG(mIU/mL)" | |
] | |
# Create visualizations for the dashboard | |
def create_visualizations(): | |
# For demo purposes, we'll use sample data similar to what was in your notebook | |
# In a real application, you would load the actual dataset | |
# Sample data for visualization (small dataset for demo) | |
np.random.seed(42) | |
n_samples = 100 | |
# Create sample data | |
sample_data = { | |
"Age (yrs)": np.random.normal(25, 5, n_samples), | |
"PCOS (Y/N)": np.random.choice([0, 1], n_samples, p=[0.6, 0.4]), | |
"BMI": np.random.normal(25, 5, n_samples), | |
"Cycle length(days)": np.random.normal(28, 5, n_samples), | |
"Follicle No. (L)": np.random.normal(12, 5, n_samples), | |
"Follicle No. (R)": np.random.normal(12, 5, n_samples), | |
"Endometrium (mm)": np.random.normal(8, 2, n_samples), | |
"Cycle(R/I)": np.random.choice([2, 4], n_samples), | |
"Weight (Kg)": np.random.normal(65, 10, n_samples), | |
"Hb(g/dl)": np.random.normal(12, 1.5, n_samples) | |
} | |
# Create a DataFrame | |
df = pd.DataFrame(sample_data) | |
# For PCOS cases, adjust the values to show differences | |
pcos_indices = df["PCOS (Y/N)"] == 1 | |
df.loc[pcos_indices, "BMI"] += 2 | |
df.loc[pcos_indices, "Cycle length(days)"] += 5 | |
df.loc[pcos_indices, "Follicle No. (L)"] += 8 | |
df.loc[pcos_indices, "Follicle No. (R)"] += 7 | |
df.loc[pcos_indices, "Cycle(R/I)"] = 4 | |
# Create visualizations | |
visualizations = [] | |
# 1. BMI vs Age scatter plot | |
fig1, ax1 = plt.subplots(figsize=(8, 6)) | |
sns.scatterplot(x="Age (yrs)", y="BMI", hue="PCOS (Y/N)", | |
data=df, palette=["teal", "plum"], ax=ax1) | |
ax1.set_title("BMI vs Age by PCOS Status") | |
visualizations.append(fig1) | |
# 2. Cycle length vs Age scatter plot | |
fig2, ax2 = plt.subplots(figsize=(8, 6)) | |
sns.scatterplot(x="Age (yrs)", y="Cycle length(days)", hue="PCOS (Y/N)", | |
data=df, palette=["teal", "plum"], ax=ax2) | |
ax2.set_title("Menstrual Cycle Length vs Age by PCOS Status") | |
visualizations.append(fig2) | |
# 3. Follicle distribution scatter plot | |
fig3, ax3 = plt.subplots(figsize=(8, 6)) | |
sns.scatterplot(x="Follicle No. (L)", y="Follicle No. (R)", hue="PCOS (Y/N)", | |
data=df, palette=["teal", "plum"], ax=ax3) | |
ax3.set_title("Follicle Distribution (Left vs Right Ovary)") | |
visualizations.append(fig3) | |
# 4. Boxplot for Follicle numbers | |
fig4, ax4 = plt.subplots(figsize=(10, 6)) | |
sns.boxplot(x="PCOS (Y/N)", y="Follicle No. (L)", data=df, palette=["teal", "plum"], ax=ax4) | |
ax4.set_title("Follicle Count (Left Ovary) by PCOS Status") | |
visualizations.append(fig4) | |
# 5. Endometrium thickness boxplot | |
fig5, ax5 = plt.subplots(figsize=(10, 6)) | |
sns.boxplot(x="PCOS (Y/N)", y="Endometrium (mm)", data=df, palette=["teal", "plum"], ax=ax5) | |
ax5.set_title("Endometrium Thickness by PCOS Status") | |
visualizations.append(fig5) | |
return visualizations | |
# Helper function to get numerical value for categorical inputs | |
def get_numerical_value(value, options): | |
try: | |
return options.index(value) | |
except: | |
return 0 | |
# Helper function to preprocess inputs | |
def preprocess_inputs(input_dict): | |
# Convert checkbox values to 0/1 | |
for key in input_dict: | |
if isinstance(input_dict[key], bool): | |
input_dict[key] = 1 if input_dict[key] else 0 | |
# Convert blood group to numeric | |
blood_groups = ["A+", "A-", "B+", "B-", "AB+", "AB-", "O+", "O-"] | |
if "Blood Group" in input_dict and input_dict["Blood Group"] in blood_groups: | |
input_dict["Blood Group"] = blood_groups.index(input_dict["Blood Group"]) | |
return input_dict | |
# Function to process input and make predictions | |
def predict_pcos(*args): | |
if model is None: | |
return "Model not loaded correctly. Please check if model files are available." | |
try: | |
# Convert inputs to a dictionary and then DataFrame | |
input_dict = {feature: value for feature, value in zip(features, args)} | |
# Preprocess inputs | |
input_dict = preprocess_inputs(input_dict) | |
# Convert to DataFrame | |
input_df = pd.DataFrame([input_dict]) | |
# Print for debugging | |
print("Input shape:", input_df.shape) | |
print("Input data types:", input_df.dtypes) | |
# Make prediction | |
try: | |
prediction = model.predict(input_df)[0] | |
probability = model.predict_proba(input_df)[0] | |
result = "Positive for PCOS" if prediction == 1 else "Negative for PCOS" | |
conf = probability[1] if prediction == 1 else probability[0] | |
return f"{result} (Confidence: {conf:.2f})" | |
except AttributeError: | |
# If model is a numpy array, use a simple threshold-based prediction | |
# This is a fallback if the loaded model is just coefficients | |
print("Model is not a classifier object, using fallback prediction") | |
risk_score = np.mean([ | |
input_df["BMI"].values[0] / 30, | |
input_df["Follicle No. (L)"].values[0] / 15, | |
input_df["Follicle No. (R)"].values[0] / 15, | |
(1 if input_df["Cycle(R/I)"].values[0] > 3 else 0) | |
]) | |
prediction = 1 if risk_score > 0.6 else 0 | |
result = "Positive for PCOS" if prediction == 1 else "Negative for PCOS" | |
return f"{result} (Risk Score: {risk_score:.2f})" | |
except Exception as e: | |
import traceback | |
traceback.print_exc() | |
return f"Error making prediction: {str(e)}" | |
# Function to display visualizations | |
def show_visualization(visualization_index): | |
visualizations = create_visualizations() | |
if 0 <= visualization_index < len(visualizations): | |
return visualizations[visualization_index] | |
return None | |
# Create the Gradio interface | |
with gr.Blocks(title="PCOS Detection Tool") as app: | |
gr.Markdown("# PCOS Detection and Analysis Tool") | |
gr.Markdown("This application uses machine learning to detect Polycystic Ovary Syndrome (PCOS) based on patient data.") | |
with gr.Tabs(): | |
with gr.TabItem("Make Prediction"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Patient Demographics") | |
age = gr.Slider(18, 50, value=25, label="Age (yrs)") | |
weight = gr.Slider(40, 120, value=60, label="Weight (Kg)") | |
height = gr.Slider(140, 190, value=160, label="Height (cm)") | |
blood_group = gr.Dropdown(["A+", "A-", "B+", "B-", "AB+", "AB-", "O+", "O-"], value="A+", label="Blood Group") | |
bmi = gr.Slider(15, 40, value=22, label="BMI") | |
with gr.Column(): | |
gr.Markdown("### Vital Signs") | |
pulse = gr.Slider(60, 120, value=80, label="Pulse rate (bpm)") | |
rr = gr.Slider(12, 25, value=16, label="Respiratory Rate (breaths/min)") | |
systolic = gr.Slider(90, 180, value=120, label="BP Systolic (mmHg)") | |
diastolic = gr.Slider(60, 120, value=80, label="BP Diastolic (mmHg)") | |
hb = gr.Slider(8, 18, value=12, label="Hemoglobin (g/dl)") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Menstrual History") | |
cycle_length = gr.Slider(21, 45, value=28, label="Cycle length (days)") | |
cycle_regularity = gr.Radio([2, 4], value=2, label="Cycle Regularity (2=Regular, 4=Irregular)") | |
with gr.Column(): | |
gr.Markdown("### Physical Measurements") | |
hip = gr.Slider(30, 60, value=40, label="Hip (inch)") | |
waist = gr.Slider(20, 50, value=30, label="Waist (inch)") | |
waist_hip_ratio = gr.Slider(0.6, 1.2, value=0.75, label="Waist:Hip Ratio") | |
hip_waist_ratio = gr.Slider(1.0, 2.0, value=1.33, label="Hip:Waist Ratio") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Symptoms") | |
weight_gain = gr.Checkbox(label="Weight gain", value=False) | |
hair_growth = gr.Checkbox(label="Excessive hair growth", value=False) | |
skin_darkening = gr.Checkbox(label="Skin darkening", value=False) | |
hair_loss = gr.Checkbox(label="Hair loss", value=False) | |
pimples = gr.Checkbox(label="Pimples", value=False) | |
with gr.Column(): | |
gr.Markdown("### Lifestyle") | |
fast_food = gr.Checkbox(label="Fast food consumption", value=False) | |
regular_exercise = gr.Checkbox(label="Regular exercise", value=False) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Ultrasound Findings") | |
follicle_l = gr.Slider(0, 30, value=10, label="Follicle No. (Left)") | |
follicle_r = gr.Slider(0, 30, value=10, label="Follicle No. (Right)") | |
avg_fsize_l = gr.Slider(0, 25, value=5, label="Avg. Follicle size (Left) (mm)") | |
avg_fsize_r = gr.Slider(0, 25, value=5, label="Avg. Follicle size (Right) (mm)") | |
endometrium = gr.Slider(1, 20, value=8, label="Endometrium (mm)") | |
with gr.Column(): | |
gr.Markdown("### Hormone Levels") | |
fsh = gr.Slider(0, 20, value=6, label="FSH (mIU/mL)") | |
lh = gr.Slider(0, 20, value=7, label="LH (mIU/mL)") | |
fsh_lh_ratio = gr.Slider(0, 3, value=0.85, label="FSH/LH Ratio") | |
tsh = gr.Slider(0, 10, value=2.5, label="TSH (mIU/L)") | |
amh = gr.Slider(0, 10, value=3, label="AMH (ng/mL)") | |
prl = gr.Slider(0, 30, value=15, label="Prolactin (ng/mL)") | |
vit_d3 = gr.Slider(0, 100, value=30, label="Vitamin D3 (ng/mL)") | |
prg = gr.Slider(0, 20, value=5, label="Progesterone (ng/mL)") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Other Medical") | |
married_years = gr.Slider(0, 20, value=0, label="Marriage Status (Years)") | |
pregnant = gr.Checkbox(label="Currently Pregnant", value=False) | |
abortions = gr.Slider(0, 5, value=0, label="Number of abortions") | |
rbs = gr.Slider(70, 200, value=90, label="Random Blood Sugar (mg/dl)") | |
beta_hcg1 = gr.Slider(0, 100, value=5, label="I beta-HCG (mIU/mL)") | |
beta_hcg2 = gr.Slider(0, 100, value=5, label="II beta-HCG (mIU/mL)") | |
predict_btn = gr.Button("Predict PCOS Status") | |
prediction_output = gr.Textbox(label="Prediction Result") | |
# Connect inputs to prediction function | |
input_components = [ | |
age, weight, height, bmi, blood_group, pulse, rr, hb, cycle_length, | |
cycle_regularity, married_years, pregnant, abortions, hip, waist, | |
waist_hip_ratio, weight_gain, hair_growth, skin_darkening, hair_loss, | |
pimples, fast_food, regular_exercise, systolic, diastolic, follicle_l, | |
follicle_r, avg_fsize_l, avg_fsize_r, endometrium, fsh, lh, fsh_lh_ratio, | |
hip_waist_ratio, tsh, amh, prl, vit_d3, prg, rbs, weight_gain, beta_hcg1, beta_hcg2 | |
] | |
predict_btn.click( | |
predict_pcos, | |
inputs=input_components, | |
outputs=prediction_output | |
) | |
with gr.TabItem("Visualizations"): | |
gr.Markdown("### PCOS Data Analysis Visualizations") | |
visualization_choice = gr.Radio( | |
["BMI vs Age", "Menstrual Cycle Length vs Age", "Follicle Distribution", | |
"Follicle Count Boxplot", "Endometrium Thickness"], | |
value="BMI vs Age", | |
label="Select Visualization" | |
) | |
visualization_output = gr.Plot() | |
visualization_choice.change( | |
lambda choice: show_visualization(["BMI vs Age", "Menstrual Cycle Length vs Age", | |
"Follicle Distribution", "Follicle Count Boxplot", | |
"Endometrium Thickness"].index(choice)), | |
inputs=visualization_choice, | |
outputs=visualization_output | |
) | |
with gr.TabItem("About PCOS"): | |
gr.Markdown(""" | |
# Polycystic Ovary Syndrome (PCOS) | |
Polycystic ovary syndrome (PCOS) is a hormonal disorder common among women of reproductive age. | |
Women with PCOS may have infrequent or prolonged menstrual periods or excess male hormone (androgen) levels. | |
## Common Symptoms | |
- Irregular periods | |
- Excess androgen (elevated levels of male hormones) | |
- Polycystic ovaries | |
- Weight gain | |
- Acne | |
- Excessive hair growth (hirsutism) | |
- Thinning hair or hair loss | |
- Infertility | |
## Risk Factors | |
- Having a mother or sister with PCOS | |
- Insulin resistance | |
- Obesity | |
## Complications | |
- Infertility | |
- Gestational diabetes or pregnancy-induced high blood pressure | |
- Miscarriage or premature birth | |
- Type 2 diabetes or prediabetes | |
- Depression, anxiety, and eating disorders | |
- Sleep apnea | |
- Endometrial cancer | |
- Cardiovascular disease | |
## Treatment | |
Treatment focuses on managing your individual concerns, such as infertility, hirsutism, acne or obesity. | |
Specific treatment might involve lifestyle changes or medication. | |
""") | |
with gr.TabItem("Debug Info"): | |
gr.Markdown("### Model and System Information") | |
debug_output = gr.Textbox(label="Debug Information", value=f"Model type: {type(model).__name__}") | |
debug_btn = gr.Button("Check Model Status") | |
def check_model(): | |
try: | |
if model is None: | |
return "Model not loaded" | |
model_info = f"Model type: {type(model).__name__}\n" | |
# Try to get additional info based on model type | |
if hasattr(model, 'n_estimators'): | |
model_info += f"Number of estimators: {model.n_estimators}\n" | |
if hasattr(model, 'feature_importances_'): | |
top_features = np.argsort(model.feature_importances_)[-5:] | |
model_info += "Top 5 important features (indices): " + str(top_features.tolist()) + "\n" | |
# Check if the model has predict and predict_proba methods | |
has_predict = hasattr(model, 'predict') and callable(getattr(model, 'predict')) | |
has_proba = hasattr(model, 'predict_proba') and callable(getattr(model, 'predict_proba')) | |
model_info += f"Has predict method: {has_predict}\n" | |
model_info += f"Has predict_proba method: {has_proba}\n" | |
return model_info | |
except Exception as e: | |
return f"Error checking model: {str(e)}" | |
debug_btn.click(check_model, outputs=debug_output) | |
# Launch the app | |
if __name__ == "__main__": | |
app.launch(share=True, debug=True) |