Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import tensorflow as tf | |
| from transformers import TFAutoModel, AutoTokenizer | |
| import numpy as np | |
| import shap | |
| from scipy.special import softmax | |
| # Model and Tokenizer Setup | |
| MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = TFAutoModel.from_pretrained(MODEL_NAME) | |
| # Constants | |
| SEQ_LEN = 128 | |
| CONDITIONS = [ | |
| "Common Cold", "COVID-19", "Allergies", "Anxiety Disorder", "Skin Infection", | |
| "Heart Condition", "Digestive Issues", "Migraine", "Muscle Strain", "Arthritis" | |
| ] | |
| # Dynamic Condition Predictions | |
| def predict_condition(description: str): | |
| tokens = tokenizer( | |
| description, max_length=SEQ_LEN, truncation=True, padding="max_length", return_tensors="tf" | |
| ) | |
| outputs = model(tokens).last_hidden_state[:, 0, :] # CLS token output | |
| scores = softmax(outputs.numpy()) | |
| predictions = dict(zip(CONDITIONS, scores.flatten())) | |
| return predictions | |
| # Lifestyle Tips | |
| LIFESTYLE_TIPS = { | |
| "Common Cold": "Rest, stay hydrated, and use saline nasal sprays.", | |
| "COVID-19": "Quarantine, stay hydrated, and seek medical attention if symptoms worsen.", | |
| "Allergies": "Avoid allergens, take antihistamines, and use air purifiers.", | |
| "Anxiety Disorder": "Practice mindfulness, exercise, and seek therapy if needed.", | |
| "Skin Infection": "Keep the area clean, use topical creams, and consult a dermatologist.", | |
| # Add more conditions and tips... | |
| } | |
| def get_lifestyle_advice(condition: str): | |
| return LIFESTYLE_TIPS.get(condition, "Consult a healthcare professional for guidance.") | |
| # Interactive Health Visualization (SHAP) | |
| def explain_prediction(text: str): | |
| explainer = shap.Explainer(lambda x: predict_condition(x), tokenizer) | |
| shap_values = explainer([text]) | |
| return shap.plots.text(shap_values, display=False) | |
| # Symptom Tracker (Simple Implementation) | |
| symptom_history = [] | |
| def log_symptom(symptom: str): | |
| symptom_history.append(symptom) | |
| return f"Logged: {symptom}. Total symptoms logged: {len(symptom_history)}" | |
| def display_symptom_trends(): | |
| return "\n".join(symptom_history[-10:]) # Last 10 logged symptoms | |
| # Gradio UI Design | |
| css = """ | |
| textarea { background-color: transparent; border: 1px solid #6366f1; } | |
| """ | |
| with gr.Blocks(title="MedAI Compass", css=css, theme=gr.themes.Soft()) as app: | |
| # Header | |
| gr.HTML("<h1>MedAI Compass: Comprehensive Symptom and Health Guide</h1>") | |
| # Section: Symptom Diagnosis | |
| with gr.Row(): | |
| gr.Markdown("## Symptom Diagnosis") | |
| input_description = gr.Textbox(label="Describe your symptom") | |
| diagnose_btn = gr.Button("Diagnose") | |
| diagnosis_output = gr.Label(label="Possible Conditions") | |
| diagnose_btn.click(predict_condition, inputs=input_description, outputs=diagnosis_output) | |
| # Section: SHAP Analysis | |
| with gr.Row(): | |
| gr.Markdown("## Explain Predictions") | |
| shap_text_input = gr.Textbox(label="Enter Symptom Description for Analysis") | |
| shap_btn = gr.Button("Generate Explanation") | |
| shap_output = gr.HTML() | |
| shap_btn.click(explain_prediction, inputs=shap_text_input, outputs=shap_output) | |
| # Section: Personalized Advice | |
| with gr.Row(): | |
| gr.Markdown("## Personalized Health Advice") | |
| condition_input = gr.Dropdown(choices=CONDITIONS, label="Select a Condition") | |
| advice_output = gr.Textbox(label="Advice") | |
| advice_btn = gr.Button("Get Advice") | |
| advice_btn.click(get_lifestyle_advice, inputs=condition_input, outputs=advice_output) | |
| # Section: Symptom Tracker | |
| with gr.Row(): | |
| gr.Markdown("## Symptom Tracker") | |
| tracker_input = gr.Textbox(label="Log a Symptom") | |
| tracker_btn = gr.Button("Log Symptom") | |
| tracker_output = gr.Textbox(label="Logged Symptoms") | |
| tracker_btn.click(log_symptom, inputs=tracker_input, outputs=tracker_output) | |
| tracker_display_btn = gr.Button("Display Trends") | |
| tracker_trends_output = gr.Textbox(label="Symptom Trends") | |
| tracker_display_btn.click(display_symptom_trends, outputs=tracker_trends_output) | |
| # Footer | |
| gr.HTML("<p>© 2024 MedAI Compass. All Rights Reserved.</p>") | |
| app.launch() | |