|
import os
|
|
import json
|
|
import base64
|
|
import io
|
|
import ast
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, List, Optional, Any
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
import streamlit as st
|
|
import spacy
|
|
|
|
from scipy.stats import ttest_ind, f_oneway
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.linear_model import LogisticRegression
|
|
from sklearn.metrics import accuracy_score
|
|
|
|
from statsmodels.tsa.seasonal import seasonal_decompose
|
|
from statsmodels.tsa.stattools import adfuller
|
|
|
|
from pydantic import BaseModel, Field
|
|
from Bio import Entrez
|
|
|
|
from dotenv import load_dotenv
|
|
import requests
|
|
import openai
|
|
from openai.error import APIError, RateLimitError, InvalidRequestError
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
logging.basicConfig(
|
|
filename='app.log',
|
|
filemode='a',
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
level=logging.INFO
|
|
)
|
|
logger = logging.getLogger()
|
|
|
|
|
|
st.set_page_config(page_title="AI Clinical Intelligence Hub", layout="wide")
|
|
|
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
PUB_EMAIL = os.getenv("PUB_EMAIL", "")
|
|
|
|
if not OPENAI_API_KEY:
|
|
st.error("OpenAI API key must be set as an environment variable (OPENAI_API_KEY).")
|
|
st.stop()
|
|
|
|
|
|
openai.api_key = OPENAI_API_KEY
|
|
|
|
|
|
try:
|
|
nlp = spacy.load("en_core_web_sm")
|
|
except OSError:
|
|
|
|
import subprocess
|
|
import sys
|
|
subprocess.run([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
|
|
nlp = spacy.load("en_core_web_sm")
|
|
|
|
|
|
|
|
class ResearchInput(BaseModel):
|
|
"""Base schema for research tool inputs."""
|
|
data_key: str = Field(..., description="Session state key containing DataFrame")
|
|
columns: Optional[List[str]] = Field(None, description="List of columns to analyze")
|
|
|
|
class TemporalAnalysisInput(ResearchInput):
|
|
"""Schema for temporal analysis."""
|
|
time_col: str = Field(..., description="Name of timestamp column")
|
|
value_col: str = Field(..., description="Name of value column to analyze")
|
|
|
|
class HypothesisInput(ResearchInput):
|
|
"""Schema for hypothesis testing."""
|
|
group_col: str = Field(..., description="Categorical column defining groups")
|
|
value_col: str = Field(..., description="Numerical column to compare")
|
|
|
|
class ModelTrainingInput(ResearchInput):
|
|
"""Schema for model training."""
|
|
target_col: str = Field(..., description="Name of target column")
|
|
|
|
class DataAnalyzer(ABC):
|
|
"""Abstract base class for data analysis modules."""
|
|
@abstractmethod
|
|
def invoke(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
|
pass
|
|
|
|
|
|
|
|
class AdvancedEDA(DataAnalyzer):
|
|
"""Comprehensive Exploratory Data Analysis."""
|
|
def invoke(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
|
try:
|
|
analysis = {
|
|
"dimensionality": {
|
|
"rows": len(data),
|
|
"columns": list(data.columns),
|
|
"memory_usage_MB": f"{data.memory_usage().sum() / 1e6:.2f} MB"
|
|
},
|
|
"statistical_profile": data.describe(percentiles=[.25, .5, .75]).to_dict(),
|
|
"temporal_analysis": {
|
|
"date_ranges": {
|
|
col: {
|
|
"min": data[col].min(),
|
|
"max": data[col].max()
|
|
} for col in data.select_dtypes(include='datetime').columns
|
|
}
|
|
},
|
|
"data_quality": {
|
|
"missing_values": data.isnull().sum().to_dict(),
|
|
"duplicates": data.duplicated().sum(),
|
|
"cardinality": {
|
|
col: data[col].nunique() for col in data.columns
|
|
}
|
|
}
|
|
}
|
|
return analysis
|
|
except Exception as e:
|
|
logger.error(f"EDA Failed: {str(e)}")
|
|
return {"error": f"EDA Failed: {str(e)}"}
|
|
|
|
class DistributionVisualizer(DataAnalyzer):
|
|
"""Distribution visualizations."""
|
|
def invoke(self, data: pd.DataFrame, columns: List[str], **kwargs) -> str:
|
|
try:
|
|
plt.figure(figsize=(12, 6))
|
|
for i, col in enumerate(columns, 1):
|
|
plt.subplot(1, len(columns), i)
|
|
sns.histplot(data[col], kde=True, stat="density")
|
|
plt.title(f'Distribution of {col}', fontsize=10)
|
|
plt.xticks(fontsize=8)
|
|
plt.yticks(fontsize=8)
|
|
plt.tight_layout()
|
|
|
|
buf = io.BytesIO()
|
|
plt.savefig(buf, format='png', dpi=300, bbox_inches='tight')
|
|
plt.close()
|
|
return base64.b64encode(buf.getvalue()).decode()
|
|
except Exception as e:
|
|
logger.error(f"Visualization Error: {str(e)}")
|
|
return f"Visualization Error: {str(e)}"
|
|
|
|
class TemporalAnalyzer(DataAnalyzer):
|
|
"""Time series analysis."""
|
|
def invoke(self, data: pd.DataFrame, time_col: str, value_col: str, **kwargs) -> Dict[str, Any]:
|
|
try:
|
|
ts_data = data.set_index(pd.to_datetime(data[time_col]))[value_col]
|
|
decomposition = seasonal_decompose(ts_data, period=365)
|
|
|
|
plt.figure(figsize=(12, 8))
|
|
decomposition.plot()
|
|
plt.tight_layout()
|
|
|
|
buf = io.BytesIO()
|
|
plt.savefig(buf, format='png')
|
|
plt.close()
|
|
plot_data = base64.b64encode(buf.getvalue()).decode()
|
|
|
|
stationarity_p_value = adfuller(ts_data)[1]
|
|
|
|
return {
|
|
"trend_statistics": {
|
|
"stationarity_p_value": stationarity_p_value,
|
|
"seasonality_strength": float(max(decomposition.seasonal))
|
|
},
|
|
"visualization": plot_data
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Temporal Analysis Failed: {str(e)}")
|
|
return {"error": f"Temporal Analysis Failed: {str(e)}"}
|
|
|
|
class HypothesisTester(DataAnalyzer):
|
|
"""Statistical hypothesis testing."""
|
|
def invoke(self, data: pd.DataFrame, group_col: str, value_col: str, **kwargs) -> Dict[str, Any]:
|
|
try:
|
|
groups = data[group_col].unique()
|
|
|
|
if len(groups) < 2:
|
|
return {"error": "Insufficient groups for comparison"}
|
|
|
|
group_data = [data[data[group_col] == g][value_col] for g in groups]
|
|
|
|
if len(groups) == 2:
|
|
stat, p = ttest_ind(*group_data)
|
|
test_type = "Independent t-test"
|
|
effect_size = self.calculate_cohens_d(group_data[0], group_data[1])
|
|
else:
|
|
stat, p = f_oneway(*group_data)
|
|
test_type = "ANOVA"
|
|
effect_size = None
|
|
|
|
return {
|
|
"test_type": test_type,
|
|
"test_statistic": stat,
|
|
"p_value": p,
|
|
"effect_size": effect_size,
|
|
"interpretation": self.interpret_p_value(p)
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Hypothesis Testing Failed: {str(e)}")
|
|
return {"error": f"Hypothesis Testing Failed: {str(e)}"}
|
|
|
|
@staticmethod
|
|
def calculate_cohens_d(x: pd.Series, y: pd.Series) -> Optional[float]:
|
|
"""Calculate Cohen's d for effect size."""
|
|
try:
|
|
mean_diff = abs(x.mean() - y.mean())
|
|
pooled_std = np.sqrt((x.var() + y.var()) / 2)
|
|
return mean_diff / pooled_std
|
|
except Exception as e:
|
|
logger.error(f"Error calculating Cohen's d: {str(e)}")
|
|
return None
|
|
|
|
@staticmethod
|
|
def interpret_p_value(p: float) -> str:
|
|
"""Interpret the p-value."""
|
|
if p < 0.001:
|
|
return "Very strong evidence against H0"
|
|
elif p < 0.01:
|
|
return "Strong evidence against H0"
|
|
elif p < 0.05:
|
|
return "Evidence against H0"
|
|
elif p < 0.1:
|
|
return "Weak evidence against H0"
|
|
else:
|
|
return "No significant evidence against H0"
|
|
|
|
class LogisticRegressionTrainer(DataAnalyzer):
|
|
"""Logistic Regression Model Trainer."""
|
|
def invoke(self, data: pd.DataFrame, target_col: str, columns: List[str], **kwargs) -> Dict[str, Any]:
|
|
try:
|
|
X = data[columns]
|
|
y = data[target_col]
|
|
X_train, X_test, y_train, y_test = train_test_split(
|
|
X, y, test_size=0.2, random_state=42
|
|
)
|
|
model = LogisticRegression(max_iter=1000)
|
|
model.fit(X_train, y_train)
|
|
y_pred = model.predict(X_test)
|
|
accuracy = accuracy_score(y_test, y_pred)
|
|
return {
|
|
"model_type": "Logistic Regression",
|
|
"accuracy": accuracy,
|
|
"model_params": model.get_params()
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Logistic Regression Model Error: {str(e)}")
|
|
return {"error": f"Logistic Regression Model Error: {str(e)}"}
|
|
|
|
|
|
|
|
class ClinicalRule(BaseModel):
|
|
"""Defines a clinical rule."""
|
|
name: str
|
|
condition: str
|
|
action: str
|
|
severity: str
|
|
|
|
class ClinicalRulesEngine:
|
|
"""Executes rules against patient data."""
|
|
def __init__(self):
|
|
self.rules: Dict[str, ClinicalRule] = {}
|
|
|
|
def add_rule(self, rule: ClinicalRule):
|
|
self.rules[rule.name] = rule
|
|
|
|
def execute_rules(self, data: pd.DataFrame) -> Dict[str, Any]:
|
|
results = {}
|
|
for rule_name, rule in self.rules.items():
|
|
try:
|
|
|
|
rule_matched = self.safe_eval(rule.condition, {"df": data})
|
|
results[rule_name] = {
|
|
"rule_matched": rule_matched,
|
|
"action": rule.action if rule_matched else None,
|
|
"severity": rule.severity if rule_matched else None
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error executing rule '{rule_name}': {str(e)}")
|
|
results[rule_name] = {
|
|
"rule_matched": False,
|
|
"error": str(e),
|
|
"severity": None
|
|
}
|
|
return results
|
|
|
|
@staticmethod
|
|
def safe_eval(expr, variables):
|
|
"""
|
|
Safely evaluate an expression using AST parsing.
|
|
Only allows certain node types to prevent execution of arbitrary code.
|
|
"""
|
|
allowed_nodes = (
|
|
ast.Expression, ast.BoolOp, ast.BinOp, ast.UnaryOp, ast.Compare,
|
|
ast.Call, ast.Name, ast.Load, ast.Constant, ast.Num, ast.Str,
|
|
ast.List, ast.Tuple, ast.Dict
|
|
)
|
|
try:
|
|
node = ast.parse(expr, mode='eval')
|
|
for subnode in ast.walk(node):
|
|
if not isinstance(subnode, allowed_nodes):
|
|
raise ValueError(f"Unsupported expression: {expr}")
|
|
return eval(compile(node, '<string>', mode='eval'), {"__builtins__": None}, variables)
|
|
except Exception as e:
|
|
logger.error(f"safe_eval error: {str(e)}")
|
|
raise ValueError(f"Invalid expression: {e}")
|
|
|
|
class ClinicalKPI(BaseModel):
|
|
"""Define a clinical KPI."""
|
|
name: str
|
|
calculation: str
|
|
threshold: Optional[float] = None
|
|
|
|
class ClinicalKPIMonitoring:
|
|
"""Calculates KPIs based on data."""
|
|
def __init__(self):
|
|
self.kpis: Dict[str, ClinicalKPI] = {}
|
|
|
|
def add_kpi(self, kpi: ClinicalKPI):
|
|
self.kpis[kpi.name] = kpi
|
|
|
|
def calculate_kpis(self, data: pd.DataFrame) -> Dict[str, Any]:
|
|
results = {}
|
|
for kpi_name, kpi in self.kpis.items():
|
|
try:
|
|
|
|
kpi_value = self.safe_eval(kpi.calculation, {"df": data})
|
|
status = self.evaluate_threshold(kpi_value, kpi.threshold)
|
|
results[kpi_name] = {
|
|
"value": kpi_value,
|
|
"threshold": kpi.threshold,
|
|
"status": status
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error calculating KPI '{kpi_name}': {str(e)}")
|
|
results[kpi_name] = {"error": str(e)}
|
|
return results
|
|
|
|
@staticmethod
|
|
def evaluate_threshold(value: Any, threshold: Optional[float]) -> Optional[str]:
|
|
if threshold is None:
|
|
return None
|
|
try:
|
|
return "Above Threshold" if value > threshold else "Below Threshold"
|
|
except TypeError:
|
|
return "Threshold Evaluation Not Applicable"
|
|
|
|
@staticmethod
|
|
def safe_eval(expr, variables):
|
|
"""
|
|
Safely evaluate an expression using AST parsing.
|
|
Only allows certain node types to prevent execution of arbitrary code.
|
|
"""
|
|
allowed_nodes = (
|
|
ast.Expression, ast.BoolOp, ast.BinOp, ast.UnaryOp, ast.Compare,
|
|
ast.Call, ast.Name, ast.Load, ast.Constant, ast.Num, ast.Str,
|
|
ast.List, ast.Tuple, ast.Dict
|
|
)
|
|
try:
|
|
node = ast.parse(expr, mode='eval')
|
|
for subnode in ast.walk(node):
|
|
if not isinstance(subnode, allowed_nodes):
|
|
raise ValueError(f"Unsupported expression: {expr}")
|
|
return eval(compile(node, '<string>', mode='eval'), {"__builtins__": None}, variables)
|
|
except Exception as e:
|
|
logger.error(f"safe_eval error: {str(e)}")
|
|
raise ValueError(f"Invalid expression: {e}")
|
|
|
|
class DiagnosisSupport(ABC):
|
|
"""Abstract class for implementing clinical diagnoses."""
|
|
@abstractmethod
|
|
def diagnose(
|
|
self,
|
|
data: pd.DataFrame,
|
|
target_col: str,
|
|
columns: List[str],
|
|
diagnosis_key: str = "diagnosis",
|
|
**kwargs
|
|
) -> pd.DataFrame:
|
|
pass
|
|
|
|
class SimpleDiagnosis(DiagnosisSupport):
|
|
"""Provides a simple diagnosis example, based on the Logistic regression model."""
|
|
def __init__(self):
|
|
self.model_trainer: LogisticRegressionTrainer = LogisticRegressionTrainer()
|
|
|
|
def diagnose(
|
|
self,
|
|
data: pd.DataFrame,
|
|
target_col: str,
|
|
columns: List[str],
|
|
diagnosis_key: str = "diagnosis",
|
|
**kwargs
|
|
) -> pd.DataFrame:
|
|
try:
|
|
result = self.model_trainer.invoke(data, target_col=target_col, columns=columns)
|
|
if "accuracy" in result:
|
|
return pd.DataFrame({
|
|
diagnosis_key: [f"Model Accuracy: {result['accuracy']:.2%}"],
|
|
"model": [result["model_type"]]
|
|
})
|
|
else:
|
|
return pd.DataFrame({
|
|
diagnosis_key: [f"Diagnosis failed: {result.get('error', 'Unknown error')}"]
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Error during diagnosis: {str(e)}")
|
|
return pd.DataFrame({
|
|
diagnosis_key: [f"Error during diagnosis: {e}"]
|
|
})
|
|
|
|
class TreatmentRecommendation(ABC):
|
|
"""Abstract class for treatment recommendations."""
|
|
@abstractmethod
|
|
def recommend(
|
|
self,
|
|
data: pd.DataFrame,
|
|
condition_col: str,
|
|
treatment_col: str,
|
|
recommendation_key: str = "recommendation",
|
|
**kwargs
|
|
) -> pd.DataFrame:
|
|
pass
|
|
|
|
class BasicTreatmentRecommendation(TreatmentRecommendation):
|
|
"""A placeholder class for basic treatment recommendations."""
|
|
def recommend(
|
|
self,
|
|
data: pd.DataFrame,
|
|
condition_col: str,
|
|
treatment_col: str,
|
|
recommendation_key: str = "recommendation",
|
|
**kwargs
|
|
) -> pd.DataFrame:
|
|
if condition_col not in data.columns or treatment_col not in data.columns:
|
|
logger.warning(f"Condition or Treatment columns not found: {condition_col}, {treatment_col}")
|
|
return pd.DataFrame({
|
|
recommendation_key: ["Condition or Treatment columns not found!"]
|
|
})
|
|
|
|
treatment = data[data[condition_col] == "High"][treatment_col].to_list()
|
|
if treatment:
|
|
return pd.DataFrame({
|
|
recommendation_key: [f"Treatment recommended for High risk patients: {treatment}"]
|
|
})
|
|
else:
|
|
return pd.DataFrame({
|
|
recommendation_key: ["No treatment recommendation found!"]
|
|
})
|
|
|
|
|
|
|
|
class MedicalKnowledgeBase(ABC):
|
|
"""Abstract class for Medical Knowledge."""
|
|
@abstractmethod
|
|
def search_medical_info(self, query: str, pub_email: str = "") -> str:
|
|
pass
|
|
|
|
class SimpleMedicalKnowledge(MedicalKnowledgeBase):
|
|
"""Enhanced Medical Knowledge Class using OpenAI GPT-4."""
|
|
def __init__(self, nlp_model):
|
|
self.nlp = nlp_model
|
|
|
|
def search_medical_info(self, query: str, pub_email: str = "") -> str:
|
|
"""
|
|
Uses OpenAI's GPT-4 to fetch medical information based on the user's query.
|
|
"""
|
|
logger.info(f"Received medical query: {query}")
|
|
try:
|
|
|
|
doc = self.nlp(query.lower())
|
|
entities = [ent.text for ent in doc.ents]
|
|
processed_query = " ".join(entities) if entities else query.lower()
|
|
|
|
logger.info(f"Processed query: {processed_query}")
|
|
|
|
|
|
prompt = f"""
|
|
You are a medical assistant. Provide a comprehensive and accurate response to the following medical query:
|
|
|
|
Query: {processed_query}
|
|
|
|
Please ensure the information is clear, concise, and evidence-based.
|
|
"""
|
|
|
|
|
|
response = openai.ChatCompletion.create(
|
|
model="gpt-4",
|
|
messages=[
|
|
{"role": "system", "content": "You are a helpful medical assistant."},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
max_tokens=500,
|
|
temperature=0.7,
|
|
)
|
|
|
|
|
|
answer = response.choices[0].message['content'].strip()
|
|
|
|
logger.info("Successfully retrieved data from OpenAI GPT-4.")
|
|
|
|
|
|
pubmed_abstract = self.fetch_pubmed_abstract(processed_query, pub_email)
|
|
|
|
|
|
return f"**Based on your query:** {answer}\n\n**PubMed Abstract:**\n\n{pubmed_abstract}"
|
|
|
|
except RateLimitError as e:
|
|
logger.error(f"Rate Limit Exceeded: {str(e)}")
|
|
return "Rate limit exceeded. Please try again later."
|
|
except InvalidRequestError as e:
|
|
logger.error(f"Invalid Request: {str(e)}")
|
|
return f"Invalid request: {str(e)}"
|
|
except APIError as e:
|
|
logger.error(f"OpenAI API Error: {str(e)}")
|
|
return f"OpenAI API Error: {str(e)}"
|
|
except Exception as e:
|
|
logger.error(f"Medical Knowledge Search Failed: {str(e)}")
|
|
return f"Medical Knowledge Search Failed: {str(e)}"
|
|
|
|
def fetch_pubmed_abstract(self, query: str, email: str) -> str:
|
|
"""
|
|
Searches PubMed for abstracts related to the query.
|
|
"""
|
|
try:
|
|
if not email:
|
|
logger.warning("PubMed abstract retrieval skipped: Email not provided.")
|
|
return "No PubMed abstract available: Email not provided."
|
|
|
|
Entrez.email = email
|
|
handle = Entrez.esearch(db="pubmed", term=query, retmax=1, sort='relevance')
|
|
record = Entrez.read(handle)
|
|
handle.close()
|
|
logger.info(f"PubMed search for query '{query}' returned IDs: {record['IdList']}")
|
|
|
|
if record["IdList"]:
|
|
handle = Entrez.efetch(db="pubmed", id=record["IdList"][0], rettype="abstract", retmode="text")
|
|
abstract = handle.read()
|
|
handle.close()
|
|
logger.info(f"Fetched PubMed abstract for ID {record['IdList'][0]}")
|
|
return abstract
|
|
else:
|
|
logger.info(f"No PubMed abstracts found for query '{query}'.")
|
|
return "No abstracts found for this query on PubMed."
|
|
except Exception as e:
|
|
logger.error(f"Error searching PubMed: {e}")
|
|
return f"Error searching PubMed: {e}"
|
|
|
|
|
|
|
|
class ForecastingEngine(ABC):
|
|
"""Abstract class for forecasting."""
|
|
@abstractmethod
|
|
def predict(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame:
|
|
pass
|
|
|
|
class SimpleForecasting(ForecastingEngine):
|
|
"""Simple forecasting engine."""
|
|
def predict(self, data: pd.DataFrame, period: int = 7, **kwargs) -> pd.DataFrame:
|
|
|
|
return pd.DataFrame({"forecast": [f"Forecast for the next {period} days"]})
|
|
|
|
|
|
|
|
class AutomatedInsights:
|
|
"""Generates automated insights based on selected analyses."""
|
|
def __init__(self):
|
|
self.analyses: Dict[str, DataAnalyzer] = {
|
|
"EDA": AdvancedEDA(),
|
|
"temporal": TemporalAnalyzer(),
|
|
"distribution": DistributionVisualizer(),
|
|
"hypothesis": HypothesisTester(),
|
|
"model": LogisticRegressionTrainer()
|
|
}
|
|
|
|
def generate_insights(self, data: pd.DataFrame, analysis_names: List[str], **kwargs) -> Dict[str, Any]:
|
|
results = {}
|
|
for name in analysis_names:
|
|
analyzer = self.analyses.get(name)
|
|
if analyzer:
|
|
try:
|
|
results[name] = analyzer.invoke(data=data, **kwargs)
|
|
except Exception as e:
|
|
logger.error(f"Error in analysis '{name}': {str(e)}")
|
|
results[name] = {"error": str(e)}
|
|
else:
|
|
logger.warning(f"Analysis '{name}' not found.")
|
|
results[name] = {"error": "Analysis not found"}
|
|
return results
|
|
|
|
class Dashboard:
|
|
"""Handles the creation and display of the dashboard."""
|
|
def __init__(self):
|
|
self.layout: Dict[str, str] = {}
|
|
|
|
def add_visualisation(self, vis_name: str, vis_type: str):
|
|
self.layout[vis_name] = vis_type
|
|
|
|
def display_dashboard(self, data_dict: Dict[str, pd.DataFrame]):
|
|
st.header("Dashboard")
|
|
for vis_name, vis_type in self.layout.items():
|
|
st.subheader(vis_name)
|
|
df = data_dict.get(vis_name)
|
|
if df is not None:
|
|
if vis_type == "table":
|
|
st.table(df)
|
|
elif vis_type == "plot":
|
|
if len(df.columns) > 1:
|
|
fig = plt.figure()
|
|
sns.lineplot(data=df)
|
|
st.pyplot(fig)
|
|
else:
|
|
st.write("Please select a DataFrame with more than 1 column for plotting.")
|
|
else:
|
|
st.write("Data Not Found")
|
|
|
|
class AutomatedReports:
|
|
"""Manages automated report definitions and generation."""
|
|
def __init__(self):
|
|
self.report_definitions: Dict[str, str] = {}
|
|
|
|
def create_report_definition(self, report_name: str, definition: str):
|
|
self.report_definitions[report_name] = definition
|
|
|
|
def generate_report(self, report_name: str, data: Dict[str, pd.DataFrame]) -> Dict[str, Any]:
|
|
if report_name not in self.report_definitions:
|
|
return {"error": "Report name not found"}
|
|
report_content = {
|
|
"Report Name": report_name,
|
|
"Report Definition": self.report_definitions[report_name],
|
|
"Data": {df_name: df.to_dict() for df_name, df in data.items()}
|
|
}
|
|
return report_content
|
|
|
|
|
|
|
|
class DataSource(ABC):
|
|
"""Base class for data sources."""
|
|
@abstractmethod
|
|
def connect(self) -> None:
|
|
"""Connect to the data source."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def fetch_data(self, query: str, **kwargs) -> pd.DataFrame:
|
|
"""Fetch the data based on a specific query."""
|
|
pass
|
|
|
|
class CSVDataSource(DataSource):
|
|
"""Data source for CSV files."""
|
|
def __init__(self, file_path: io.BytesIO):
|
|
self.file_path = file_path
|
|
self.data: Optional[pd.DataFrame] = None
|
|
|
|
def connect(self):
|
|
self.data = pd.read_csv(self.file_path)
|
|
|
|
def fetch_data(self, query: str = None, **kwargs) -> pd.DataFrame:
|
|
if self.data is None:
|
|
raise Exception("No connection is made, call connect()")
|
|
return self.data
|
|
|
|
class DatabaseSource(DataSource):
|
|
"""Data source for SQL Databases."""
|
|
def __init__(self, connection_string: str, database_type: str):
|
|
self.connection_string = connection_string
|
|
self.database_type = database_type.lower()
|
|
self.connection = None
|
|
|
|
def connect(self):
|
|
if self.database_type == "sql":
|
|
|
|
self.connection = "Connected to SQL Database"
|
|
else:
|
|
raise Exception(f"Database type '{self.database_type}' is not supported.")
|
|
|
|
def fetch_data(self, query: str, **kwargs) -> pd.DataFrame:
|
|
if self.connection is None:
|
|
raise Exception("No connection is made, call connect()")
|
|
|
|
return pd.DataFrame({"result": [f"Fetched data based on query: {query}"]})
|
|
|
|
class DataIngestion:
|
|
"""Handles data ingestion from various sources."""
|
|
def __init__(self):
|
|
self.sources: Dict[str, DataSource] = {}
|
|
|
|
def add_source(self, source_name: str, source: DataSource):
|
|
self.sources[source_name] = source
|
|
|
|
def ingest_data(self, source_name: str, query: str = None, **kwargs) -> pd.DataFrame:
|
|
if source_name not in self.sources:
|
|
raise Exception(f"Source '{source_name}' not found.")
|
|
source = self.sources[source_name]
|
|
source.connect()
|
|
return source.fetch_data(query, **kwargs)
|
|
|
|
class DataModel(BaseModel):
|
|
"""Defines a data model."""
|
|
name: str
|
|
kpis: List[str] = Field(default_factory=list)
|
|
dimensions: List[str] = Field(default_factory=list)
|
|
custom_calculations: Optional[Dict[str, str]] = None
|
|
relations: Optional[Dict[str, str]] = None
|
|
|
|
def to_json(self) -> str:
|
|
return json.dumps(self.dict())
|
|
|
|
@staticmethod
|
|
def from_json(json_str: str) -> 'DataModel':
|
|
return DataModel(**json.loads(json_str))
|
|
|
|
class DataModelling:
|
|
"""Manages data models."""
|
|
def __init__(self):
|
|
self.models: Dict[str, DataModel] = {}
|
|
|
|
def add_model(self, model: DataModel):
|
|
self.models[model.name] = model
|
|
|
|
def get_model(self, model_name: str) -> DataModel:
|
|
if model_name not in self.models:
|
|
raise Exception(f"Model '{model_name}' not found.")
|
|
return self.models[model_name]
|
|
|
|
|
|
|
|
def main():
|
|
"""Main function to run the Streamlit app."""
|
|
st.title("🏥 AI-Powered Clinical Intelligence Hub")
|
|
|
|
|
|
initialize_session_state()
|
|
|
|
|
|
with st.sidebar:
|
|
data_management_section()
|
|
|
|
|
|
if st.session_state.data:
|
|
col1, col2 = st.columns([1, 3])
|
|
|
|
with col1:
|
|
dataset_metadata_section()
|
|
|
|
with col2:
|
|
main_tabs_section()
|
|
|
|
def initialize_session_state():
|
|
"""Initialize necessary components in Streamlit's session state."""
|
|
if 'data' not in st.session_state:
|
|
st.session_state.data = {}
|
|
if 'data_ingestion' not in st.session_state:
|
|
st.session_state.data_ingestion = DataIngestion()
|
|
if 'data_modelling' not in st.session_state:
|
|
st.session_state.data_modelling = DataModelling()
|
|
if 'clinical_rules' not in st.session_state:
|
|
st.session_state.clinical_rules = ClinicalRulesEngine()
|
|
if 'kpi_monitoring' not in st.session_state:
|
|
st.session_state.kpi_monitoring = ClinicalKPIMonitoring()
|
|
if 'forecasting_engine' not in st.session_state:
|
|
st.session_state.forecasting_engine = SimpleForecasting()
|
|
if 'automated_insights' not in st.session_state:
|
|
st.session_state.automated_insights = AutomatedInsights()
|
|
if 'dashboard' not in st.session_state:
|
|
st.session_state.dashboard = Dashboard()
|
|
if 'automated_reports' not in st.session_state:
|
|
st.session_state.automated_reports = AutomatedReports()
|
|
if 'diagnosis_support' not in st.session_state:
|
|
st.session_state.diagnosis_support = SimpleDiagnosis()
|
|
if 'treatment_recommendation' not in st.session_state:
|
|
st.session_state.treatment_recommendation = BasicTreatmentRecommendation()
|
|
if 'knowledge_base' not in st.session_state:
|
|
st.session_state.knowledge_base = SimpleMedicalKnowledge(nlp_model=nlp)
|
|
if 'pub_email' not in st.session_state:
|
|
st.session_state.pub_email = PUB_EMAIL
|
|
|
|
def data_management_section():
|
|
"""Handles the data management section in the sidebar."""
|
|
st.header("⚙️ Data Management")
|
|
data_source_selection = st.selectbox("Select Data Source Type", ["CSV", "SQL Database"])
|
|
|
|
if data_source_selection == "CSV":
|
|
handle_csv_upload()
|
|
elif data_source_selection == "SQL Database":
|
|
handle_sql_database()
|
|
|
|
if st.button("Ingest Data"):
|
|
ingest_data_action()
|
|
|
|
def handle_csv_upload():
|
|
"""Handles CSV file uploads."""
|
|
uploaded_file = st.file_uploader("Upload research dataset (CSV)", type=["csv"])
|
|
if uploaded_file:
|
|
source_name = st.text_input("Data Source Name")
|
|
if source_name:
|
|
try:
|
|
csv_source = CSVDataSource(file_path=uploaded_file)
|
|
st.session_state.data_ingestion.add_source(source_name, csv_source)
|
|
st.success(f"Uploaded {uploaded_file.name} as '{source_name}'.")
|
|
except Exception as e:
|
|
st.error(f"Error loading dataset: {e}")
|
|
|
|
def handle_sql_database():
|
|
"""Handles SQL database connections."""
|
|
conn_str = st.text_input("Enter connection string for SQL DB")
|
|
if conn_str:
|
|
source_name = st.text_input("Data Source Name")
|
|
if source_name:
|
|
try:
|
|
sql_source = DatabaseSource(connection_string=conn_str, database_type="sql")
|
|
st.session_state.data_ingestion.add_source(source_name, sql_source)
|
|
st.success(f"Added SQL DB Source '{source_name}'.")
|
|
except Exception as e:
|
|
st.error(f"Error loading database source: {e}")
|
|
|
|
def ingest_data_action():
|
|
"""Performs data ingestion from the selected source."""
|
|
if st.session_state.data_ingestion.sources:
|
|
source_name_to_fetch = st.selectbox("Select Data Source to Ingest", list(st.session_state.data_ingestion.sources.keys()))
|
|
query = st.text_area("Optional Query to Fetch data")
|
|
if source_name_to_fetch:
|
|
with st.spinner("Ingesting data..."):
|
|
try:
|
|
data = st.session_state.data_ingestion.ingest_data(source_name_to_fetch, query)
|
|
st.session_state.data[source_name_to_fetch] = data
|
|
st.success(f"Ingested data from '{source_name_to_fetch}'.")
|
|
except Exception as e:
|
|
st.error(f"Ingestion failed: {e}")
|
|
else:
|
|
st.error("No data source added. Please add a data source.")
|
|
|
|
def dataset_metadata_section():
|
|
"""Displays metadata for the selected dataset."""
|
|
st.subheader("Dataset Metadata")
|
|
data_source_keys = list(st.session_state.data.keys())
|
|
selected_data_key = st.selectbox("Select Dataset", data_source_keys)
|
|
|
|
if selected_data_key:
|
|
data = st.session_state.data[selected_data_key]
|
|
metadata = {
|
|
"Variables": list(data.columns),
|
|
"Time Range": {
|
|
col: {
|
|
"min": data[col].min(),
|
|
"max": data[col].max()
|
|
} for col in data.select_dtypes(include='datetime').columns
|
|
},
|
|
"Size": f"{data.memory_usage().sum() / 1e6:.2f} MB"
|
|
}
|
|
st.json(metadata)
|
|
|
|
st.session_state.selected_data_key = selected_data_key
|
|
|
|
def main_tabs_section():
|
|
"""Creates and manages the main tabs in the application."""
|
|
analysis_tab, clinical_logic_tab, insights_tab, reports_tab, knowledge_tab = st.tabs([
|
|
"Data Analysis",
|
|
"Clinical Logic",
|
|
"Insights",
|
|
"Reports",
|
|
"Medical Knowledge"
|
|
])
|
|
|
|
with analysis_tab:
|
|
data_analysis_section()
|
|
|
|
with clinical_logic_tab:
|
|
clinical_logic_section()
|
|
|
|
with insights_tab:
|
|
insights_section()
|
|
|
|
with reports_tab:
|
|
reports_section()
|
|
|
|
with knowledge_tab:
|
|
medical_knowledge_section()
|
|
|
|
def data_analysis_section():
|
|
"""Handles the Data Analysis tab."""
|
|
selected_data_key = st.session_state.get('selected_data_key', None)
|
|
if not selected_data_key:
|
|
st.warning("Please select a dataset from the metadata section.")
|
|
return
|
|
|
|
data = st.session_state.data[selected_data_key]
|
|
analysis_type = st.selectbox("Select Analysis Mode", [
|
|
"Exploratory Data Analysis",
|
|
"Temporal Pattern Analysis",
|
|
"Comparative Statistics",
|
|
"Distribution Analysis",
|
|
"Train Logistic Regression Model"
|
|
])
|
|
|
|
if analysis_type == "Exploratory Data Analysis":
|
|
perform_eda(data)
|
|
elif analysis_type == "Temporal Pattern Analysis":
|
|
perform_temporal_analysis(data)
|
|
elif analysis_type == "Comparative Statistics":
|
|
perform_comparative_statistics(data)
|
|
elif analysis_type == "Distribution Analysis":
|
|
perform_distribution_analysis(data)
|
|
elif analysis_type == "Train Logistic Regression Model":
|
|
perform_logistic_regression_training(data)
|
|
|
|
def perform_eda(data: pd.DataFrame):
|
|
"""Performs Exploratory Data Analysis."""
|
|
analyzer = AdvancedEDA()
|
|
eda_result = analyzer.invoke(data=data)
|
|
st.subheader("Data Quality Report")
|
|
st.json(eda_result)
|
|
|
|
def perform_temporal_analysis(data: pd.DataFrame):
|
|
"""Performs Temporal Pattern Analysis."""
|
|
time_cols = data.select_dtypes(include='datetime').columns
|
|
num_cols = data.select_dtypes(include=np.number).columns
|
|
|
|
if len(time_cols) == 0:
|
|
st.warning("No datetime columns available for temporal analysis.")
|
|
return
|
|
|
|
time_col = st.selectbox("Select Temporal Variable", time_cols)
|
|
value_col = st.selectbox("Select Analysis Variable", num_cols)
|
|
|
|
if time_col and value_col:
|
|
analyzer = TemporalAnalyzer()
|
|
result = analyzer.invoke(data=data, time_col=time_col, value_col=value_col)
|
|
if "visualization" in result and result["visualization"]:
|
|
st.image(f"data:image/png;base64,{result['visualization']}", use_column_width=True)
|
|
st.json(result)
|
|
|
|
def perform_comparative_statistics(data: pd.DataFrame):
|
|
"""Performs Comparative Statistics."""
|
|
categorical_cols = data.select_dtypes(include=['category', 'object']).columns
|
|
numeric_cols = data.select_dtypes(include=np.number).columns
|
|
|
|
if len(categorical_cols) == 0:
|
|
st.warning("No categorical columns available for hypothesis testing.")
|
|
return
|
|
|
|
if len(numeric_cols) == 0:
|
|
st.warning("No numerical columns available for hypothesis testing.")
|
|
return
|
|
|
|
group_col = st.selectbox("Select Grouping Variable", categorical_cols)
|
|
value_col = st.selectbox("Select Metric Variable", numeric_cols)
|
|
|
|
if group_col and value_col:
|
|
analyzer = HypothesisTester()
|
|
result = analyzer.invoke(data=data, group_col=group_col, value_col=value_col)
|
|
st.subheader("Statistical Test Results")
|
|
st.json(result)
|
|
|
|
def perform_distribution_analysis(data: pd.DataFrame):
|
|
"""Performs Distribution Analysis."""
|
|
numeric_cols = data.select_dtypes(include=np.number).columns.tolist()
|
|
selected_cols = st.multiselect("Select Variables for Distribution Analysis", numeric_cols)
|
|
|
|
if selected_cols:
|
|
analyzer = DistributionVisualizer()
|
|
img_data = analyzer.invoke(data=data, columns=selected_cols)
|
|
if not img_data.startswith("Visualization Error"):
|
|
st.image(f"data:image/png;base64,{img_data}", use_column_width=True)
|
|
else:
|
|
st.error(img_data)
|
|
else:
|
|
st.info("Please select at least one numerical column to visualize.")
|
|
|
|
def perform_logistic_regression_training(data: pd.DataFrame):
|
|
"""Trains a Logistic Regression model."""
|
|
numeric_cols = data.select_dtypes(include=np.number).columns.tolist()
|
|
target_col = st.selectbox("Select Target Variable", data.columns.tolist())
|
|
selected_cols = st.multiselect("Select Feature Variables", numeric_cols)
|
|
|
|
if selected_cols and target_col:
|
|
analyzer = LogisticRegressionTrainer()
|
|
result = analyzer.invoke(data=data, target_col=target_col, columns=selected_cols)
|
|
st.subheader("Logistic Regression Model Results")
|
|
st.json(result)
|
|
else:
|
|
st.warning("Please select both target and feature variables for model training.")
|
|
|
|
def clinical_logic_section():
|
|
"""Handles the Clinical Logic tab."""
|
|
st.header("Clinical Logic")
|
|
|
|
|
|
st.subheader("Clinical Rules")
|
|
rule_name = st.text_input("Enter Rule Name")
|
|
condition = st.text_area("Enter Rule Condition (use 'df' for DataFrame)",
|
|
help="Example: df['blood_pressure'] > 140")
|
|
action = st.text_area("Enter Action to be Taken on Rule Match")
|
|
severity = st.selectbox("Enter Severity for the Rule", ["low", "medium", "high"])
|
|
|
|
if st.button("Add Clinical Rule"):
|
|
if rule_name and condition and action and severity:
|
|
try:
|
|
rule = ClinicalRule(
|
|
name=rule_name,
|
|
condition=condition,
|
|
action=action,
|
|
severity=severity
|
|
)
|
|
st.session_state.clinical_rules.add_rule(rule)
|
|
st.success("Added Clinical Rule successfully.")
|
|
except Exception as e:
|
|
st.error(f"Error in rule definition: {e}")
|
|
else:
|
|
st.error("Please fill in all fields to add a clinical rule.")
|
|
|
|
|
|
st.subheader("Clinical KPI Definition")
|
|
kpi_name = st.text_input("Enter KPI Name")
|
|
kpi_calculation = st.text_area("Enter KPI Calculation (use 'df' for DataFrame)",
|
|
help="Example: df['patient_count'].sum()")
|
|
threshold = st.text_input("Enter Threshold for KPI (Optional)", help="Leave blank if not applicable")
|
|
|
|
if st.button("Add Clinical KPI"):
|
|
if kpi_name and kpi_calculation:
|
|
try:
|
|
threshold_value = float(threshold) if threshold else None
|
|
kpi = ClinicalKPI(
|
|
name=kpi_name,
|
|
calculation=kpi_calculation,
|
|
threshold=threshold_value
|
|
)
|
|
st.session_state.kpi_monitoring.add_kpi(kpi)
|
|
st.success(f"Added KPI '{kpi_name}' successfully.")
|
|
except ValueError:
|
|
st.error("Threshold must be a numeric value.")
|
|
except Exception as e:
|
|
st.error(f"Error creating KPI: {e}")
|
|
else:
|
|
st.error("Please provide both KPI name and calculation.")
|
|
|
|
|
|
selected_data_key = st.selectbox("Select Dataset for Clinical Logic", list(st.session_state.data.keys()))
|
|
if selected_data_key:
|
|
data = st.session_state.data[selected_data_key]
|
|
if st.button("Execute Clinical Rules"):
|
|
with st.spinner("Executing Clinical Rules..."):
|
|
result = st.session_state.clinical_rules.execute_rules(data)
|
|
st.json(result)
|
|
if st.button("Calculate Clinical KPIs"):
|
|
with st.spinner("Calculating Clinical KPIs..."):
|
|
result = st.session_state.kpi_monitoring.calculate_kpis(data)
|
|
st.json(result)
|
|
else:
|
|
st.warning("Please ingest data to execute clinical rules and calculate KPIs.")
|
|
|
|
def insights_section():
|
|
"""Handles the Insights tab."""
|
|
st.header("Automated Insights")
|
|
|
|
selected_data_key = st.selectbox("Select Dataset for Insights", list(st.session_state.data.keys()))
|
|
if not selected_data_key:
|
|
st.warning("Please select a dataset to generate insights.")
|
|
return
|
|
|
|
data = st.session_state.data[selected_data_key]
|
|
available_analyses = ["EDA", "temporal", "distribution", "hypothesis", "model"]
|
|
selected_analyses = st.multiselect("Select Analyses for Insights", available_analyses)
|
|
|
|
if st.button("Generate Automated Insights"):
|
|
if selected_analyses:
|
|
with st.spinner("Generating Insights..."):
|
|
results = st.session_state.automated_insights.generate_insights(
|
|
data, analysis_names=selected_analyses
|
|
)
|
|
st.json(results)
|
|
else:
|
|
st.warning("Please select at least one analysis to generate insights.")
|
|
|
|
|
|
st.subheader("Diagnosis Support")
|
|
target_col = st.selectbox("Select Target Variable for Diagnosis", data.columns.tolist())
|
|
numeric_cols = data.select_dtypes(include=np.number).columns.tolist()
|
|
selected_feature_cols = st.multiselect("Select Feature Variables for Diagnosis", numeric_cols)
|
|
|
|
if st.button("Generate Diagnosis"):
|
|
if target_col and selected_feature_cols:
|
|
with st.spinner("Generating Diagnosis..."):
|
|
result = st.session_state.diagnosis_support.diagnose(
|
|
data, target_col=target_col, columns=selected_feature_cols, diagnosis_key="diagnosis_result"
|
|
)
|
|
st.json(result)
|
|
else:
|
|
st.error("Please select both target and feature variables for diagnosis.")
|
|
|
|
|
|
st.subheader("Treatment Recommendation")
|
|
condition_col = st.selectbox("Select Condition Column for Treatment Recommendation", data.columns.tolist())
|
|
treatment_col = st.selectbox("Select Treatment Column for Treatment Recommendation", data.columns.tolist())
|
|
|
|
if st.button("Generate Treatment Recommendation"):
|
|
if condition_col and treatment_col:
|
|
with st.spinner("Generating Treatment Recommendation..."):
|
|
result = st.session_state.treatment_recommendation.recommend(
|
|
data, condition_col=condition_col, treatment_col=treatment_col, recommendation_key="treatment_recommendation"
|
|
)
|
|
st.json(result)
|
|
else:
|
|
st.error("Please select both condition and treatment columns.")
|
|
|
|
def reports_section():
|
|
"""Handles the Reports tab."""
|
|
st.header("Automated Reports")
|
|
|
|
|
|
st.subheader("Create Report Definition")
|
|
report_name = st.text_input("Report Name")
|
|
report_def = st.text_area("Report Definition", help="Describe the structure and content of the report.")
|
|
|
|
if st.button("Create Report Definition"):
|
|
if report_name and report_def:
|
|
st.session_state.automated_reports.create_report_definition(report_name, report_def)
|
|
st.success("Report definition created successfully.")
|
|
else:
|
|
st.error("Please provide both report name and definition.")
|
|
|
|
|
|
st.subheader("Generate Report")
|
|
report_names = list(st.session_state.automated_reports.report_definitions.keys())
|
|
if report_names:
|
|
report_name_to_generate = st.selectbox("Select Report to Generate", report_names)
|
|
if st.button("Generate Report"):
|
|
with st.spinner("Generating Report..."):
|
|
report = st.session_state.automated_reports.generate_report(report_name_to_generate, st.session_state.data)
|
|
if "error" not in report:
|
|
st.header(f"Report: {report['Report Name']}")
|
|
st.markdown(f"**Definition:** {report['Report Definition']}")
|
|
for df_name, df_content in report["Data"].items():
|
|
st.subheader(f"Data: {df_name}")
|
|
st.dataframe(pd.DataFrame(df_content))
|
|
else:
|
|
st.error(report["error"])
|
|
else:
|
|
st.info("No report definitions found. Please create a report definition first.")
|
|
|
|
def medical_knowledge_section():
|
|
"""Handles the Medical Knowledge tab."""
|
|
st.header("Medical Knowledge")
|
|
query = st.text_input("Enter your medical question here:")
|
|
|
|
if st.button("Search"):
|
|
if query.strip():
|
|
with st.spinner("Searching..."):
|
|
result = st.session_state.knowledge_base.search_medical_info(
|
|
query, pub_email=st.session_state.pub_email
|
|
)
|
|
st.markdown(result)
|
|
else:
|
|
st.error("Please enter a medical question to search.")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|