mgbam commited on
Commit
ae9ec5a
·
1 Parent(s): a273fc4

Add application file

Browse files
Files changed (1) hide show
  1. app.py +712 -0
app.py ADDED
@@ -0,0 +1,712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from typing import Dict, List, Optional, Any
4
+ from pydantic import BaseModel, Field
5
+ import base64
6
+ import io
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ from abc import ABC, abstractmethod
10
+ from sklearn.model_selection import train_test_split
11
+ from sklearn.linear_model import LogisticRegression
12
+ from sklearn.metrics import accuracy_score
13
+ from statsmodels.tsa.seasonal import seasonal_decompose
14
+ from statsmodels.tsa.stattools import adfuller
15
+ from langchain.prompts import PromptTemplate
16
+ from groq import Groq
17
+ import os
18
+ import numpy as np
19
+ from scipy.stats import ttest_ind, f_oneway
20
+ import json
21
+
22
+ # Initialize Groq Client
23
+ client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
24
+
25
+ # ---------------------- Base Classes and Schemas ---------------------------
26
+ class ResearchInput(BaseModel):
27
+ """Base schema for research tool inputs"""
28
+ data_key: str = Field(..., description="Session state key containing DataFrame")
29
+ columns: Optional[List[str]] = Field(None, description="List of columns to analyze")
30
+
31
+ class TemporalAnalysisInput(ResearchInput):
32
+ """Schema for temporal analysis"""
33
+ time_col: str = Field(..., description="Name of timestamp column")
34
+ value_col: str = Field(..., description="Name of value column to analyze")
35
+
36
+ class HypothesisInput(ResearchInput):
37
+ """Schema for hypothesis testing"""
38
+ group_col: str = Field(..., description="Categorical column defining groups")
39
+ value_col: str = Field(..., description="Numerical column to compare")
40
+
41
+ class ModelTrainingInput(ResearchInput):
42
+ """Schema for model training"""
43
+ target_col: str = Field(..., description="Name of target column")
44
+
45
+ class DataAnalyzer(ABC):
46
+ """Abstract base class for data analysis modules"""
47
+ @abstractmethod
48
+ def invoke(self, **kwargs) -> Dict[str, Any]:
49
+ pass
50
+
51
+ # ---------------------- Concrete Analyzer Implementations ---------------------------
52
+ class AdvancedEDA(DataAnalyzer):
53
+ """Comprehensive Exploratory Data Analysis"""
54
+ def invoke(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
55
+ try:
56
+ analysis = {
57
+ "dimensionality": {
58
+ "rows": len(data),
59
+ "columns": list(data.columns),
60
+ "memory_usage": f"{data.memory_usage().sum() / 1e6:.2f} MB"
61
+ },
62
+ "statistical_profile": data.describe(percentiles=[.25, .5, .75]).to_dict(),
63
+ "temporal_analysis": {
64
+ "date_ranges": {
65
+ col: {
66
+ "min": data[col].min(),
67
+ "max": data[col].max()
68
+ } for col in data.select_dtypes(include='datetime').columns
69
+ }
70
+ },
71
+ "data_quality": {
72
+ "missing_values": data.isnull().sum().to_dict(),
73
+ "duplicates": data.duplicated().sum(),
74
+ "cardinality": {
75
+ col: data[col].nunique() for col in data.columns
76
+ }
77
+ }
78
+ }
79
+ return analysis
80
+ except Exception as e:
81
+ return {"error": f"EDA Failed: {str(e)}"}
82
+
83
+ class DistributionVisualizer(DataAnalyzer):
84
+ """Distribution visualizations"""
85
+ def invoke(self, data: pd.DataFrame, columns: List[str], **kwargs) -> str:
86
+ try:
87
+ plt.figure(figsize=(12, 6))
88
+ for i, col in enumerate(columns, 1):
89
+ plt.subplot(1, len(columns), i)
90
+ sns.histplot(data[col], kde=True, stat="density")
91
+ plt.title(f'Distribution of {col}', fontsize=10)
92
+ plt.xticks(fontsize=8)
93
+ plt.yticks(fontsize=8)
94
+ plt.tight_layout()
95
+
96
+ buf = io.BytesIO()
97
+ plt.savefig(buf, format='png', dpi=300, bbox_inches='tight')
98
+ plt.close()
99
+ return base64.b64encode(buf.getvalue()).decode()
100
+ except Exception as e:
101
+ return f"Visualization Error: {str(e)}"
102
+
103
+ class TemporalAnalyzer(DataAnalyzer):
104
+ """Time series analysis"""
105
+ def invoke(self, data: pd.DataFrame, time_col: str, value_col: str, **kwargs) -> Dict[str, Any]:
106
+ try:
107
+ ts_data = data.set_index(pd.to_datetime(data[time_col]))[value_col]
108
+ decomposition = seasonal_decompose(ts_data, period=365)
109
+
110
+ plt.figure(figsize=(12, 8))
111
+ decomposition.plot()
112
+ plt.tight_layout()
113
+
114
+ buf = io.BytesIO()
115
+ plt.savefig(buf, format='png')
116
+ plt.close()
117
+ plot_data = base64.b64encode(buf.getvalue()).decode()
118
+
119
+ return {
120
+ "trend_statistics": {
121
+ "stationarity": adfuller(ts_data)[1],
122
+ "seasonality_strength": max(decomposition.seasonal)
123
+ },
124
+ "visualization": plot_data
125
+ }
126
+ except Exception as e:
127
+ return {"error": f"Temporal Analysis Failed: {str(e)}"}
128
+
129
+ class HypothesisTester(DataAnalyzer):
130
+ """Statistical hypothesis testing"""
131
+ def invoke(self, data: pd.DataFrame, group_col: str, value_col: str, **kwargs) -> Dict[str, Any]:
132
+ try:
133
+ groups = data[group_col].unique()
134
+
135
+ if len(groups) < 2:
136
+ return {"error": "Insufficient groups for comparison"}
137
+
138
+ if len(groups) == 2:
139
+ group_data = [data[data[group_col] == g][value_col] for g in groups]
140
+ stat, p = ttest_ind(*group_data)
141
+ test_type = "Independent t-test"
142
+ else:
143
+ group_data = [data[data[group_col] == g][value_col] for g in groups]
144
+ stat, p = f_oneway(*group_data)
145
+ test_type = "ANOVA"
146
+
147
+ return {
148
+ "test_type": test_type,
149
+ "test_statistic": stat,
150
+ "p_value": p,
151
+ "effect_size": {
152
+ "cohens_d": abs(group_data[0].mean() - group_data[1].mean())/np.sqrt(
153
+ (group_data[0].var() + group_data[1].var())/2
154
+ ) if len(groups) == 2 else None
155
+ },
156
+ "interpretation": self.interpret_p_value(p)
157
+ }
158
+ except Exception as e:
159
+ return {"error": f"Hypothesis Testing Failed: {str(e)}"}
160
+
161
+ def interpret_p_value(self, p: float) -> str:
162
+ if p < 0.001: return "Very strong evidence against H0"
163
+ elif p < 0.01: return "Strong evidence against H0"
164
+ elif p < 0.05: return "Evidence against H0"
165
+ elif p < 0.1: return "Weak evidence against H0"
166
+ else: return "No significant evidence against H0"
167
+
168
+ class LogisticRegressionTrainer(DataAnalyzer):
169
+ """Logistic Regression Model Trainer"""
170
+ def invoke(self, data: pd.DataFrame, target_col: str, columns: List[str], **kwargs) -> Dict[str, Any]:
171
+ try:
172
+ X = data[columns]
173
+ y = data[target_col]
174
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
175
+ model = LogisticRegression(max_iter=1000)
176
+ model.fit(X_train, y_train)
177
+ y_pred = model.predict(X_test)
178
+ accuracy = accuracy_score(y_test, y_pred)
179
+ return {
180
+ "model_type": "Logistic Regression",
181
+ "accuracy": accuracy,
182
+ "model_params": model.get_params()
183
+ }
184
+ except Exception as e:
185
+ return {"error": f"Logistic Regression Model Error: {str(e)}"}
186
+ # ---------------------- Business Logic Layer ---------------------------
187
+
188
+ class ClinicalRule(BaseModel):
189
+ """Defines a clinical rule"""
190
+ name: str
191
+ condition: str
192
+ action: str
193
+ severity: str # low, medium or high
194
+
195
+ class ClinicalRulesEngine():
196
+ """Executes rules against patient data."""
197
+ def __init__(self):
198
+ self.rules: Dict[str, ClinicalRule] = {}
199
+
200
+ def add_rule(self, rule: ClinicalRule):
201
+ self.rules[rule.name] = rule
202
+
203
+ def execute_rules(self, data: pd.DataFrame):
204
+ results = {}
205
+ for rule_name, rule in self.rules.items():
206
+ try:
207
+ if eval(rule.condition, {}, {"df":data}):
208
+ results[rule_name] = {"rule_matched": True,
209
+ "action": rule.action,
210
+ "severity": rule.severity
211
+ }
212
+ else:
213
+ results[rule_name] = {"rule_matched": False, "action": None, "severity": None}
214
+ except Exception as e:
215
+ results[rule_name] = {"rule_matched": False, "error": str(e), "severity": None}
216
+ return results
217
+
218
+ class ClinicalKPI(BaseModel):
219
+ """Define a clinical KPI"""
220
+ name: str
221
+ calculation: str
222
+ threshold: Optional[float] = None
223
+
224
+ class ClinicalKPIMonitoring():
225
+ """Calculates KPIs based on data"""
226
+ def __init__(self):
227
+ self.kpis : Dict[str, ClinicalKPI] = {}
228
+
229
+ def add_kpi(self, kpi:ClinicalKPI):
230
+ self.kpis[kpi.name] = kpi
231
+
232
+ def calculate_kpis(self, data: pd.DataFrame):
233
+ results = {}
234
+ for kpi_name, kpi in self.kpis.items():
235
+ try:
236
+ results[kpi_name] = eval(kpi.calculation, {}, {"df": data})
237
+ except Exception as e:
238
+ results[kpi_name] = {"error": str(e)}
239
+ return results
240
+
241
+ class DiagnosisSupport(ABC):
242
+ """Abstract class for implementing clinical diagnoses."""
243
+ @abstractmethod
244
+ def diagnose(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame:
245
+ pass
246
+
247
+ class SimpleDiagnosis(DiagnosisSupport):
248
+ """Provides a simple diagnosis example, based on the Logistic regression model"""
249
+ def __init__(self):
250
+ self.model : LogisticRegressionTrainer = LogisticRegressionTrainer()
251
+
252
+ def diagnose(self, data: pd.DataFrame, target_col: str, columns: List[str], **kwargs) -> pd.DataFrame:
253
+ try:
254
+ result = self.model.invoke(data, target_col=target_col, columns = columns)
255
+ if "accuracy" in result:
256
+ return pd.DataFrame({"diagnosis": [f"Accuracy {result['accuracy']}"],
257
+ "model": result["model_type"]})
258
+ else:
259
+ return pd.DataFrame({"diagnosis": [f"Diagnosis failed: {result}"]})
260
+
261
+ except Exception as e:
262
+ return pd.DataFrame({"diagnosis":[f"Error during diagnosis {e}"]})
263
+
264
+
265
+ class TreatmentRecommendation(ABC):
266
+ """Abstract class for treatment recommendations"""
267
+ @abstractmethod
268
+ def recommend(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame:
269
+ pass
270
+
271
+ class BasicTreatmentRecommendation(TreatmentRecommendation):
272
+ """A placeholder class for basic treatment recommendations"""
273
+ def recommend(self, data: pd.DataFrame, condition_col: str, treatment_col:str, **kwargs) -> pd.DataFrame:
274
+ if condition_col not in data.columns or treatment_col not in data.columns:
275
+ return pd.DataFrame({"recommendation": ["Condition or Treatment columns not found!"]})
276
+ treatment = data[data[condition_col] == "High"][treatment_col].to_list()
277
+ if len(treatment)>0:
278
+ return pd.DataFrame({"recommendation": [f"Treatment recommended for High risk patients: {treatment}"]})
279
+ else:
280
+ return pd.DataFrame({"recommendation": [f"No treatment recommendation found!"]})
281
+
282
+
283
+ class MedicalKnowledgeBase():
284
+ """Abstract class for Medical Knowledge"""
285
+ @abstractmethod
286
+ def search_medical_info(self, query: str) -> str:
287
+ pass
288
+
289
+ class SimpleMedicalKnowledge(MedicalKnowledgeBase):
290
+ """Simple Medical Knowledge Class"""
291
+ def search_medical_info(self, query: str) -> str:
292
+ if "diabetes treatment" in query.lower():
293
+ return "The recommended treatment for diabetes includes lifestyle changes, medication, and monitoring"
294
+ elif "heart disease risk factors" in query.lower():
295
+ return "Risk factors for heart disease include high blood pressure, high cholesterol, and smoking"
296
+ else:
297
+ return "No specific information is available"
298
+
299
+
300
+ class ForecastingEngine(ABC):
301
+ @abstractmethod
302
+ def predict(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame:
303
+ pass
304
+
305
+ class SimpleForecasting(ForecastingEngine):
306
+ def predict(self, data: pd.DataFrame, period: int = 7, **kwargs) -> pd.DataFrame:
307
+ #Placeholder for actual forecasting
308
+ return pd.DataFrame({"forecast":[f"Forecast for the next {period} days"]})
309
+
310
+ # ---------------------- Insights and Reporting Layer ---------------------------
311
+ class AutomatedInsights():
312
+ def __init__(self):
313
+ self.analyses : Dict[str, DataAnalyzer] = {
314
+ "EDA": AdvancedEDA(),
315
+ "temporal": TemporalAnalyzer(),
316
+ "distribution": DistributionVisualizer(),
317
+ "hypothesis": HypothesisTester(),
318
+ "model": LogisticRegressionTrainer()
319
+ }
320
+
321
+ def generate_insights(self, data: pd.DataFrame, analysis_names: List[str], **kwargs):
322
+ results = {}
323
+ for name in analysis_names:
324
+ if name in self.analyses:
325
+ analyzer = self.analyses[name]
326
+ results[name] = analyzer.invoke(data=data, **kwargs)
327
+ else:
328
+ results[name] = {"error": "Analysis not found"}
329
+ return results
330
+
331
+ class Dashboard():
332
+ def __init__(self):
333
+ self.layout: Dict[str,str] = {}
334
+
335
+ def add_visualisation(self, vis_name: str, vis_type: str):
336
+ self.layout[vis_name] = vis_type
337
+
338
+ def display_dashboard(self, data_dict: Dict[str,pd.DataFrame]):
339
+ st.header("Dashboard")
340
+ for vis_name, vis_type in self.layout.items():
341
+ st.subheader(vis_name)
342
+ if vis_type == "table":
343
+ if vis_name in data_dict:
344
+ st.table(data_dict[vis_name])
345
+ else:
346
+ st.write("Data Not Found")
347
+ elif vis_type == "plot":
348
+ if vis_name in data_dict:
349
+ df = data_dict[vis_name]
350
+ if len(df.columns) > 1:
351
+ fig = plt.figure()
352
+ sns.lineplot(data=df)
353
+ st.pyplot(fig)
354
+ else:
355
+ st.write("Please have more than 1 column")
356
+ else:
357
+ st.write("Data not found")
358
+ class AutomatedReports():
359
+ def __init__(self):
360
+ self.report_definition: Dict[str,str] = {}
361
+
362
+ def create_report_definition(self, report_name: str, definition: str):
363
+ self.report_definition[report_name] = definition
364
+
365
+ def generate_report(self, report_name: str, data:Dict[str, pd.DataFrame]):
366
+ if report_name not in self.report_definition:
367
+ return {"error":"Report name not found"}
368
+ st.header(f"Report : {report_name}")
369
+ st.write(f"Report Definition: {self.report_definition[report_name]}")
370
+ for df_name, df in data.items():
371
+ st.subheader(f"Data: {df_name}")
372
+ st.table(df)
373
+
374
+ # ---------------------- Data Acquisition Layer ---------------------------
375
+ class DataSource(ABC):
376
+ """Base class for data sources."""
377
+ @abstractmethod
378
+ def connect(self) -> None:
379
+ """Connect to the data source."""
380
+ pass
381
+
382
+ @abstractmethod
383
+ def fetch_data(self, query: str, **kwargs) -> pd.DataFrame:
384
+ """Fetch the data based on a specific query."""
385
+ pass
386
+
387
+
388
+ class CSVDataSource(DataSource):
389
+ """Data source for CSV files."""
390
+ def __init__(self, file_path: str):
391
+ self.file_path = file_path
392
+ self.data: Optional[pd.DataFrame] = None
393
+
394
+ def connect(self):
395
+ self.data = pd.read_csv(self.file_path)
396
+
397
+ def fetch_data(self, query: str = None, **kwargs) -> pd.DataFrame:
398
+ if self.data is None:
399
+ raise Exception("No connection is made, call connect()")
400
+ return self.data
401
+
402
+ class DatabaseSource(DataSource):
403
+ def __init__(self, connection_string: str, database_type: str):
404
+ self.connection_string = connection_string
405
+ self.database_type = database_type
406
+ self.connection = None
407
+
408
+ def connect(self):
409
+ if self.database_type.lower() == "sql":
410
+ #Placeholder for the actual database connection
411
+ self.connection = "Connected to SQL Database"
412
+ else:
413
+ raise Exception(f"Database type '{self.database_type}' is not supported")
414
+
415
+ def fetch_data(self, query: str, **kwargs) -> pd.DataFrame:
416
+ if self.connection is None:
417
+ raise Exception("No connection is made, call connect()")
418
+ #Placeholder for the data fetching
419
+ return pd.DataFrame({"result":[f"Fetched data based on query: {query}"]})
420
+
421
+
422
+ class DataIngestion:
423
+ def __init__(self):
424
+ self.sources : Dict[str, DataSource] = {}
425
+
426
+ def add_source(self, source_name: str, source: DataSource):
427
+ self.sources[source_name] = source
428
+
429
+ def ingest_data(self, source_name: str, query: str = None, **kwargs) -> pd.DataFrame:
430
+ if source_name not in self.sources:
431
+ raise Exception(f"Source '{source_name}' not found")
432
+ source = self.sources[source_name]
433
+ source.connect()
434
+ return source.fetch_data(query, **kwargs)
435
+
436
+ class DataModel(BaseModel):
437
+ name : str
438
+ kpis : List[str] = Field(default_factory=list)
439
+ dimensions : List[str] = Field(default_factory=list)
440
+ custom_calculations : Optional[Dict[str, str]] = None
441
+ relations: Optional[Dict[str,str]] = None #Example {table1: table2}
442
+
443
+ def to_json(self):
444
+ return json.dumps(self.dict())
445
+
446
+ @staticmethod
447
+ def from_json(json_str):
448
+ return DataModel(**json.loads(json_str))
449
+
450
+ class DataModelling():
451
+ def __init__(self):
452
+ self.models : Dict[str, DataModel] = {}
453
+
454
+ def add_model(self, model:DataModel):
455
+ self.models[model.name] = model
456
+
457
+ def get_model(self, model_name: str) -> DataModel:
458
+ if model_name not in self.models:
459
+ raise Exception(f"Model '{model_name}' not found")
460
+ return self.models[model_name]
461
+ # ---------------------- Main Streamlit Application ---------------------------
462
+ def main():
463
+ st.set_page_config(page_title="AI Clinical Intelligence Hub", layout="wide")
464
+ st.title("🏥 AI-Powered Clinical Intelligence Hub")
465
+
466
+ # Session State
467
+ if 'data' not in st.session_state:
468
+ st.session_state.data = {} # store pd.DataFrame under a name
469
+ if 'data_ingestion' not in st.session_state:
470
+ st.session_state.data_ingestion = DataIngestion()
471
+ if 'data_modelling' not in st.session_state:
472
+ st.session_state.data_modelling = DataModelling()
473
+ if 'clinical_rules' not in st.session_state:
474
+ st.session_state.clinical_rules = ClinicalRulesEngine()
475
+ if 'kpi_monitoring' not in st.session_state:
476
+ st.session_state.kpi_monitoring = ClinicalKPIMonitoring()
477
+ if 'forecasting_engine' not in st.session_state:
478
+ st.session_state.forecasting_engine = SimpleForecasting()
479
+ if 'automated_insights' not in st.session_state:
480
+ st.session_state.automated_insights = AutomatedInsights()
481
+ if 'dashboard' not in st.session_state:
482
+ st.session_state.dashboard = Dashboard()
483
+ if 'automated_reports' not in st.session_state:
484
+ st.session_state.automated_reports = AutomatedReports()
485
+ if 'diagnosis_support' not in st.session_state:
486
+ st.session_state.diagnosis_support = SimpleDiagnosis()
487
+ if 'treatment_recommendation' not in st.session_state:
488
+ st.session_state.treatment_recommendation = BasicTreatmentRecommendation()
489
+ if 'knowledge_base' not in st.session_state:
490
+ st.session_state.knowledge_base = SimpleMedicalKnowledge()
491
+
492
+
493
+ # Sidebar for Data Management
494
+ with st.sidebar:
495
+ st.header("⚙️ Data Management")
496
+ data_source_selection = st.selectbox("Select Data Source Type",["CSV","SQL Database"])
497
+ if data_source_selection == "CSV":
498
+ uploaded_file = st.file_uploader("Upload research dataset (CSV)", type=["csv"])
499
+ if uploaded_file:
500
+ source_name = st.text_input("Data Source Name")
501
+ if source_name:
502
+ try:
503
+ csv_source = CSVDataSource(file_path=uploaded_file)
504
+ st.session_state.data_ingestion.add_source(source_name,csv_source)
505
+ st.success(f"Uploaded {uploaded_file.name}")
506
+ except Exception as e:
507
+ st.error(f"Error loading dataset: {e}")
508
+ elif data_source_selection == "SQL Database":
509
+ conn_str = st.text_input("Enter connection string for SQL DB")
510
+ if conn_str:
511
+ source_name = st.text_input("Data Source Name")
512
+ if source_name:
513
+ try:
514
+ sql_source = DatabaseSource(connection_string=conn_str, database_type="sql")
515
+ st.session_state.data_ingestion.add_source(source_name, sql_source)
516
+ st.success(f"Added SQL DB Source {source_name}")
517
+ except Exception as e:
518
+ st.error(f"Error loading database source {e}")
519
+
520
+
521
+ if st.button("Ingest Data"):
522
+ if st.session_state.data_ingestion.sources:
523
+ source_name_to_fetch = st.selectbox("Select Data Source to Ingest", list(st.session_state.data_ingestion.sources.keys()))
524
+ query = st.text_area("Optional Query to Fetch data")
525
+ if source_name_to_fetch:
526
+ with st.spinner("Ingesting data..."):
527
+ try:
528
+ data = st.session_state.data_ingestion.ingest_data(source_name_to_fetch, query)
529
+ st.session_state.data[source_name_to_fetch] = data
530
+ st.success(f"Ingested data from {source_name_to_fetch}")
531
+ except Exception as e:
532
+ st.error(f"Ingestion failed: {e}")
533
+ else:
534
+ st.error("No data source added, please add data source")
535
+
536
+ if st.session_state.data:
537
+ col1, col2 = st.columns([1, 3])
538
+
539
+ with col1:
540
+ st.subheader("Dataset Metadata")
541
+
542
+ data_source_keys = list(st.session_state.data.keys())
543
+ selected_data_key = st.selectbox("Select Dataset", data_source_keys)
544
+
545
+ if selected_data_key:
546
+ data = st.session_state.data[selected_data_key]
547
+ st.json({
548
+ "Variables": list(data.columns),
549
+ "Time Range": {
550
+ col: {
551
+ "min": data[col].min(),
552
+ "max": data[col].max()
553
+ } for col in data.select_dtypes(include='datetime').columns
554
+ },
555
+ "Size": f"{data.memory_usage().sum() / 1e6:.2f} MB"
556
+ })
557
+ with col2:
558
+ analysis_tab, clinical_logic_tab, insights_tab, reports_tab, knowledge_tab = st.tabs([
559
+ "Data Analysis",
560
+ "Clinical Logic",
561
+ "Insights",
562
+ "Reports",
563
+ "Medical Knowledge"
564
+ ])
565
+
566
+ with analysis_tab:
567
+ if selected_data_key:
568
+ analysis_type = st.selectbox("Select Analysis Mode", [
569
+ "Exploratory Data Analysis",
570
+ "Temporal Pattern Analysis",
571
+ "Comparative Statistics",
572
+ "Distribution Analysis",
573
+ "Train Logistic Regression Model"
574
+ ])
575
+ data = st.session_state.data[selected_data_key]
576
+ if analysis_type == "Exploratory Data Analysis":
577
+ analyzer = AdvancedEDA()
578
+ eda_result = analyzer.invoke(data=data)
579
+ st.subheader("Data Quality Report")
580
+ st.json(eda_result)
581
+
582
+ elif analysis_type == "Temporal Pattern Analysis":
583
+ time_col = st.selectbox("Temporal Variable",
584
+ data.select_dtypes(include='datetime').columns)
585
+ value_col = st.selectbox("Analysis Variable",
586
+ data.select_dtypes(include=np.number).columns)
587
+
588
+ if time_col and value_col:
589
+ analyzer = TemporalAnalyzer()
590
+ result = analyzer.invoke(data=data, time_col=time_col, value_col=value_col)
591
+ if "visualization" in result:
592
+ st.image(f"data:image/png;base64,{result['visualization']}")
593
+ st.json(result)
594
+
595
+ elif analysis_type == "Comparative Statistics":
596
+ group_col = st.selectbox("Grouping Variable",
597
+ data.select_dtypes(include='category').columns)
598
+ value_col = st.selectbox("Metric Variable",
599
+ data.select_dtypes(include=np.number).columns)
600
+
601
+ if group_col and value_col:
602
+ analyzer = HypothesisTester()
603
+ result = analyzer.invoke(data=data, group_col=group_col, value_col=value_col)
604
+ st.subheader("Statistical Test Results")
605
+ st.json(result)
606
+
607
+ elif analysis_type == "Distribution Analysis":
608
+ num_cols = data.select_dtypes(include=np.number).columns.tolist()
609
+ selected_cols = st.multiselect("Select Variables", num_cols)
610
+ if selected_cols:
611
+ analyzer = DistributionVisualizer()
612
+ img_data = analyzer.invoke(data=data, columns=selected_cols)
613
+ st.image(f"data:image/png;base64,{img_data}")
614
+
615
+ elif analysis_type == "Train Logistic Regression Model":
616
+ num_cols = data.select_dtypes(include=np.number).columns.tolist()
617
+ target_col = st.selectbox("Select Target Variable",
618
+ data.columns.tolist())
619
+ selected_cols = st.multiselect("Select Feature Variables", num_cols)
620
+ if selected_cols and target_col:
621
+ analyzer = LogisticRegressionTrainer()
622
+ result = analyzer.invoke(data=data, target_col=target_col, columns=selected_cols)
623
+ st.subheader("Logistic Regression Model Results")
624
+ st.json(result)
625
+ with clinical_logic_tab:
626
+ st.header("Clinical Logic")
627
+ st.subheader("Clinical Rules")
628
+ rule_name = st.text_input("Enter Rule Name")
629
+ condition = st.text_area("Enter Rule Condition (use 'df' for data frame), Example df['blood_pressure'] > 140")
630
+ action = st.text_area("Enter Action to be Taken on Rule Match")
631
+ severity = st.selectbox("Enter Severity for the Rule", ["low","medium","high"])
632
+ if st.button("Add Clinical Rule"):
633
+ try:
634
+ rule = ClinicalRule(name=rule_name, condition=condition, action=action, severity=severity)
635
+ st.session_state.clinical_rules.add_rule(rule)
636
+ st.success("Added Clinical Rule")
637
+ except Exception as e:
638
+ st.error(f"Error in rule definition: {e}")
639
+
640
+ st.subheader("Clinical KPI Definition")
641
+ kpi_name = st.text_input("Enter KPI name")
642
+ kpi_calculation = st.text_area("Enter KPI calculation (use 'df' for data frame), Example df['patient_count'].sum()")
643
+ threshold = st.text_input("Enter Threshold for KPI")
644
+ if st.button("Add Clinical KPI"):
645
+ try:
646
+ threshold_value = float(threshold) if threshold else None
647
+ kpi = ClinicalKPI(name=kpi_name, calculation=kpi_calculation, threshold=threshold_value)
648
+ st.session_state.kpi_monitoring.add_kpi(kpi)
649
+ st.success(f"Added KPI {kpi_name}")
650
+ except Exception as e:
651
+ st.error(f"Error creating KPI: {e}")
652
+
653
+ if selected_data_key:
654
+ data = st.session_state.data[selected_data_key]
655
+ if st.button("Execute Clinical Rules"):
656
+ with st.spinner("Executing Clinical Rules.."):
657
+ result = st.session_state.clinical_rules.execute_rules(data)
658
+ st.json(result)
659
+ if st.button("Calculate Clinical KPIs"):
660
+ with st.spinner("Calculating Clinical KPIs..."):
661
+ result = st.session_state.kpi_monitoring.calculate_kpis(data)
662
+ st.json(result)
663
+ with insights_tab:
664
+ if selected_data_key:
665
+ data = st.session_state.data[selected_data_key]
666
+ available_analysis = ["EDA", "temporal", "distribution", "hypothesis", "model"]
667
+ selected_analysis = st.multiselect("Select Analysis", available_analysis)
668
+ if st.button("Generate Automated Insights"):
669
+ with st.spinner("Generating Insights"):
670
+ results = st.session_state.automated_insights.generate_insights(data, analysis_names=selected_analysis)
671
+ st.json(results)
672
+ st.subheader("Diagnosis Support")
673
+ target_col = st.selectbox("Select Target Variable for Diagnosis", data.columns.tolist())
674
+ num_cols = data.select_dtypes(include=np.number).columns.tolist()
675
+ selected_cols_diagnosis = st.multiselect("Select Feature Variables for Diagnosis", num_cols)
676
+ if st.button("Generate Diagnosis"):
677
+ if target_col and selected_cols_diagnosis:
678
+ with st.spinner("Generating Diagnosis"):
679
+ result = st.session_state.diagnosis_support.diagnose(data, target_col=target_col, columns=selected_cols_diagnosis)
680
+ st.json(result)
681
+
682
+ st.subheader("Treatment Recommendation")
683
+ condition_col = st.selectbox("Select Condition Column for Treatment Recommendation", data.columns.tolist())
684
+ treatment_col = st.selectbox("Select Treatment Column for Treatment Recommendation", data.columns.tolist())
685
+ if st.button("Generate Treatment Recommendation"):
686
+ if condition_col and treatment_col:
687
+ with st.spinner("Generating Treatment Recommendation"):
688
+ result = st.session_state.treatment_recommendation.recommend(data, condition_col = condition_col, treatment_col = treatment_col)
689
+ st.json(result)
690
+
691
+ with reports_tab:
692
+ st.header("Reports")
693
+ report_name = st.text_input("Report Name")
694
+ report_def = st.text_area("Report definition")
695
+ if st.button("Create Report Definition"):
696
+ st.session_state.automated_reports.create_report_definition(report_name, report_def)
697
+ st.success("Report definition created")
698
+ if selected_data_key:
699
+ data = st.session_state.data
700
+ if st.button("Generate Report"):
701
+ with st.spinner("Generating Report..."):
702
+ report = st.session_state.automated_reports.generate_report(report_name, data)
703
+ with knowledge_tab:
704
+ st.header("Medical Knowledge")
705
+ query = st.text_input("Enter your medical question here:")
706
+ if st.button("Search"):
707
+ with st.spinner("Searching..."):
708
+ result = st.session_state.knowledge_base.search_medical_info(query)
709
+ st.write(result)
710
+
711
+ if __name__ == "__main__":
712
+ main()