mgbam commited on
Commit
936885a
·
verified ·
1 Parent(s): e91216d

Upload 3 files

Browse files
Files changed (3) hide show
  1. .gitattributes +35 -35
  2. 1.0.0 +16 -0
  3. app.py +1165 -1165
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
1.0.0 ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Requirement already satisfied: openai in c:\users\adminidiakhoa\appdata\local\programs\python\python312\lib\site-packages (1.60.1)
2
+ Requirement already satisfied: anyio<5,>=3.5.0 in c:\users\adminidiakhoa\appdata\local\programs\python\python312\lib\site-packages (from openai) (4.8.0)
3
+ Requirement already satisfied: distro<2,>=1.7.0 in c:\users\adminidiakhoa\appdata\local\programs\python\python312\lib\site-packages (from openai) (1.9.0)
4
+ Requirement already satisfied: httpx<1,>=0.23.0 in c:\users\adminidiakhoa\appdata\local\programs\python\python312\lib\site-packages (from openai) (0.27.2)
5
+ Requirement already satisfied: jiter<1,>=0.4.0 in c:\users\adminidiakhoa\appdata\local\programs\python\python312\lib\site-packages (from openai) (0.8.2)
6
+ Requirement already satisfied: pydantic<3,>=1.9.0 in c:\users\adminidiakhoa\appdata\local\programs\python\python312\lib\site-packages (from openai) (2.10.6)
7
+ Requirement already satisfied: sniffio in c:\users\adminidiakhoa\appdata\local\programs\python\python312\lib\site-packages (from openai) (1.3.1)
8
+ Requirement already satisfied: tqdm>4 in c:\users\adminidiakhoa\appdata\local\programs\python\python312\lib\site-packages (from openai) (4.67.1)
9
+ Requirement already satisfied: typing-extensions<5,>=4.11 in c:\users\adminidiakhoa\appdata\local\programs\python\python312\lib\site-packages (from openai) (4.12.2)
10
+ Requirement already satisfied: idna>=2.8 in c:\users\adminidiakhoa\appdata\local\programs\python\python312\lib\site-packages (from anyio<5,>=3.5.0->openai) (3.10)
11
+ Requirement already satisfied: certifi in c:\users\adminidiakhoa\appdata\local\programs\python\python312\lib\site-packages (from httpx<1,>=0.23.0->openai) (2024.12.14)
12
+ Requirement already satisfied: httpcore==1.* in c:\users\adminidiakhoa\appdata\local\programs\python\python312\lib\site-packages (from httpx<1,>=0.23.0->openai) (1.0.7)
13
+ Requirement already satisfied: h11<0.15,>=0.13 in c:\users\adminidiakhoa\appdata\local\programs\python\python312\lib\site-packages (from httpcore==1.*->httpx<1,>=0.23.0->openai) (0.14.0)
14
+ Requirement already satisfied: annotated-types>=0.6.0 in c:\users\adminidiakhoa\appdata\local\programs\python\python312\lib\site-packages (from pydantic<3,>=1.9.0->openai) (0.7.0)
15
+ Requirement already satisfied: pydantic-core==2.27.2 in c:\users\adminidiakhoa\appdata\local\programs\python\python312\lib\site-packages (from pydantic<3,>=1.9.0->openai) (2.27.2)
16
+ Requirement already satisfied: colorama in c:\users\adminidiakhoa\appdata\local\programs\python\python312\lib\site-packages (from tqdm>4->openai) (0.4.6)
app.py CHANGED
@@ -1,1165 +1,1165 @@
1
- import os
2
- import json
3
- import base64
4
- import io
5
- import ast
6
- import logging
7
- from abc import ABC, abstractmethod
8
- from typing import Dict, List, Optional, Any
9
-
10
- import numpy as np
11
- import pandas as pd
12
- import matplotlib.pyplot as plt
13
- import seaborn as sns
14
- import streamlit as st
15
- import spacy
16
-
17
- from scipy.stats import ttest_ind, f_oneway
18
- from sklearn.model_selection import train_test_split
19
- from sklearn.linear_model import LogisticRegression
20
- from sklearn.metrics import accuracy_score
21
-
22
- from statsmodels.tsa.seasonal import seasonal_decompose
23
- from statsmodels.tsa.stattools import adfuller
24
-
25
- from pydantic import BaseModel, Field
26
- from Bio import Entrez # Ensure BioPython is installed
27
-
28
- from dotenv import load_dotenv
29
- import requests
30
- import openai # Updated for OpenAI SDK v1.0.0+
31
- from openai.error import APIError, RateLimitError, InvalidRequestError
32
-
33
- # ---------------------- Load Environment Variables ---------------------------
34
- load_dotenv()
35
-
36
- # ---------------------- Logging Configuration ---------------------------
37
- logging.basicConfig(
38
- filename='app.log',
39
- filemode='a',
40
- format='%(asctime)s - %(levelname)s - %(message)s',
41
- level=logging.INFO
42
- )
43
- logger = logging.getLogger()
44
-
45
- # ---------------------- Streamlit Page Configuration ---------------------------
46
- st.set_page_config(page_title="AI Clinical Intelligence Hub", layout="wide")
47
-
48
- # ---------------------- Initialize OpenAI SDK ---------------------------
49
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
50
- PUB_EMAIL = os.getenv("PUB_EMAIL", "")
51
-
52
- if not OPENAI_API_KEY:
53
- st.error("OpenAI API key must be set as an environment variable (OPENAI_API_KEY).")
54
- st.stop()
55
-
56
- # Set the OpenAI API key
57
- openai.api_key = OPENAI_API_KEY
58
-
59
- # ---------------------- Load spaCy Model ---------------------------
60
- try:
61
- nlp = spacy.load("en_core_web_sm")
62
- except OSError:
63
- # Avoid using Streamlit commands before set_page_config()
64
- import subprocess
65
- import sys
66
- subprocess.run([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
67
- nlp = spacy.load("en_core_web_sm")
68
-
69
- # ---------------------- Base Classes and Schemas ---------------------------
70
-
71
- class ResearchInput(BaseModel):
72
- """Base schema for research tool inputs."""
73
- data_key: str = Field(..., description="Session state key containing DataFrame")
74
- columns: Optional[List[str]] = Field(None, description="List of columns to analyze")
75
-
76
- class TemporalAnalysisInput(ResearchInput):
77
- """Schema for temporal analysis."""
78
- time_col: str = Field(..., description="Name of timestamp column")
79
- value_col: str = Field(..., description="Name of value column to analyze")
80
-
81
- class HypothesisInput(ResearchInput):
82
- """Schema for hypothesis testing."""
83
- group_col: str = Field(..., description="Categorical column defining groups")
84
- value_col: str = Field(..., description="Numerical column to compare")
85
-
86
- class ModelTrainingInput(ResearchInput):
87
- """Schema for model training."""
88
- target_col: str = Field(..., description="Name of target column")
89
-
90
- class DataAnalyzer(ABC):
91
- """Abstract base class for data analysis modules."""
92
- @abstractmethod
93
- def invoke(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
94
- pass
95
-
96
- # ---------------------- Concrete Analyzer Implementations ---------------------------
97
-
98
- class AdvancedEDA(DataAnalyzer):
99
- """Comprehensive Exploratory Data Analysis."""
100
- def invoke(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
101
- try:
102
- analysis = {
103
- "dimensionality": {
104
- "rows": len(data),
105
- "columns": list(data.columns),
106
- "memory_usage_MB": f"{data.memory_usage().sum() / 1e6:.2f} MB"
107
- },
108
- "statistical_profile": data.describe(percentiles=[.25, .5, .75]).to_dict(),
109
- "temporal_analysis": {
110
- "date_ranges": {
111
- col: {
112
- "min": data[col].min(),
113
- "max": data[col].max()
114
- } for col in data.select_dtypes(include='datetime').columns
115
- }
116
- },
117
- "data_quality": {
118
- "missing_values": data.isnull().sum().to_dict(),
119
- "duplicates": data.duplicated().sum(),
120
- "cardinality": {
121
- col: data[col].nunique() for col in data.columns
122
- }
123
- }
124
- }
125
- return analysis
126
- except Exception as e:
127
- logger.error(f"EDA Failed: {str(e)}")
128
- return {"error": f"EDA Failed: {str(e)}"}
129
-
130
- class DistributionVisualizer(DataAnalyzer):
131
- """Distribution visualizations."""
132
- def invoke(self, data: pd.DataFrame, columns: List[str], **kwargs) -> str:
133
- try:
134
- plt.figure(figsize=(12, 6))
135
- for i, col in enumerate(columns, 1):
136
- plt.subplot(1, len(columns), i)
137
- sns.histplot(data[col], kde=True, stat="density")
138
- plt.title(f'Distribution of {col}', fontsize=10)
139
- plt.xticks(fontsize=8)
140
- plt.yticks(fontsize=8)
141
- plt.tight_layout()
142
-
143
- buf = io.BytesIO()
144
- plt.savefig(buf, format='png', dpi=300, bbox_inches='tight')
145
- plt.close()
146
- return base64.b64encode(buf.getvalue()).decode()
147
- except Exception as e:
148
- logger.error(f"Visualization Error: {str(e)}")
149
- return f"Visualization Error: {str(e)}"
150
-
151
- class TemporalAnalyzer(DataAnalyzer):
152
- """Time series analysis."""
153
- def invoke(self, data: pd.DataFrame, time_col: str, value_col: str, **kwargs) -> Dict[str, Any]:
154
- try:
155
- ts_data = data.set_index(pd.to_datetime(data[time_col]))[value_col]
156
- decomposition = seasonal_decompose(ts_data, period=365)
157
-
158
- plt.figure(figsize=(12, 8))
159
- decomposition.plot()
160
- plt.tight_layout()
161
-
162
- buf = io.BytesIO()
163
- plt.savefig(buf, format='png')
164
- plt.close()
165
- plot_data = base64.b64encode(buf.getvalue()).decode()
166
-
167
- stationarity_p_value = adfuller(ts_data)[1]
168
-
169
- return {
170
- "trend_statistics": {
171
- "stationarity_p_value": stationarity_p_value,
172
- "seasonality_strength": float(max(decomposition.seasonal))
173
- },
174
- "visualization": plot_data
175
- }
176
- except Exception as e:
177
- logger.error(f"Temporal Analysis Failed: {str(e)}")
178
- return {"error": f"Temporal Analysis Failed: {str(e)}"}
179
-
180
- class HypothesisTester(DataAnalyzer):
181
- """Statistical hypothesis testing."""
182
- def invoke(self, data: pd.DataFrame, group_col: str, value_col: str, **kwargs) -> Dict[str, Any]:
183
- try:
184
- groups = data[group_col].unique()
185
-
186
- if len(groups) < 2:
187
- return {"error": "Insufficient groups for comparison"}
188
-
189
- group_data = [data[data[group_col] == g][value_col] for g in groups]
190
-
191
- if len(groups) == 2:
192
- stat, p = ttest_ind(*group_data)
193
- test_type = "Independent t-test"
194
- effect_size = self.calculate_cohens_d(group_data[0], group_data[1])
195
- else:
196
- stat, p = f_oneway(*group_data)
197
- test_type = "ANOVA"
198
- effect_size = None
199
-
200
- return {
201
- "test_type": test_type,
202
- "test_statistic": stat,
203
- "p_value": p,
204
- "effect_size": effect_size,
205
- "interpretation": self.interpret_p_value(p)
206
- }
207
- except Exception as e:
208
- logger.error(f"Hypothesis Testing Failed: {str(e)}")
209
- return {"error": f"Hypothesis Testing Failed: {str(e)}"}
210
-
211
- @staticmethod
212
- def calculate_cohens_d(x: pd.Series, y: pd.Series) -> Optional[float]:
213
- """Calculate Cohen's d for effect size."""
214
- try:
215
- mean_diff = abs(x.mean() - y.mean())
216
- pooled_std = np.sqrt((x.var() + y.var()) / 2)
217
- return mean_diff / pooled_std
218
- except Exception as e:
219
- logger.error(f"Error calculating Cohen's d: {str(e)}")
220
- return None
221
-
222
- @staticmethod
223
- def interpret_p_value(p: float) -> str:
224
- """Interpret the p-value."""
225
- if p < 0.001:
226
- return "Very strong evidence against H0"
227
- elif p < 0.01:
228
- return "Strong evidence against H0"
229
- elif p < 0.05:
230
- return "Evidence against H0"
231
- elif p < 0.1:
232
- return "Weak evidence against H0"
233
- else:
234
- return "No significant evidence against H0"
235
-
236
- class LogisticRegressionTrainer(DataAnalyzer):
237
- """Logistic Regression Model Trainer."""
238
- def invoke(self, data: pd.DataFrame, target_col: str, columns: List[str], **kwargs) -> Dict[str, Any]:
239
- try:
240
- X = data[columns]
241
- y = data[target_col]
242
- X_train, X_test, y_train, y_test = train_test_split(
243
- X, y, test_size=0.2, random_state=42
244
- )
245
- model = LogisticRegression(max_iter=1000)
246
- model.fit(X_train, y_train)
247
- y_pred = model.predict(X_test)
248
- accuracy = accuracy_score(y_test, y_pred)
249
- return {
250
- "model_type": "Logistic Regression",
251
- "accuracy": accuracy,
252
- "model_params": model.get_params()
253
- }
254
- except Exception as e:
255
- logger.error(f"Logistic Regression Model Error: {str(e)}")
256
- return {"error": f"Logistic Regression Model Error: {str(e)}"}
257
-
258
- # ---------------------- Business Logic Layer ---------------------------
259
-
260
- class ClinicalRule(BaseModel):
261
- """Defines a clinical rule."""
262
- name: str
263
- condition: str
264
- action: str
265
- severity: str # low, medium, or high
266
-
267
- class ClinicalRulesEngine:
268
- """Executes rules against patient data."""
269
- def __init__(self):
270
- self.rules: Dict[str, ClinicalRule] = {}
271
-
272
- def add_rule(self, rule: ClinicalRule):
273
- self.rules[rule.name] = rule
274
-
275
- def execute_rules(self, data: pd.DataFrame) -> Dict[str, Any]:
276
- results = {}
277
- for rule_name, rule in self.rules.items():
278
- try:
279
- # Using safe_eval instead of eval for security
280
- rule_matched = self.safe_eval(rule.condition, {"df": data})
281
- results[rule_name] = {
282
- "rule_matched": rule_matched,
283
- "action": rule.action if rule_matched else None,
284
- "severity": rule.severity if rule_matched else None
285
- }
286
- except Exception as e:
287
- logger.error(f"Error executing rule '{rule_name}': {str(e)}")
288
- results[rule_name] = {
289
- "rule_matched": False,
290
- "error": str(e),
291
- "severity": None
292
- }
293
- return results
294
-
295
- @staticmethod
296
- def safe_eval(expr, variables):
297
- """
298
- Safely evaluate an expression using AST parsing.
299
- Only allows certain node types to prevent execution of arbitrary code.
300
- """
301
- allowed_nodes = (
302
- ast.Expression, ast.BoolOp, ast.BinOp, ast.UnaryOp, ast.Compare,
303
- ast.Call, ast.Name, ast.Load, ast.Constant, ast.Num, ast.Str,
304
- ast.List, ast.Tuple, ast.Dict
305
- )
306
- try:
307
- node = ast.parse(expr, mode='eval')
308
- for subnode in ast.walk(node):
309
- if not isinstance(subnode, allowed_nodes):
310
- raise ValueError(f"Unsupported expression: {expr}")
311
- return eval(compile(node, '<string>', mode='eval'), {"__builtins__": None}, variables)
312
- except Exception as e:
313
- logger.error(f"safe_eval error: {str(e)}")
314
- raise ValueError(f"Invalid expression: {e}")
315
-
316
- class ClinicalKPI(BaseModel):
317
- """Define a clinical KPI."""
318
- name: str
319
- calculation: str
320
- threshold: Optional[float] = None
321
-
322
- class ClinicalKPIMonitoring:
323
- """Calculates KPIs based on data."""
324
- def __init__(self):
325
- self.kpis: Dict[str, ClinicalKPI] = {}
326
-
327
- def add_kpi(self, kpi: ClinicalKPI):
328
- self.kpis[kpi.name] = kpi
329
-
330
- def calculate_kpis(self, data: pd.DataFrame) -> Dict[str, Any]:
331
- results = {}
332
- for kpi_name, kpi in self.kpis.items():
333
- try:
334
- # Using safe_eval instead of eval for security
335
- kpi_value = self.safe_eval(kpi.calculation, {"df": data})
336
- status = self.evaluate_threshold(kpi_value, kpi.threshold)
337
- results[kpi_name] = {
338
- "value": kpi_value,
339
- "threshold": kpi.threshold,
340
- "status": status
341
- }
342
- except Exception as e:
343
- logger.error(f"Error calculating KPI '{kpi_name}': {str(e)}")
344
- results[kpi_name] = {"error": str(e)}
345
- return results
346
-
347
- @staticmethod
348
- def evaluate_threshold(value: Any, threshold: Optional[float]) -> Optional[str]:
349
- if threshold is None:
350
- return None
351
- try:
352
- return "Above Threshold" if value > threshold else "Below Threshold"
353
- except TypeError:
354
- return "Threshold Evaluation Not Applicable"
355
-
356
- @staticmethod
357
- def safe_eval(expr, variables):
358
- """
359
- Safely evaluate an expression using AST parsing.
360
- Only allows certain node types to prevent execution of arbitrary code.
361
- """
362
- allowed_nodes = (
363
- ast.Expression, ast.BoolOp, ast.BinOp, ast.UnaryOp, ast.Compare,
364
- ast.Call, ast.Name, ast.Load, ast.Constant, ast.Num, ast.Str,
365
- ast.List, ast.Tuple, ast.Dict
366
- )
367
- try:
368
- node = ast.parse(expr, mode='eval')
369
- for subnode in ast.walk(node):
370
- if not isinstance(subnode, allowed_nodes):
371
- raise ValueError(f"Unsupported expression: {expr}")
372
- return eval(compile(node, '<string>', mode='eval'), {"__builtins__": None}, variables)
373
- except Exception as e:
374
- logger.error(f"safe_eval error: {str(e)}")
375
- raise ValueError(f"Invalid expression: {e}")
376
-
377
- class DiagnosisSupport(ABC):
378
- """Abstract class for implementing clinical diagnoses."""
379
- @abstractmethod
380
- def diagnose(
381
- self,
382
- data: pd.DataFrame,
383
- target_col: str,
384
- columns: List[str],
385
- diagnosis_key: str = "diagnosis",
386
- **kwargs
387
- ) -> pd.DataFrame:
388
- pass
389
-
390
- class SimpleDiagnosis(DiagnosisSupport):
391
- """Provides a simple diagnosis example, based on the Logistic regression model."""
392
- def __init__(self):
393
- self.model_trainer: LogisticRegressionTrainer = LogisticRegressionTrainer()
394
-
395
- def diagnose(
396
- self,
397
- data: pd.DataFrame,
398
- target_col: str,
399
- columns: List[str],
400
- diagnosis_key: str = "diagnosis",
401
- **kwargs
402
- ) -> pd.DataFrame:
403
- try:
404
- result = self.model_trainer.invoke(data, target_col=target_col, columns=columns)
405
- if "accuracy" in result:
406
- return pd.DataFrame({
407
- diagnosis_key: [f"Model Accuracy: {result['accuracy']:.2%}"],
408
- "model": [result["model_type"]]
409
- })
410
- else:
411
- return pd.DataFrame({
412
- diagnosis_key: [f"Diagnosis failed: {result.get('error', 'Unknown error')}"]
413
- })
414
- except Exception as e:
415
- logger.error(f"Error during diagnosis: {str(e)}")
416
- return pd.DataFrame({
417
- diagnosis_key: [f"Error during diagnosis: {e}"]
418
- })
419
-
420
- class TreatmentRecommendation(ABC):
421
- """Abstract class for treatment recommendations."""
422
- @abstractmethod
423
- def recommend(
424
- self,
425
- data: pd.DataFrame,
426
- condition_col: str,
427
- treatment_col: str,
428
- recommendation_key: str = "recommendation",
429
- **kwargs
430
- ) -> pd.DataFrame:
431
- pass
432
-
433
- class BasicTreatmentRecommendation(TreatmentRecommendation):
434
- """A placeholder class for basic treatment recommendations."""
435
- def recommend(
436
- self,
437
- data: pd.DataFrame,
438
- condition_col: str,
439
- treatment_col: str,
440
- recommendation_key: str = "recommendation",
441
- **kwargs
442
- ) -> pd.DataFrame:
443
- if condition_col not in data.columns or treatment_col not in data.columns:
444
- logger.warning(f"Condition or Treatment columns not found: {condition_col}, {treatment_col}")
445
- return pd.DataFrame({
446
- recommendation_key: ["Condition or Treatment columns not found!"]
447
- })
448
-
449
- treatment = data[data[condition_col] == "High"][treatment_col].to_list()
450
- if treatment:
451
- return pd.DataFrame({
452
- recommendation_key: [f"Treatment recommended for High risk patients: {treatment}"]
453
- })
454
- else:
455
- return pd.DataFrame({
456
- recommendation_key: ["No treatment recommendation found!"]
457
- })
458
-
459
- # ---------------------- Medical Knowledge Base ---------------------------
460
-
461
- class MedicalKnowledgeBase(ABC):
462
- """Abstract class for Medical Knowledge."""
463
- @abstractmethod
464
- def search_medical_info(self, query: str, pub_email: str = "") -> str:
465
- pass
466
-
467
- class SimpleMedicalKnowledge(MedicalKnowledgeBase):
468
- """Enhanced Medical Knowledge Class using OpenAI GPT-4."""
469
- def __init__(self, nlp_model):
470
- self.nlp = nlp_model # Using the loaded spaCy model
471
-
472
- def search_medical_info(self, query: str, pub_email: str = "") -> str:
473
- """
474
- Uses OpenAI's GPT-4 to fetch medical information based on the user's query.
475
- """
476
- logger.info(f"Received medical query: {query}")
477
- try:
478
- # Preprocess the query (e.g., entity recognition)
479
- doc = self.nlp(query.lower())
480
- entities = [ent.text for ent in doc.ents]
481
- processed_query = " ".join(entities) if entities else query.lower()
482
-
483
- logger.info(f"Processed query: {processed_query}")
484
-
485
- # Create a prompt for GPT-4
486
- prompt = f"""
487
- You are a medical assistant. Provide a comprehensive and accurate response to the following medical query:
488
-
489
- Query: {processed_query}
490
-
491
- Please ensure the information is clear, concise, and evidence-based.
492
- """
493
-
494
- # Make the API request to OpenAI GPT-4
495
- response = openai.ChatCompletion.create(
496
- model="gpt-4",
497
- messages=[
498
- {"role": "system", "content": "You are a helpful medical assistant."},
499
- {"role": "user", "content": prompt}
500
- ],
501
- max_tokens=500,
502
- temperature=0.7,
503
- )
504
-
505
- # Extract the answer from the response
506
- answer = response.choices[0].message['content'].strip()
507
-
508
- logger.info("Successfully retrieved data from OpenAI GPT-4.")
509
-
510
- # Fetch PubMed abstract related to the query
511
- pubmed_abstract = self.fetch_pubmed_abstract(processed_query, pub_email)
512
-
513
- # Format the response
514
- return f"**Based on your query:** {answer}\n\n**PubMed Abstract:**\n\n{pubmed_abstract}"
515
-
516
- except RateLimitError as e:
517
- logger.error(f"Rate Limit Exceeded: {str(e)}")
518
- return "Rate limit exceeded. Please try again later."
519
- except InvalidRequestError as e:
520
- logger.error(f"Invalid Request: {str(e)}")
521
- return f"Invalid request: {str(e)}"
522
- except APIError as e:
523
- logger.error(f"OpenAI API Error: {str(e)}")
524
- return f"OpenAI API Error: {str(e)}"
525
- except Exception as e:
526
- logger.error(f"Medical Knowledge Search Failed: {str(e)}")
527
- return f"Medical Knowledge Search Failed: {str(e)}"
528
-
529
- def fetch_pubmed_abstract(self, query: str, email: str) -> str:
530
- """
531
- Searches PubMed for abstracts related to the query.
532
- """
533
- try:
534
- if not email:
535
- logger.warning("PubMed abstract retrieval skipped: Email not provided.")
536
- return "No PubMed abstract available: Email not provided."
537
-
538
- Entrez.email = email
539
- handle = Entrez.esearch(db="pubmed", term=query, retmax=1, sort='relevance')
540
- record = Entrez.read(handle)
541
- handle.close()
542
- logger.info(f"PubMed search for query '{query}' returned IDs: {record['IdList']}")
543
-
544
- if record["IdList"]:
545
- handle = Entrez.efetch(db="pubmed", id=record["IdList"][0], rettype="abstract", retmode="text")
546
- abstract = handle.read()
547
- handle.close()
548
- logger.info(f"Fetched PubMed abstract for ID {record['IdList'][0]}")
549
- return abstract
550
- else:
551
- logger.info(f"No PubMed abstracts found for query '{query}'.")
552
- return "No abstracts found for this query on PubMed."
553
- except Exception as e:
554
- logger.error(f"Error searching PubMed: {e}")
555
- return f"Error searching PubMed: {e}"
556
-
557
- # ---------------------- Forecasting Engine ---------------------------
558
-
559
- class ForecastingEngine(ABC):
560
- """Abstract class for forecasting."""
561
- @abstractmethod
562
- def predict(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame:
563
- pass
564
-
565
- class SimpleForecasting(ForecastingEngine):
566
- """Simple forecasting engine."""
567
- def predict(self, data: pd.DataFrame, period: int = 7, **kwargs) -> pd.DataFrame:
568
- # Placeholder for actual forecasting logic
569
- return pd.DataFrame({"forecast": [f"Forecast for the next {period} days"]})
570
-
571
- # ---------------------- Insights and Reporting Layer ---------------------------
572
-
573
- class AutomatedInsights:
574
- """Generates automated insights based on selected analyses."""
575
- def __init__(self):
576
- self.analyses: Dict[str, DataAnalyzer] = {
577
- "EDA": AdvancedEDA(),
578
- "temporal": TemporalAnalyzer(),
579
- "distribution": DistributionVisualizer(),
580
- "hypothesis": HypothesisTester(),
581
- "model": LogisticRegressionTrainer()
582
- }
583
-
584
- def generate_insights(self, data: pd.DataFrame, analysis_names: List[str], **kwargs) -> Dict[str, Any]:
585
- results = {}
586
- for name in analysis_names:
587
- analyzer = self.analyses.get(name)
588
- if analyzer:
589
- try:
590
- results[name] = analyzer.invoke(data=data, **kwargs)
591
- except Exception as e:
592
- logger.error(f"Error in analysis '{name}': {str(e)}")
593
- results[name] = {"error": str(e)}
594
- else:
595
- logger.warning(f"Analysis '{name}' not found.")
596
- results[name] = {"error": "Analysis not found"}
597
- return results
598
-
599
- class Dashboard:
600
- """Handles the creation and display of the dashboard."""
601
- def __init__(self):
602
- self.layout: Dict[str, str] = {}
603
-
604
- def add_visualisation(self, vis_name: str, vis_type: str):
605
- self.layout[vis_name] = vis_type
606
-
607
- def display_dashboard(self, data_dict: Dict[str, pd.DataFrame]):
608
- st.header("Dashboard")
609
- for vis_name, vis_type in self.layout.items():
610
- st.subheader(vis_name)
611
- df = data_dict.get(vis_name)
612
- if df is not None:
613
- if vis_type == "table":
614
- st.table(df)
615
- elif vis_type == "plot":
616
- if len(df.columns) > 1:
617
- fig = plt.figure()
618
- sns.lineplot(data=df)
619
- st.pyplot(fig)
620
- else:
621
- st.write("Please select a DataFrame with more than 1 column for plotting.")
622
- else:
623
- st.write("Data Not Found")
624
-
625
- class AutomatedReports:
626
- """Manages automated report definitions and generation."""
627
- def __init__(self):
628
- self.report_definitions: Dict[str, str] = {}
629
-
630
- def create_report_definition(self, report_name: str, definition: str):
631
- self.report_definitions[report_name] = definition
632
-
633
- def generate_report(self, report_name: str, data: Dict[str, pd.DataFrame]) -> Dict[str, Any]:
634
- if report_name not in self.report_definitions:
635
- return {"error": "Report name not found"}
636
- report_content = {
637
- "Report Name": report_name,
638
- "Report Definition": self.report_definitions[report_name],
639
- "Data": {df_name: df.to_dict() for df_name, df in data.items()}
640
- }
641
- return report_content
642
-
643
- # ---------------------- Data Acquisition Layer ---------------------------
644
-
645
- class DataSource(ABC):
646
- """Base class for data sources."""
647
- @abstractmethod
648
- def connect(self) -> None:
649
- """Connect to the data source."""
650
- pass
651
-
652
- @abstractmethod
653
- def fetch_data(self, query: str, **kwargs) -> pd.DataFrame:
654
- """Fetch the data based on a specific query."""
655
- pass
656
-
657
- class CSVDataSource(DataSource):
658
- """Data source for CSV files."""
659
- def __init__(self, file_path: io.BytesIO):
660
- self.file_path = file_path
661
- self.data: Optional[pd.DataFrame] = None
662
-
663
- def connect(self):
664
- self.data = pd.read_csv(self.file_path)
665
-
666
- def fetch_data(self, query: str = None, **kwargs) -> pd.DataFrame:
667
- if self.data is None:
668
- raise Exception("No connection is made, call connect()")
669
- return self.data
670
-
671
- class DatabaseSource(DataSource):
672
- """Data source for SQL Databases."""
673
- def __init__(self, connection_string: str, database_type: str):
674
- self.connection_string = connection_string
675
- self.database_type = database_type.lower()
676
- self.connection = None
677
-
678
- def connect(self):
679
- if self.database_type == "sql":
680
- # Placeholder for actual SQL connection logic
681
- self.connection = "Connected to SQL Database"
682
- else:
683
- raise Exception(f"Database type '{self.database_type}' is not supported.")
684
-
685
- def fetch_data(self, query: str, **kwargs) -> pd.DataFrame:
686
- if self.connection is None:
687
- raise Exception("No connection is made, call connect()")
688
- # Placeholder for data fetching logic
689
- return pd.DataFrame({"result": [f"Fetched data based on query: {query}"]})
690
-
691
- class DataIngestion:
692
- """Handles data ingestion from various sources."""
693
- def __init__(self):
694
- self.sources: Dict[str, DataSource] = {}
695
-
696
- def add_source(self, source_name: str, source: DataSource):
697
- self.sources[source_name] = source
698
-
699
- def ingest_data(self, source_name: str, query: str = None, **kwargs) -> pd.DataFrame:
700
- if source_name not in self.sources:
701
- raise Exception(f"Source '{source_name}' not found.")
702
- source = self.sources[source_name]
703
- source.connect()
704
- return source.fetch_data(query, **kwargs)
705
-
706
- class DataModel(BaseModel):
707
- """Defines a data model."""
708
- name: str
709
- kpis: List[str] = Field(default_factory=list)
710
- dimensions: List[str] = Field(default_factory=list)
711
- custom_calculations: Optional[Dict[str, str]] = None
712
- relations: Optional[Dict[str, str]] = None # Example: {"table1": "table2"}
713
-
714
- def to_json(self) -> str:
715
- return json.dumps(self.dict())
716
-
717
- @staticmethod
718
- def from_json(json_str: str) -> 'DataModel':
719
- return DataModel(**json.loads(json_str))
720
-
721
- class DataModelling:
722
- """Manages data models."""
723
- def __init__(self):
724
- self.models: Dict[str, DataModel] = {}
725
-
726
- def add_model(self, model: DataModel):
727
- self.models[model.name] = model
728
-
729
- def get_model(self, model_name: str) -> DataModel:
730
- if model_name not in self.models:
731
- raise Exception(f"Model '{model_name}' not found.")
732
- return self.models[model_name]
733
-
734
- # ---------------------- Main Streamlit Application ---------------------------
735
-
736
- def main():
737
- """Main function to run the Streamlit app."""
738
- st.title("🏥 AI-Powered Clinical Intelligence Hub")
739
-
740
- # Initialize Session State
741
- initialize_session_state()
742
-
743
- # Sidebar for Data Management
744
- with st.sidebar:
745
- data_management_section()
746
-
747
- # Main Content
748
- if st.session_state.data:
749
- col1, col2 = st.columns([1, 3])
750
-
751
- with col1:
752
- dataset_metadata_section()
753
-
754
- with col2:
755
- main_tabs_section()
756
-
757
- def initialize_session_state():
758
- """Initialize necessary components in Streamlit's session state."""
759
- if 'data' not in st.session_state:
760
- st.session_state.data = {} # Store pd.DataFrame under a name
761
- if 'data_ingestion' not in st.session_state:
762
- st.session_state.data_ingestion = DataIngestion()
763
- if 'data_modelling' not in st.session_state:
764
- st.session_state.data_modelling = DataModelling()
765
- if 'clinical_rules' not in st.session_state:
766
- st.session_state.clinical_rules = ClinicalRulesEngine()
767
- if 'kpi_monitoring' not in st.session_state:
768
- st.session_state.kpi_monitoring = ClinicalKPIMonitoring()
769
- if 'forecasting_engine' not in st.session_state:
770
- st.session_state.forecasting_engine = SimpleForecasting()
771
- if 'automated_insights' not in st.session_state:
772
- st.session_state.automated_insights = AutomatedInsights()
773
- if 'dashboard' not in st.session_state:
774
- st.session_state.dashboard = Dashboard()
775
- if 'automated_reports' not in st.session_state:
776
- st.session_state.automated_reports = AutomatedReports()
777
- if 'diagnosis_support' not in st.session_state:
778
- st.session_state.diagnosis_support = SimpleDiagnosis()
779
- if 'treatment_recommendation' not in st.session_state:
780
- st.session_state.treatment_recommendation = BasicTreatmentRecommendation()
781
- if 'knowledge_base' not in st.session_state:
782
- st.session_state.knowledge_base = SimpleMedicalKnowledge(nlp_model=nlp)
783
- if 'pub_email' not in st.session_state:
784
- st.session_state.pub_email = PUB_EMAIL # Load PUB_EMAIL from environment variables
785
-
786
- def data_management_section():
787
- """Handles the data management section in the sidebar."""
788
- st.header("⚙️ Data Management")
789
- data_source_selection = st.selectbox("Select Data Source Type", ["CSV", "SQL Database"])
790
-
791
- if data_source_selection == "CSV":
792
- handle_csv_upload()
793
- elif data_source_selection == "SQL Database":
794
- handle_sql_database()
795
-
796
- if st.button("Ingest Data"):
797
- ingest_data_action()
798
-
799
- def handle_csv_upload():
800
- """Handles CSV file uploads."""
801
- uploaded_file = st.file_uploader("Upload research dataset (CSV)", type=["csv"])
802
- if uploaded_file:
803
- source_name = st.text_input("Data Source Name")
804
- if source_name:
805
- try:
806
- csv_source = CSVDataSource(file_path=uploaded_file)
807
- st.session_state.data_ingestion.add_source(source_name, csv_source)
808
- st.success(f"Uploaded {uploaded_file.name} as '{source_name}'.")
809
- except Exception as e:
810
- st.error(f"Error loading dataset: {e}")
811
-
812
- def handle_sql_database():
813
- """Handles SQL database connections."""
814
- conn_str = st.text_input("Enter connection string for SQL DB")
815
- if conn_str:
816
- source_name = st.text_input("Data Source Name")
817
- if source_name:
818
- try:
819
- sql_source = DatabaseSource(connection_string=conn_str, database_type="sql")
820
- st.session_state.data_ingestion.add_source(source_name, sql_source)
821
- st.success(f"Added SQL DB Source '{source_name}'.")
822
- except Exception as e:
823
- st.error(f"Error loading database source: {e}")
824
-
825
- def ingest_data_action():
826
- """Performs data ingestion from the selected source."""
827
- if st.session_state.data_ingestion.sources:
828
- source_name_to_fetch = st.selectbox("Select Data Source to Ingest", list(st.session_state.data_ingestion.sources.keys()))
829
- query = st.text_area("Optional Query to Fetch data")
830
- if source_name_to_fetch:
831
- with st.spinner("Ingesting data..."):
832
- try:
833
- data = st.session_state.data_ingestion.ingest_data(source_name_to_fetch, query)
834
- st.session_state.data[source_name_to_fetch] = data
835
- st.success(f"Ingested data from '{source_name_to_fetch}'.")
836
- except Exception as e:
837
- st.error(f"Ingestion failed: {e}")
838
- else:
839
- st.error("No data source added. Please add a data source.")
840
-
841
- def dataset_metadata_section():
842
- """Displays metadata for the selected dataset."""
843
- st.subheader("Dataset Metadata")
844
- data_source_keys = list(st.session_state.data.keys())
845
- selected_data_key = st.selectbox("Select Dataset", data_source_keys)
846
-
847
- if selected_data_key:
848
- data = st.session_state.data[selected_data_key]
849
- metadata = {
850
- "Variables": list(data.columns),
851
- "Time Range": {
852
- col: {
853
- "min": data[col].min(),
854
- "max": data[col].max()
855
- } for col in data.select_dtypes(include='datetime').columns
856
- },
857
- "Size": f"{data.memory_usage().sum() / 1e6:.2f} MB"
858
- }
859
- st.json(metadata)
860
- # Store the selected dataset key in session state for use in analysis
861
- st.session_state.selected_data_key = selected_data_key
862
-
863
- def main_tabs_section():
864
- """Creates and manages the main tabs in the application."""
865
- analysis_tab, clinical_logic_tab, insights_tab, reports_tab, knowledge_tab = st.tabs([
866
- "Data Analysis",
867
- "Clinical Logic",
868
- "Insights",
869
- "Reports",
870
- "Medical Knowledge"
871
- ])
872
-
873
- with analysis_tab:
874
- data_analysis_section()
875
-
876
- with clinical_logic_tab:
877
- clinical_logic_section()
878
-
879
- with insights_tab:
880
- insights_section()
881
-
882
- with reports_tab:
883
- reports_section()
884
-
885
- with knowledge_tab:
886
- medical_knowledge_section()
887
-
888
- def data_analysis_section():
889
- """Handles the Data Analysis tab."""
890
- selected_data_key = st.session_state.get('selected_data_key', None)
891
- if not selected_data_key:
892
- st.warning("Please select a dataset from the metadata section.")
893
- return
894
-
895
- data = st.session_state.data[selected_data_key]
896
- analysis_type = st.selectbox("Select Analysis Mode", [
897
- "Exploratory Data Analysis",
898
- "Temporal Pattern Analysis",
899
- "Comparative Statistics",
900
- "Distribution Analysis",
901
- "Train Logistic Regression Model"
902
- ])
903
-
904
- if analysis_type == "Exploratory Data Analysis":
905
- perform_eda(data)
906
- elif analysis_type == "Temporal Pattern Analysis":
907
- perform_temporal_analysis(data)
908
- elif analysis_type == "Comparative Statistics":
909
- perform_comparative_statistics(data)
910
- elif analysis_type == "Distribution Analysis":
911
- perform_distribution_analysis(data)
912
- elif analysis_type == "Train Logistic Regression Model":
913
- perform_logistic_regression_training(data)
914
-
915
- def perform_eda(data: pd.DataFrame):
916
- """Performs Exploratory Data Analysis."""
917
- analyzer = AdvancedEDA()
918
- eda_result = analyzer.invoke(data=data)
919
- st.subheader("Data Quality Report")
920
- st.json(eda_result)
921
-
922
- def perform_temporal_analysis(data: pd.DataFrame):
923
- """Performs Temporal Pattern Analysis."""
924
- time_cols = data.select_dtypes(include='datetime').columns
925
- num_cols = data.select_dtypes(include=np.number).columns
926
-
927
- if len(time_cols) == 0:
928
- st.warning("No datetime columns available for temporal analysis.")
929
- return
930
-
931
- time_col = st.selectbox("Select Temporal Variable", time_cols)
932
- value_col = st.selectbox("Select Analysis Variable", num_cols)
933
-
934
- if time_col and value_col:
935
- analyzer = TemporalAnalyzer()
936
- result = analyzer.invoke(data=data, time_col=time_col, value_col=value_col)
937
- if "visualization" in result and result["visualization"]:
938
- st.image(f"data:image/png;base64,{result['visualization']}", use_column_width=True)
939
- st.json(result)
940
-
941
- def perform_comparative_statistics(data: pd.DataFrame):
942
- """Performs Comparative Statistics."""
943
- categorical_cols = data.select_dtypes(include=['category', 'object']).columns
944
- numeric_cols = data.select_dtypes(include=np.number).columns
945
-
946
- if len(categorical_cols) == 0:
947
- st.warning("No categorical columns available for hypothesis testing.")
948
- return
949
-
950
- if len(numeric_cols) == 0:
951
- st.warning("No numerical columns available for hypothesis testing.")
952
- return
953
-
954
- group_col = st.selectbox("Select Grouping Variable", categorical_cols)
955
- value_col = st.selectbox("Select Metric Variable", numeric_cols)
956
-
957
- if group_col and value_col:
958
- analyzer = HypothesisTester()
959
- result = analyzer.invoke(data=data, group_col=group_col, value_col=value_col)
960
- st.subheader("Statistical Test Results")
961
- st.json(result)
962
-
963
- def perform_distribution_analysis(data: pd.DataFrame):
964
- """Performs Distribution Analysis."""
965
- numeric_cols = data.select_dtypes(include=np.number).columns.tolist()
966
- selected_cols = st.multiselect("Select Variables for Distribution Analysis", numeric_cols)
967
-
968
- if selected_cols:
969
- analyzer = DistributionVisualizer()
970
- img_data = analyzer.invoke(data=data, columns=selected_cols)
971
- if not img_data.startswith("Visualization Error"):
972
- st.image(f"data:image/png;base64,{img_data}", use_column_width=True)
973
- else:
974
- st.error(img_data)
975
- else:
976
- st.info("Please select at least one numerical column to visualize.")
977
-
978
- def perform_logistic_regression_training(data: pd.DataFrame):
979
- """Trains a Logistic Regression model."""
980
- numeric_cols = data.select_dtypes(include=np.number).columns.tolist()
981
- target_col = st.selectbox("Select Target Variable", data.columns.tolist())
982
- selected_cols = st.multiselect("Select Feature Variables", numeric_cols)
983
-
984
- if selected_cols and target_col:
985
- analyzer = LogisticRegressionTrainer()
986
- result = analyzer.invoke(data=data, target_col=target_col, columns=selected_cols)
987
- st.subheader("Logistic Regression Model Results")
988
- st.json(result)
989
- else:
990
- st.warning("Please select both target and feature variables for model training.")
991
-
992
- def clinical_logic_section():
993
- """Handles the Clinical Logic tab."""
994
- st.header("Clinical Logic")
995
-
996
- # Clinical Rules Management
997
- st.subheader("Clinical Rules")
998
- rule_name = st.text_input("Enter Rule Name")
999
- condition = st.text_area("Enter Rule Condition (use 'df' for DataFrame)",
1000
- help="Example: df['blood_pressure'] > 140")
1001
- action = st.text_area("Enter Action to be Taken on Rule Match")
1002
- severity = st.selectbox("Enter Severity for the Rule", ["low", "medium", "high"])
1003
-
1004
- if st.button("Add Clinical Rule"):
1005
- if rule_name and condition and action and severity:
1006
- try:
1007
- rule = ClinicalRule(
1008
- name=rule_name,
1009
- condition=condition,
1010
- action=action,
1011
- severity=severity
1012
- )
1013
- st.session_state.clinical_rules.add_rule(rule)
1014
- st.success("Added Clinical Rule successfully.")
1015
- except Exception as e:
1016
- st.error(f"Error in rule definition: {e}")
1017
- else:
1018
- st.error("Please fill in all fields to add a clinical rule.")
1019
-
1020
- # Clinical KPI Management
1021
- st.subheader("Clinical KPI Definition")
1022
- kpi_name = st.text_input("Enter KPI Name")
1023
- kpi_calculation = st.text_area("Enter KPI Calculation (use 'df' for DataFrame)",
1024
- help="Example: df['patient_count'].sum()")
1025
- threshold = st.text_input("Enter Threshold for KPI (Optional)", help="Leave blank if not applicable")
1026
-
1027
- if st.button("Add Clinical KPI"):
1028
- if kpi_name and kpi_calculation:
1029
- try:
1030
- threshold_value = float(threshold) if threshold else None
1031
- kpi = ClinicalKPI(
1032
- name=kpi_name,
1033
- calculation=kpi_calculation,
1034
- threshold=threshold_value
1035
- )
1036
- st.session_state.kpi_monitoring.add_kpi(kpi)
1037
- st.success(f"Added KPI '{kpi_name}' successfully.")
1038
- except ValueError:
1039
- st.error("Threshold must be a numeric value.")
1040
- except Exception as e:
1041
- st.error(f"Error creating KPI: {e}")
1042
- else:
1043
- st.error("Please provide both KPI name and calculation.")
1044
-
1045
- # Execute Clinical Rules and Calculate KPIs
1046
- selected_data_key = st.selectbox("Select Dataset for Clinical Logic", list(st.session_state.data.keys()))
1047
- if selected_data_key:
1048
- data = st.session_state.data[selected_data_key]
1049
- if st.button("Execute Clinical Rules"):
1050
- with st.spinner("Executing Clinical Rules..."):
1051
- result = st.session_state.clinical_rules.execute_rules(data)
1052
- st.json(result)
1053
- if st.button("Calculate Clinical KPIs"):
1054
- with st.spinner("Calculating Clinical KPIs..."):
1055
- result = st.session_state.kpi_monitoring.calculate_kpis(data)
1056
- st.json(result)
1057
- else:
1058
- st.warning("Please ingest data to execute clinical rules and calculate KPIs.")
1059
-
1060
- def insights_section():
1061
- """Handles the Insights tab."""
1062
- st.header("Automated Insights")
1063
-
1064
- selected_data_key = st.selectbox("Select Dataset for Insights", list(st.session_state.data.keys()))
1065
- if not selected_data_key:
1066
- st.warning("Please select a dataset to generate insights.")
1067
- return
1068
-
1069
- data = st.session_state.data[selected_data_key]
1070
- available_analyses = ["EDA", "temporal", "distribution", "hypothesis", "model"]
1071
- selected_analyses = st.multiselect("Select Analyses for Insights", available_analyses)
1072
-
1073
- if st.button("Generate Automated Insights"):
1074
- if selected_analyses:
1075
- with st.spinner("Generating Insights..."):
1076
- results = st.session_state.automated_insights.generate_insights(
1077
- data, analysis_names=selected_analyses
1078
- )
1079
- st.json(results)
1080
- else:
1081
- st.warning("Please select at least one analysis to generate insights.")
1082
-
1083
- # Diagnosis Support
1084
- st.subheader("Diagnosis Support")
1085
- target_col = st.selectbox("Select Target Variable for Diagnosis", data.columns.tolist())
1086
- numeric_cols = data.select_dtypes(include=np.number).columns.tolist()
1087
- selected_feature_cols = st.multiselect("Select Feature Variables for Diagnosis", numeric_cols)
1088
-
1089
- if st.button("Generate Diagnosis"):
1090
- if target_col and selected_feature_cols:
1091
- with st.spinner("Generating Diagnosis..."):
1092
- result = st.session_state.diagnosis_support.diagnose(
1093
- data, target_col=target_col, columns=selected_feature_cols, diagnosis_key="diagnosis_result"
1094
- )
1095
- st.json(result)
1096
- else:
1097
- st.error("Please select both target and feature variables for diagnosis.")
1098
-
1099
- # Treatment Recommendation
1100
- st.subheader("Treatment Recommendation")
1101
- condition_col = st.selectbox("Select Condition Column for Treatment Recommendation", data.columns.tolist())
1102
- treatment_col = st.selectbox("Select Treatment Column for Treatment Recommendation", data.columns.tolist())
1103
-
1104
- if st.button("Generate Treatment Recommendation"):
1105
- if condition_col and treatment_col:
1106
- with st.spinner("Generating Treatment Recommendation..."):
1107
- result = st.session_state.treatment_recommendation.recommend(
1108
- data, condition_col=condition_col, treatment_col=treatment_col, recommendation_key="treatment_recommendation"
1109
- )
1110
- st.json(result)
1111
- else:
1112
- st.error("Please select both condition and treatment columns.")
1113
-
1114
- def reports_section():
1115
- """Handles the Reports tab."""
1116
- st.header("Automated Reports")
1117
-
1118
- # Create Report Definition
1119
- st.subheader("Create Report Definition")
1120
- report_name = st.text_input("Report Name")
1121
- report_def = st.text_area("Report Definition", help="Describe the structure and content of the report.")
1122
-
1123
- if st.button("Create Report Definition"):
1124
- if report_name and report_def:
1125
- st.session_state.automated_reports.create_report_definition(report_name, report_def)
1126
- st.success("Report definition created successfully.")
1127
- else:
1128
- st.error("Please provide both report name and definition.")
1129
-
1130
- # Generate Report
1131
- st.subheader("Generate Report")
1132
- report_names = list(st.session_state.automated_reports.report_definitions.keys())
1133
- if report_names:
1134
- report_name_to_generate = st.selectbox("Select Report to Generate", report_names)
1135
- if st.button("Generate Report"):
1136
- with st.spinner("Generating Report..."):
1137
- report = st.session_state.automated_reports.generate_report(report_name_to_generate, st.session_state.data)
1138
- if "error" not in report:
1139
- st.header(f"Report: {report['Report Name']}")
1140
- st.markdown(f"**Definition:** {report['Report Definition']}")
1141
- for df_name, df_content in report["Data"].items():
1142
- st.subheader(f"Data: {df_name}")
1143
- st.dataframe(pd.DataFrame(df_content))
1144
- else:
1145
- st.error(report["error"])
1146
- else:
1147
- st.info("No report definitions found. Please create a report definition first.")
1148
-
1149
- def medical_knowledge_section():
1150
- """Handles the Medical Knowledge tab."""
1151
- st.header("Medical Knowledge")
1152
- query = st.text_input("Enter your medical question here:")
1153
-
1154
- if st.button("Search"):
1155
- if query.strip():
1156
- with st.spinner("Searching..."):
1157
- result = st.session_state.knowledge_base.search_medical_info(
1158
- query, pub_email=st.session_state.pub_email
1159
- )
1160
- st.markdown(result)
1161
- else:
1162
- st.error("Please enter a medical question to search.")
1163
-
1164
- if __name__ == "__main__":
1165
- main()
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import io
5
+ import ast
6
+ import logging
7
+ from abc import ABC, abstractmethod
8
+ from typing import Dict, List, Optional, Any
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import matplotlib.pyplot as plt
13
+ import seaborn as sns
14
+ import streamlit as st
15
+ import spacy
16
+
17
+ from scipy.stats import ttest_ind, f_oneway
18
+ from sklearn.model_selection import train_test_split
19
+ from sklearn.linear_model import LogisticRegression
20
+ from sklearn.metrics import accuracy_score
21
+
22
+ from statsmodels.tsa.seasonal import seasonal_decompose
23
+ from statsmodels.tsa.stattools import adfuller
24
+
25
+ from pydantic import BaseModel, Field
26
+ from Bio import Entrez # Ensure BioPython is installed
27
+
28
+ from dotenv import load_dotenv
29
+ import requests
30
+ import openai # Updated for OpenAI SDK v1.0.0+
31
+ from openai.error import APIError, RateLimitError, InvalidRequestError
32
+
33
+ # ---------------------- Load Environment Variables ---------------------------
34
+ load_dotenv()
35
+
36
+ # ---------------------- Logging Configuration ---------------------------
37
+ logging.basicConfig(
38
+ filename='app.log',
39
+ filemode='a',
40
+ format='%(asctime)s - %(levelname)s - %(message)s',
41
+ level=logging.INFO
42
+ )
43
+ logger = logging.getLogger()
44
+
45
+ # ---------------------- Streamlit Page Configuration ---------------------------
46
+ st.set_page_config(page_title="AI Clinical Intelligence Hub", layout="wide")
47
+
48
+ # ---------------------- Initialize OpenAI SDK ---------------------------
49
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
50
+ PUB_EMAIL = os.getenv("PUB_EMAIL", "")
51
+
52
+ if not OPENAI_API_KEY:
53
+ st.error("OpenAI API key must be set as an environment variable (OPENAI_API_KEY).")
54
+ st.stop()
55
+
56
+ # Set the OpenAI API key
57
+ openai.api_key = OPENAI_API_KEY
58
+
59
+ # ---------------------- Load spaCy Model ---------------------------
60
+ try:
61
+ nlp = spacy.load("en_core_web_sm")
62
+ except OSError:
63
+ # Avoid using Streamlit commands before set_page_config()
64
+ import subprocess
65
+ import sys
66
+ subprocess.run([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
67
+ nlp = spacy.load("en_core_web_sm")
68
+
69
+ # ---------------------- Base Classes and Schemas ---------------------------
70
+
71
+ class ResearchInput(BaseModel):
72
+ """Base schema for research tool inputs."""
73
+ data_key: str = Field(..., description="Session state key containing DataFrame")
74
+ columns: Optional[List[str]] = Field(None, description="List of columns to analyze")
75
+
76
+ class TemporalAnalysisInput(ResearchInput):
77
+ """Schema for temporal analysis."""
78
+ time_col: str = Field(..., description="Name of timestamp column")
79
+ value_col: str = Field(..., description="Name of value column to analyze")
80
+
81
+ class HypothesisInput(ResearchInput):
82
+ """Schema for hypothesis testing."""
83
+ group_col: str = Field(..., description="Categorical column defining groups")
84
+ value_col: str = Field(..., description="Numerical column to compare")
85
+
86
+ class ModelTrainingInput(ResearchInput):
87
+ """Schema for model training."""
88
+ target_col: str = Field(..., description="Name of target column")
89
+
90
+ class DataAnalyzer(ABC):
91
+ """Abstract base class for data analysis modules."""
92
+ @abstractmethod
93
+ def invoke(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
94
+ pass
95
+
96
+ # ---------------------- Concrete Analyzer Implementations ---------------------------
97
+
98
+ class AdvancedEDA(DataAnalyzer):
99
+ """Comprehensive Exploratory Data Analysis."""
100
+ def invoke(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
101
+ try:
102
+ analysis = {
103
+ "dimensionality": {
104
+ "rows": len(data),
105
+ "columns": list(data.columns),
106
+ "memory_usage_MB": f"{data.memory_usage().sum() / 1e6:.2f} MB"
107
+ },
108
+ "statistical_profile": data.describe(percentiles=[.25, .5, .75]).to_dict(),
109
+ "temporal_analysis": {
110
+ "date_ranges": {
111
+ col: {
112
+ "min": data[col].min(),
113
+ "max": data[col].max()
114
+ } for col in data.select_dtypes(include='datetime').columns
115
+ }
116
+ },
117
+ "data_quality": {
118
+ "missing_values": data.isnull().sum().to_dict(),
119
+ "duplicates": data.duplicated().sum(),
120
+ "cardinality": {
121
+ col: data[col].nunique() for col in data.columns
122
+ }
123
+ }
124
+ }
125
+ return analysis
126
+ except Exception as e:
127
+ logger.error(f"EDA Failed: {str(e)}")
128
+ return {"error": f"EDA Failed: {str(e)}"}
129
+
130
+ class DistributionVisualizer(DataAnalyzer):
131
+ """Distribution visualizations."""
132
+ def invoke(self, data: pd.DataFrame, columns: List[str], **kwargs) -> str:
133
+ try:
134
+ plt.figure(figsize=(12, 6))
135
+ for i, col in enumerate(columns, 1):
136
+ plt.subplot(1, len(columns), i)
137
+ sns.histplot(data[col], kde=True, stat="density")
138
+ plt.title(f'Distribution of {col}', fontsize=10)
139
+ plt.xticks(fontsize=8)
140
+ plt.yticks(fontsize=8)
141
+ plt.tight_layout()
142
+
143
+ buf = io.BytesIO()
144
+ plt.savefig(buf, format='png', dpi=300, bbox_inches='tight')
145
+ plt.close()
146
+ return base64.b64encode(buf.getvalue()).decode()
147
+ except Exception as e:
148
+ logger.error(f"Visualization Error: {str(e)}")
149
+ return f"Visualization Error: {str(e)}"
150
+
151
+ class TemporalAnalyzer(DataAnalyzer):
152
+ """Time series analysis."""
153
+ def invoke(self, data: pd.DataFrame, time_col: str, value_col: str, **kwargs) -> Dict[str, Any]:
154
+ try:
155
+ ts_data = data.set_index(pd.to_datetime(data[time_col]))[value_col]
156
+ decomposition = seasonal_decompose(ts_data, period=365)
157
+
158
+ plt.figure(figsize=(12, 8))
159
+ decomposition.plot()
160
+ plt.tight_layout()
161
+
162
+ buf = io.BytesIO()
163
+ plt.savefig(buf, format='png')
164
+ plt.close()
165
+ plot_data = base64.b64encode(buf.getvalue()).decode()
166
+
167
+ stationarity_p_value = adfuller(ts_data)[1]
168
+
169
+ return {
170
+ "trend_statistics": {
171
+ "stationarity_p_value": stationarity_p_value,
172
+ "seasonality_strength": float(max(decomposition.seasonal))
173
+ },
174
+ "visualization": plot_data
175
+ }
176
+ except Exception as e:
177
+ logger.error(f"Temporal Analysis Failed: {str(e)}")
178
+ return {"error": f"Temporal Analysis Failed: {str(e)}"}
179
+
180
+ class HypothesisTester(DataAnalyzer):
181
+ """Statistical hypothesis testing."""
182
+ def invoke(self, data: pd.DataFrame, group_col: str, value_col: str, **kwargs) -> Dict[str, Any]:
183
+ try:
184
+ groups = data[group_col].unique()
185
+
186
+ if len(groups) < 2:
187
+ return {"error": "Insufficient groups for comparison"}
188
+
189
+ group_data = [data[data[group_col] == g][value_col] for g in groups]
190
+
191
+ if len(groups) == 2:
192
+ stat, p = ttest_ind(*group_data)
193
+ test_type = "Independent t-test"
194
+ effect_size = self.calculate_cohens_d(group_data[0], group_data[1])
195
+ else:
196
+ stat, p = f_oneway(*group_data)
197
+ test_type = "ANOVA"
198
+ effect_size = None
199
+
200
+ return {
201
+ "test_type": test_type,
202
+ "test_statistic": stat,
203
+ "p_value": p,
204
+ "effect_size": effect_size,
205
+ "interpretation": self.interpret_p_value(p)
206
+ }
207
+ except Exception as e:
208
+ logger.error(f"Hypothesis Testing Failed: {str(e)}")
209
+ return {"error": f"Hypothesis Testing Failed: {str(e)}"}
210
+
211
+ @staticmethod
212
+ def calculate_cohens_d(x: pd.Series, y: pd.Series) -> Optional[float]:
213
+ """Calculate Cohen's d for effect size."""
214
+ try:
215
+ mean_diff = abs(x.mean() - y.mean())
216
+ pooled_std = np.sqrt((x.var() + y.var()) / 2)
217
+ return mean_diff / pooled_std
218
+ except Exception as e:
219
+ logger.error(f"Error calculating Cohen's d: {str(e)}")
220
+ return None
221
+
222
+ @staticmethod
223
+ def interpret_p_value(p: float) -> str:
224
+ """Interpret the p-value."""
225
+ if p < 0.001:
226
+ return "Very strong evidence against H0"
227
+ elif p < 0.01:
228
+ return "Strong evidence against H0"
229
+ elif p < 0.05:
230
+ return "Evidence against H0"
231
+ elif p < 0.1:
232
+ return "Weak evidence against H0"
233
+ else:
234
+ return "No significant evidence against H0"
235
+
236
+ class LogisticRegressionTrainer(DataAnalyzer):
237
+ """Logistic Regression Model Trainer."""
238
+ def invoke(self, data: pd.DataFrame, target_col: str, columns: List[str], **kwargs) -> Dict[str, Any]:
239
+ try:
240
+ X = data[columns]
241
+ y = data[target_col]
242
+ X_train, X_test, y_train, y_test = train_test_split(
243
+ X, y, test_size=0.2, random_state=42
244
+ )
245
+ model = LogisticRegression(max_iter=1000)
246
+ model.fit(X_train, y_train)
247
+ y_pred = model.predict(X_test)
248
+ accuracy = accuracy_score(y_test, y_pred)
249
+ return {
250
+ "model_type": "Logistic Regression",
251
+ "accuracy": accuracy,
252
+ "model_params": model.get_params()
253
+ }
254
+ except Exception as e:
255
+ logger.error(f"Logistic Regression Model Error: {str(e)}")
256
+ return {"error": f"Logistic Regression Model Error: {str(e)}"}
257
+
258
+ # ---------------------- Business Logic Layer ---------------------------
259
+
260
+ class ClinicalRule(BaseModel):
261
+ """Defines a clinical rule."""
262
+ name: str
263
+ condition: str
264
+ action: str
265
+ severity: str # low, medium, or high
266
+
267
+ class ClinicalRulesEngine:
268
+ """Executes rules against patient data."""
269
+ def __init__(self):
270
+ self.rules: Dict[str, ClinicalRule] = {}
271
+
272
+ def add_rule(self, rule: ClinicalRule):
273
+ self.rules[rule.name] = rule
274
+
275
+ def execute_rules(self, data: pd.DataFrame) -> Dict[str, Any]:
276
+ results = {}
277
+ for rule_name, rule in self.rules.items():
278
+ try:
279
+ # Using safe_eval instead of eval for security
280
+ rule_matched = self.safe_eval(rule.condition, {"df": data})
281
+ results[rule_name] = {
282
+ "rule_matched": rule_matched,
283
+ "action": rule.action if rule_matched else None,
284
+ "severity": rule.severity if rule_matched else None
285
+ }
286
+ except Exception as e:
287
+ logger.error(f"Error executing rule '{rule_name}': {str(e)}")
288
+ results[rule_name] = {
289
+ "rule_matched": False,
290
+ "error": str(e),
291
+ "severity": None
292
+ }
293
+ return results
294
+
295
+ @staticmethod
296
+ def safe_eval(expr, variables):
297
+ """
298
+ Safely evaluate an expression using AST parsing.
299
+ Only allows certain node types to prevent execution of arbitrary code.
300
+ """
301
+ allowed_nodes = (
302
+ ast.Expression, ast.BoolOp, ast.BinOp, ast.UnaryOp, ast.Compare,
303
+ ast.Call, ast.Name, ast.Load, ast.Constant, ast.Num, ast.Str,
304
+ ast.List, ast.Tuple, ast.Dict
305
+ )
306
+ try:
307
+ node = ast.parse(expr, mode='eval')
308
+ for subnode in ast.walk(node):
309
+ if not isinstance(subnode, allowed_nodes):
310
+ raise ValueError(f"Unsupported expression: {expr}")
311
+ return eval(compile(node, '<string>', mode='eval'), {"__builtins__": None}, variables)
312
+ except Exception as e:
313
+ logger.error(f"safe_eval error: {str(e)}")
314
+ raise ValueError(f"Invalid expression: {e}")
315
+
316
+ class ClinicalKPI(BaseModel):
317
+ """Define a clinical KPI."""
318
+ name: str
319
+ calculation: str
320
+ threshold: Optional[float] = None
321
+
322
+ class ClinicalKPIMonitoring:
323
+ """Calculates KPIs based on data."""
324
+ def __init__(self):
325
+ self.kpis: Dict[str, ClinicalKPI] = {}
326
+
327
+ def add_kpi(self, kpi: ClinicalKPI):
328
+ self.kpis[kpi.name] = kpi
329
+
330
+ def calculate_kpis(self, data: pd.DataFrame) -> Dict[str, Any]:
331
+ results = {}
332
+ for kpi_name, kpi in self.kpis.items():
333
+ try:
334
+ # Using safe_eval instead of eval for security
335
+ kpi_value = self.safe_eval(kpi.calculation, {"df": data})
336
+ status = self.evaluate_threshold(kpi_value, kpi.threshold)
337
+ results[kpi_name] = {
338
+ "value": kpi_value,
339
+ "threshold": kpi.threshold,
340
+ "status": status
341
+ }
342
+ except Exception as e:
343
+ logger.error(f"Error calculating KPI '{kpi_name}': {str(e)}")
344
+ results[kpi_name] = {"error": str(e)}
345
+ return results
346
+
347
+ @staticmethod
348
+ def evaluate_threshold(value: Any, threshold: Optional[float]) -> Optional[str]:
349
+ if threshold is None:
350
+ return None
351
+ try:
352
+ return "Above Threshold" if value > threshold else "Below Threshold"
353
+ except TypeError:
354
+ return "Threshold Evaluation Not Applicable"
355
+
356
+ @staticmethod
357
+ def safe_eval(expr, variables):
358
+ """
359
+ Safely evaluate an expression using AST parsing.
360
+ Only allows certain node types to prevent execution of arbitrary code.
361
+ """
362
+ allowed_nodes = (
363
+ ast.Expression, ast.BoolOp, ast.BinOp, ast.UnaryOp, ast.Compare,
364
+ ast.Call, ast.Name, ast.Load, ast.Constant, ast.Num, ast.Str,
365
+ ast.List, ast.Tuple, ast.Dict
366
+ )
367
+ try:
368
+ node = ast.parse(expr, mode='eval')
369
+ for subnode in ast.walk(node):
370
+ if not isinstance(subnode, allowed_nodes):
371
+ raise ValueError(f"Unsupported expression: {expr}")
372
+ return eval(compile(node, '<string>', mode='eval'), {"__builtins__": None}, variables)
373
+ except Exception as e:
374
+ logger.error(f"safe_eval error: {str(e)}")
375
+ raise ValueError(f"Invalid expression: {e}")
376
+
377
+ class DiagnosisSupport(ABC):
378
+ """Abstract class for implementing clinical diagnoses."""
379
+ @abstractmethod
380
+ def diagnose(
381
+ self,
382
+ data: pd.DataFrame,
383
+ target_col: str,
384
+ columns: List[str],
385
+ diagnosis_key: str = "diagnosis",
386
+ **kwargs
387
+ ) -> pd.DataFrame:
388
+ pass
389
+
390
+ class SimpleDiagnosis(DiagnosisSupport):
391
+ """Provides a simple diagnosis example, based on the Logistic regression model."""
392
+ def __init__(self):
393
+ self.model_trainer: LogisticRegressionTrainer = LogisticRegressionTrainer()
394
+
395
+ def diagnose(
396
+ self,
397
+ data: pd.DataFrame,
398
+ target_col: str,
399
+ columns: List[str],
400
+ diagnosis_key: str = "diagnosis",
401
+ **kwargs
402
+ ) -> pd.DataFrame:
403
+ try:
404
+ result = self.model_trainer.invoke(data, target_col=target_col, columns=columns)
405
+ if "accuracy" in result:
406
+ return pd.DataFrame({
407
+ diagnosis_key: [f"Model Accuracy: {result['accuracy']:.2%}"],
408
+ "model": [result["model_type"]]
409
+ })
410
+ else:
411
+ return pd.DataFrame({
412
+ diagnosis_key: [f"Diagnosis failed: {result.get('error', 'Unknown error')}"]
413
+ })
414
+ except Exception as e:
415
+ logger.error(f"Error during diagnosis: {str(e)}")
416
+ return pd.DataFrame({
417
+ diagnosis_key: [f"Error during diagnosis: {e}"]
418
+ })
419
+
420
+ class TreatmentRecommendation(ABC):
421
+ """Abstract class for treatment recommendations."""
422
+ @abstractmethod
423
+ def recommend(
424
+ self,
425
+ data: pd.DataFrame,
426
+ condition_col: str,
427
+ treatment_col: str,
428
+ recommendation_key: str = "recommendation",
429
+ **kwargs
430
+ ) -> pd.DataFrame:
431
+ pass
432
+
433
+ class BasicTreatmentRecommendation(TreatmentRecommendation):
434
+ """A placeholder class for basic treatment recommendations."""
435
+ def recommend(
436
+ self,
437
+ data: pd.DataFrame,
438
+ condition_col: str,
439
+ treatment_col: str,
440
+ recommendation_key: str = "recommendation",
441
+ **kwargs
442
+ ) -> pd.DataFrame:
443
+ if condition_col not in data.columns or treatment_col not in data.columns:
444
+ logger.warning(f"Condition or Treatment columns not found: {condition_col}, {treatment_col}")
445
+ return pd.DataFrame({
446
+ recommendation_key: ["Condition or Treatment columns not found!"]
447
+ })
448
+
449
+ treatment = data[data[condition_col] == "High"][treatment_col].to_list()
450
+ if treatment:
451
+ return pd.DataFrame({
452
+ recommendation_key: [f"Treatment recommended for High risk patients: {treatment}"]
453
+ })
454
+ else:
455
+ return pd.DataFrame({
456
+ recommendation_key: ["No treatment recommendation found!"]
457
+ })
458
+
459
+ # ---------------------- Medical Knowledge Base ---------------------------
460
+
461
+ class MedicalKnowledgeBase(ABC):
462
+ """Abstract class for Medical Knowledge."""
463
+ @abstractmethod
464
+ def search_medical_info(self, query: str, pub_email: str = "") -> str:
465
+ pass
466
+
467
+ class SimpleMedicalKnowledge(MedicalKnowledgeBase):
468
+ """Enhanced Medical Knowledge Class using OpenAI GPT-4."""
469
+ def __init__(self, nlp_model):
470
+ self.nlp = nlp_model # Using the loaded spaCy model
471
+
472
+ def search_medical_info(self, query: str, pub_email: str = "") -> str:
473
+ """
474
+ Uses OpenAI's GPT-4 to fetch medical information based on the user's query.
475
+ """
476
+ logger.info(f"Received medical query: {query}")
477
+ try:
478
+ # Preprocess the query (e.g., entity recognition)
479
+ doc = self.nlp(query.lower())
480
+ entities = [ent.text for ent in doc.ents]
481
+ processed_query = " ".join(entities) if entities else query.lower()
482
+
483
+ logger.info(f"Processed query: {processed_query}")
484
+
485
+ # Create a prompt for GPT-4
486
+ prompt = f"""
487
+ You are a medical assistant. Provide a comprehensive and accurate response to the following medical query:
488
+
489
+ Query: {processed_query}
490
+
491
+ Please ensure the information is clear, concise, and evidence-based.
492
+ """
493
+
494
+ # Make the API request to OpenAI GPT-4
495
+ response = openai.ChatCompletion.create(
496
+ model="gpt-4",
497
+ messages=[
498
+ {"role": "system", "content": "You are a helpful medical assistant."},
499
+ {"role": "user", "content": prompt}
500
+ ],
501
+ max_tokens=500,
502
+ temperature=0.7,
503
+ )
504
+
505
+ # Extract the answer from the response
506
+ answer = response.choices[0].message['content'].strip()
507
+
508
+ logger.info("Successfully retrieved data from OpenAI GPT-4.")
509
+
510
+ # Fetch PubMed abstract related to the query
511
+ pubmed_abstract = self.fetch_pubmed_abstract(processed_query, pub_email)
512
+
513
+ # Format the response
514
+ return f"**Based on your query:** {answer}\n\n**PubMed Abstract:**\n\n{pubmed_abstract}"
515
+
516
+ except RateLimitError as e:
517
+ logger.error(f"Rate Limit Exceeded: {str(e)}")
518
+ return "Rate limit exceeded. Please try again later."
519
+ except InvalidRequestError as e:
520
+ logger.error(f"Invalid Request: {str(e)}")
521
+ return f"Invalid request: {str(e)}"
522
+ except APIError as e:
523
+ logger.error(f"OpenAI API Error: {str(e)}")
524
+ return f"OpenAI API Error: {str(e)}"
525
+ except Exception as e:
526
+ logger.error(f"Medical Knowledge Search Failed: {str(e)}")
527
+ return f"Medical Knowledge Search Failed: {str(e)}"
528
+
529
+ def fetch_pubmed_abstract(self, query: str, email: str) -> str:
530
+ """
531
+ Searches PubMed for abstracts related to the query.
532
+ """
533
+ try:
534
+ if not email:
535
+ logger.warning("PubMed abstract retrieval skipped: Email not provided.")
536
+ return "No PubMed abstract available: Email not provided."
537
+
538
+ Entrez.email = email
539
+ handle = Entrez.esearch(db="pubmed", term=query, retmax=1, sort='relevance')
540
+ record = Entrez.read(handle)
541
+ handle.close()
542
+ logger.info(f"PubMed search for query '{query}' returned IDs: {record['IdList']}")
543
+
544
+ if record["IdList"]:
545
+ handle = Entrez.efetch(db="pubmed", id=record["IdList"][0], rettype="abstract", retmode="text")
546
+ abstract = handle.read()
547
+ handle.close()
548
+ logger.info(f"Fetched PubMed abstract for ID {record['IdList'][0]}")
549
+ return abstract
550
+ else:
551
+ logger.info(f"No PubMed abstracts found for query '{query}'.")
552
+ return "No abstracts found for this query on PubMed."
553
+ except Exception as e:
554
+ logger.error(f"Error searching PubMed: {e}")
555
+ return f"Error searching PubMed: {e}"
556
+
557
+ # ---------------------- Forecasting Engine ---------------------------
558
+
559
+ class ForecastingEngine(ABC):
560
+ """Abstract class for forecasting."""
561
+ @abstractmethod
562
+ def predict(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame:
563
+ pass
564
+
565
+ class SimpleForecasting(ForecastingEngine):
566
+ """Simple forecasting engine."""
567
+ def predict(self, data: pd.DataFrame, period: int = 7, **kwargs) -> pd.DataFrame:
568
+ # Placeholder for actual forecasting logic
569
+ return pd.DataFrame({"forecast": [f"Forecast for the next {period} days"]})
570
+
571
+ # ---------------------- Insights and Reporting Layer ---------------------------
572
+
573
+ class AutomatedInsights:
574
+ """Generates automated insights based on selected analyses."""
575
+ def __init__(self):
576
+ self.analyses: Dict[str, DataAnalyzer] = {
577
+ "EDA": AdvancedEDA(),
578
+ "temporal": TemporalAnalyzer(),
579
+ "distribution": DistributionVisualizer(),
580
+ "hypothesis": HypothesisTester(),
581
+ "model": LogisticRegressionTrainer()
582
+ }
583
+
584
+ def generate_insights(self, data: pd.DataFrame, analysis_names: List[str], **kwargs) -> Dict[str, Any]:
585
+ results = {}
586
+ for name in analysis_names:
587
+ analyzer = self.analyses.get(name)
588
+ if analyzer:
589
+ try:
590
+ results[name] = analyzer.invoke(data=data, **kwargs)
591
+ except Exception as e:
592
+ logger.error(f"Error in analysis '{name}': {str(e)}")
593
+ results[name] = {"error": str(e)}
594
+ else:
595
+ logger.warning(f"Analysis '{name}' not found.")
596
+ results[name] = {"error": "Analysis not found"}
597
+ return results
598
+
599
+ class Dashboard:
600
+ """Handles the creation and display of the dashboard."""
601
+ def __init__(self):
602
+ self.layout: Dict[str, str] = {}
603
+
604
+ def add_visualisation(self, vis_name: str, vis_type: str):
605
+ self.layout[vis_name] = vis_type
606
+
607
+ def display_dashboard(self, data_dict: Dict[str, pd.DataFrame]):
608
+ st.header("Dashboard")
609
+ for vis_name, vis_type in self.layout.items():
610
+ st.subheader(vis_name)
611
+ df = data_dict.get(vis_name)
612
+ if df is not None:
613
+ if vis_type == "table":
614
+ st.table(df)
615
+ elif vis_type == "plot":
616
+ if len(df.columns) > 1:
617
+ fig = plt.figure()
618
+ sns.lineplot(data=df)
619
+ st.pyplot(fig)
620
+ else:
621
+ st.write("Please select a DataFrame with more than 1 column for plotting.")
622
+ else:
623
+ st.write("Data Not Found")
624
+
625
+ class AutomatedReports:
626
+ """Manages automated report definitions and generation."""
627
+ def __init__(self):
628
+ self.report_definitions: Dict[str, str] = {}
629
+
630
+ def create_report_definition(self, report_name: str, definition: str):
631
+ self.report_definitions[report_name] = definition
632
+
633
+ def generate_report(self, report_name: str, data: Dict[str, pd.DataFrame]) -> Dict[str, Any]:
634
+ if report_name not in self.report_definitions:
635
+ return {"error": "Report name not found"}
636
+ report_content = {
637
+ "Report Name": report_name,
638
+ "Report Definition": self.report_definitions[report_name],
639
+ "Data": {df_name: df.to_dict() for df_name, df in data.items()}
640
+ }
641
+ return report_content
642
+
643
+ # ---------------------- Data Acquisition Layer ---------------------------
644
+
645
+ class DataSource(ABC):
646
+ """Base class for data sources."""
647
+ @abstractmethod
648
+ def connect(self) -> None:
649
+ """Connect to the data source."""
650
+ pass
651
+
652
+ @abstractmethod
653
+ def fetch_data(self, query: str, **kwargs) -> pd.DataFrame:
654
+ """Fetch the data based on a specific query."""
655
+ pass
656
+
657
+ class CSVDataSource(DataSource):
658
+ """Data source for CSV files."""
659
+ def __init__(self, file_path: io.BytesIO):
660
+ self.file_path = file_path
661
+ self.data: Optional[pd.DataFrame] = None
662
+
663
+ def connect(self):
664
+ self.data = pd.read_csv(self.file_path)
665
+
666
+ def fetch_data(self, query: str = None, **kwargs) -> pd.DataFrame:
667
+ if self.data is None:
668
+ raise Exception("No connection is made, call connect()")
669
+ return self.data
670
+
671
+ class DatabaseSource(DataSource):
672
+ """Data source for SQL Databases."""
673
+ def __init__(self, connection_string: str, database_type: str):
674
+ self.connection_string = connection_string
675
+ self.database_type = database_type.lower()
676
+ self.connection = None
677
+
678
+ def connect(self):
679
+ if self.database_type == "sql":
680
+ # Placeholder for actual SQL connection logic
681
+ self.connection = "Connected to SQL Database"
682
+ else:
683
+ raise Exception(f"Database type '{self.database_type}' is not supported.")
684
+
685
+ def fetch_data(self, query: str, **kwargs) -> pd.DataFrame:
686
+ if self.connection is None:
687
+ raise Exception("No connection is made, call connect()")
688
+ # Placeholder for data fetching logic
689
+ return pd.DataFrame({"result": [f"Fetched data based on query: {query}"]})
690
+
691
+ class DataIngestion:
692
+ """Handles data ingestion from various sources."""
693
+ def __init__(self):
694
+ self.sources: Dict[str, DataSource] = {}
695
+
696
+ def add_source(self, source_name: str, source: DataSource):
697
+ self.sources[source_name] = source
698
+
699
+ def ingest_data(self, source_name: str, query: str = None, **kwargs) -> pd.DataFrame:
700
+ if source_name not in self.sources:
701
+ raise Exception(f"Source '{source_name}' not found.")
702
+ source = self.sources[source_name]
703
+ source.connect()
704
+ return source.fetch_data(query, **kwargs)
705
+
706
+ class DataModel(BaseModel):
707
+ """Defines a data model."""
708
+ name: str
709
+ kpis: List[str] = Field(default_factory=list)
710
+ dimensions: List[str] = Field(default_factory=list)
711
+ custom_calculations: Optional[Dict[str, str]] = None
712
+ relations: Optional[Dict[str, str]] = None # Example: {"table1": "table2"}
713
+
714
+ def to_json(self) -> str:
715
+ return json.dumps(self.dict())
716
+
717
+ @staticmethod
718
+ def from_json(json_str: str) -> 'DataModel':
719
+ return DataModel(**json.loads(json_str))
720
+
721
+ class DataModelling:
722
+ """Manages data models."""
723
+ def __init__(self):
724
+ self.models: Dict[str, DataModel] = {}
725
+
726
+ def add_model(self, model: DataModel):
727
+ self.models[model.name] = model
728
+
729
+ def get_model(self, model_name: str) -> DataModel:
730
+ if model_name not in self.models:
731
+ raise Exception(f"Model '{model_name}' not found.")
732
+ return self.models[model_name]
733
+
734
+ # ---------------------- Main Streamlit Application ---------------------------
735
+
736
+ def main():
737
+ """Main function to run the Streamlit app."""
738
+ st.title("🏥 AI-Powered Clinical Intelligence Hub")
739
+
740
+ # Initialize Session State
741
+ initialize_session_state()
742
+
743
+ # Sidebar for Data Management
744
+ with st.sidebar:
745
+ data_management_section()
746
+
747
+ # Main Content
748
+ if st.session_state.data:
749
+ col1, col2 = st.columns([1, 3])
750
+
751
+ with col1:
752
+ dataset_metadata_section()
753
+
754
+ with col2:
755
+ main_tabs_section()
756
+
757
+ def initialize_session_state():
758
+ """Initialize necessary components in Streamlit's session state."""
759
+ if 'data' not in st.session_state:
760
+ st.session_state.data = {} # Store pd.DataFrame under a name
761
+ if 'data_ingestion' not in st.session_state:
762
+ st.session_state.data_ingestion = DataIngestion()
763
+ if 'data_modelling' not in st.session_state:
764
+ st.session_state.data_modelling = DataModelling()
765
+ if 'clinical_rules' not in st.session_state:
766
+ st.session_state.clinical_rules = ClinicalRulesEngine()
767
+ if 'kpi_monitoring' not in st.session_state:
768
+ st.session_state.kpi_monitoring = ClinicalKPIMonitoring()
769
+ if 'forecasting_engine' not in st.session_state:
770
+ st.session_state.forecasting_engine = SimpleForecasting()
771
+ if 'automated_insights' not in st.session_state:
772
+ st.session_state.automated_insights = AutomatedInsights()
773
+ if 'dashboard' not in st.session_state:
774
+ st.session_state.dashboard = Dashboard()
775
+ if 'automated_reports' not in st.session_state:
776
+ st.session_state.automated_reports = AutomatedReports()
777
+ if 'diagnosis_support' not in st.session_state:
778
+ st.session_state.diagnosis_support = SimpleDiagnosis()
779
+ if 'treatment_recommendation' not in st.session_state:
780
+ st.session_state.treatment_recommendation = BasicTreatmentRecommendation()
781
+ if 'knowledge_base' not in st.session_state:
782
+ st.session_state.knowledge_base = SimpleMedicalKnowledge(nlp_model=nlp)
783
+ if 'pub_email' not in st.session_state:
784
+ st.session_state.pub_email = PUB_EMAIL # Load PUB_EMAIL from environment variables
785
+
786
+ def data_management_section():
787
+ """Handles the data management section in the sidebar."""
788
+ st.header("⚙️ Data Management")
789
+ data_source_selection = st.selectbox("Select Data Source Type", ["CSV", "SQL Database"])
790
+
791
+ if data_source_selection == "CSV":
792
+ handle_csv_upload()
793
+ elif data_source_selection == "SQL Database":
794
+ handle_sql_database()
795
+
796
+ if st.button("Ingest Data"):
797
+ ingest_data_action()
798
+
799
+ def handle_csv_upload():
800
+ """Handles CSV file uploads."""
801
+ uploaded_file = st.file_uploader("Upload research dataset (CSV)", type=["csv"])
802
+ if uploaded_file:
803
+ source_name = st.text_input("Data Source Name")
804
+ if source_name:
805
+ try:
806
+ csv_source = CSVDataSource(file_path=uploaded_file)
807
+ st.session_state.data_ingestion.add_source(source_name, csv_source)
808
+ st.success(f"Uploaded {uploaded_file.name} as '{source_name}'.")
809
+ except Exception as e:
810
+ st.error(f"Error loading dataset: {e}")
811
+
812
+ def handle_sql_database():
813
+ """Handles SQL database connections."""
814
+ conn_str = st.text_input("Enter connection string for SQL DB")
815
+ if conn_str:
816
+ source_name = st.text_input("Data Source Name")
817
+ if source_name:
818
+ try:
819
+ sql_source = DatabaseSource(connection_string=conn_str, database_type="sql")
820
+ st.session_state.data_ingestion.add_source(source_name, sql_source)
821
+ st.success(f"Added SQL DB Source '{source_name}'.")
822
+ except Exception as e:
823
+ st.error(f"Error loading database source: {e}")
824
+
825
+ def ingest_data_action():
826
+ """Performs data ingestion from the selected source."""
827
+ if st.session_state.data_ingestion.sources:
828
+ source_name_to_fetch = st.selectbox("Select Data Source to Ingest", list(st.session_state.data_ingestion.sources.keys()))
829
+ query = st.text_area("Optional Query to Fetch data")
830
+ if source_name_to_fetch:
831
+ with st.spinner("Ingesting data..."):
832
+ try:
833
+ data = st.session_state.data_ingestion.ingest_data(source_name_to_fetch, query)
834
+ st.session_state.data[source_name_to_fetch] = data
835
+ st.success(f"Ingested data from '{source_name_to_fetch}'.")
836
+ except Exception as e:
837
+ st.error(f"Ingestion failed: {e}")
838
+ else:
839
+ st.error("No data source added. Please add a data source.")
840
+
841
+ def dataset_metadata_section():
842
+ """Displays metadata for the selected dataset."""
843
+ st.subheader("Dataset Metadata")
844
+ data_source_keys = list(st.session_state.data.keys())
845
+ selected_data_key = st.selectbox("Select Dataset", data_source_keys)
846
+
847
+ if selected_data_key:
848
+ data = st.session_state.data[selected_data_key]
849
+ metadata = {
850
+ "Variables": list(data.columns),
851
+ "Time Range": {
852
+ col: {
853
+ "min": data[col].min(),
854
+ "max": data[col].max()
855
+ } for col in data.select_dtypes(include='datetime').columns
856
+ },
857
+ "Size": f"{data.memory_usage().sum() / 1e6:.2f} MB"
858
+ }
859
+ st.json(metadata)
860
+ # Store the selected dataset key in session state for use in analysis
861
+ st.session_state.selected_data_key = selected_data_key
862
+
863
+ def main_tabs_section():
864
+ """Creates and manages the main tabs in the application."""
865
+ analysis_tab, clinical_logic_tab, insights_tab, reports_tab, knowledge_tab = st.tabs([
866
+ "Data Analysis",
867
+ "Clinical Logic",
868
+ "Insights",
869
+ "Reports",
870
+ "Medical Knowledge"
871
+ ])
872
+
873
+ with analysis_tab:
874
+ data_analysis_section()
875
+
876
+ with clinical_logic_tab:
877
+ clinical_logic_section()
878
+
879
+ with insights_tab:
880
+ insights_section()
881
+
882
+ with reports_tab:
883
+ reports_section()
884
+
885
+ with knowledge_tab:
886
+ medical_knowledge_section()
887
+
888
+ def data_analysis_section():
889
+ """Handles the Data Analysis tab."""
890
+ selected_data_key = st.session_state.get('selected_data_key', None)
891
+ if not selected_data_key:
892
+ st.warning("Please select a dataset from the metadata section.")
893
+ return
894
+
895
+ data = st.session_state.data[selected_data_key]
896
+ analysis_type = st.selectbox("Select Analysis Mode", [
897
+ "Exploratory Data Analysis",
898
+ "Temporal Pattern Analysis",
899
+ "Comparative Statistics",
900
+ "Distribution Analysis",
901
+ "Train Logistic Regression Model"
902
+ ])
903
+
904
+ if analysis_type == "Exploratory Data Analysis":
905
+ perform_eda(data)
906
+ elif analysis_type == "Temporal Pattern Analysis":
907
+ perform_temporal_analysis(data)
908
+ elif analysis_type == "Comparative Statistics":
909
+ perform_comparative_statistics(data)
910
+ elif analysis_type == "Distribution Analysis":
911
+ perform_distribution_analysis(data)
912
+ elif analysis_type == "Train Logistic Regression Model":
913
+ perform_logistic_regression_training(data)
914
+
915
+ def perform_eda(data: pd.DataFrame):
916
+ """Performs Exploratory Data Analysis."""
917
+ analyzer = AdvancedEDA()
918
+ eda_result = analyzer.invoke(data=data)
919
+ st.subheader("Data Quality Report")
920
+ st.json(eda_result)
921
+
922
+ def perform_temporal_analysis(data: pd.DataFrame):
923
+ """Performs Temporal Pattern Analysis."""
924
+ time_cols = data.select_dtypes(include='datetime').columns
925
+ num_cols = data.select_dtypes(include=np.number).columns
926
+
927
+ if len(time_cols) == 0:
928
+ st.warning("No datetime columns available for temporal analysis.")
929
+ return
930
+
931
+ time_col = st.selectbox("Select Temporal Variable", time_cols)
932
+ value_col = st.selectbox("Select Analysis Variable", num_cols)
933
+
934
+ if time_col and value_col:
935
+ analyzer = TemporalAnalyzer()
936
+ result = analyzer.invoke(data=data, time_col=time_col, value_col=value_col)
937
+ if "visualization" in result and result["visualization"]:
938
+ st.image(f"data:image/png;base64,{result['visualization']}", use_column_width=True)
939
+ st.json(result)
940
+
941
+ def perform_comparative_statistics(data: pd.DataFrame):
942
+ """Performs Comparative Statistics."""
943
+ categorical_cols = data.select_dtypes(include=['category', 'object']).columns
944
+ numeric_cols = data.select_dtypes(include=np.number).columns
945
+
946
+ if len(categorical_cols) == 0:
947
+ st.warning("No categorical columns available for hypothesis testing.")
948
+ return
949
+
950
+ if len(numeric_cols) == 0:
951
+ st.warning("No numerical columns available for hypothesis testing.")
952
+ return
953
+
954
+ group_col = st.selectbox("Select Grouping Variable", categorical_cols)
955
+ value_col = st.selectbox("Select Metric Variable", numeric_cols)
956
+
957
+ if group_col and value_col:
958
+ analyzer = HypothesisTester()
959
+ result = analyzer.invoke(data=data, group_col=group_col, value_col=value_col)
960
+ st.subheader("Statistical Test Results")
961
+ st.json(result)
962
+
963
+ def perform_distribution_analysis(data: pd.DataFrame):
964
+ """Performs Distribution Analysis."""
965
+ numeric_cols = data.select_dtypes(include=np.number).columns.tolist()
966
+ selected_cols = st.multiselect("Select Variables for Distribution Analysis", numeric_cols)
967
+
968
+ if selected_cols:
969
+ analyzer = DistributionVisualizer()
970
+ img_data = analyzer.invoke(data=data, columns=selected_cols)
971
+ if not img_data.startswith("Visualization Error"):
972
+ st.image(f"data:image/png;base64,{img_data}", use_column_width=True)
973
+ else:
974
+ st.error(img_data)
975
+ else:
976
+ st.info("Please select at least one numerical column to visualize.")
977
+
978
+ def perform_logistic_regression_training(data: pd.DataFrame):
979
+ """Trains a Logistic Regression model."""
980
+ numeric_cols = data.select_dtypes(include=np.number).columns.tolist()
981
+ target_col = st.selectbox("Select Target Variable", data.columns.tolist())
982
+ selected_cols = st.multiselect("Select Feature Variables", numeric_cols)
983
+
984
+ if selected_cols and target_col:
985
+ analyzer = LogisticRegressionTrainer()
986
+ result = analyzer.invoke(data=data, target_col=target_col, columns=selected_cols)
987
+ st.subheader("Logistic Regression Model Results")
988
+ st.json(result)
989
+ else:
990
+ st.warning("Please select both target and feature variables for model training.")
991
+
992
+ def clinical_logic_section():
993
+ """Handles the Clinical Logic tab."""
994
+ st.header("Clinical Logic")
995
+
996
+ # Clinical Rules Management
997
+ st.subheader("Clinical Rules")
998
+ rule_name = st.text_input("Enter Rule Name")
999
+ condition = st.text_area("Enter Rule Condition (use 'df' for DataFrame)",
1000
+ help="Example: df['blood_pressure'] > 140")
1001
+ action = st.text_area("Enter Action to be Taken on Rule Match")
1002
+ severity = st.selectbox("Enter Severity for the Rule", ["low", "medium", "high"])
1003
+
1004
+ if st.button("Add Clinical Rule"):
1005
+ if rule_name and condition and action and severity:
1006
+ try:
1007
+ rule = ClinicalRule(
1008
+ name=rule_name,
1009
+ condition=condition,
1010
+ action=action,
1011
+ severity=severity
1012
+ )
1013
+ st.session_state.clinical_rules.add_rule(rule)
1014
+ st.success("Added Clinical Rule successfully.")
1015
+ except Exception as e:
1016
+ st.error(f"Error in rule definition: {e}")
1017
+ else:
1018
+ st.error("Please fill in all fields to add a clinical rule.")
1019
+
1020
+ # Clinical KPI Management
1021
+ st.subheader("Clinical KPI Definition")
1022
+ kpi_name = st.text_input("Enter KPI Name")
1023
+ kpi_calculation = st.text_area("Enter KPI Calculation (use 'df' for DataFrame)",
1024
+ help="Example: df['patient_count'].sum()")
1025
+ threshold = st.text_input("Enter Threshold for KPI (Optional)", help="Leave blank if not applicable")
1026
+
1027
+ if st.button("Add Clinical KPI"):
1028
+ if kpi_name and kpi_calculation:
1029
+ try:
1030
+ threshold_value = float(threshold) if threshold else None
1031
+ kpi = ClinicalKPI(
1032
+ name=kpi_name,
1033
+ calculation=kpi_calculation,
1034
+ threshold=threshold_value
1035
+ )
1036
+ st.session_state.kpi_monitoring.add_kpi(kpi)
1037
+ st.success(f"Added KPI '{kpi_name}' successfully.")
1038
+ except ValueError:
1039
+ st.error("Threshold must be a numeric value.")
1040
+ except Exception as e:
1041
+ st.error(f"Error creating KPI: {e}")
1042
+ else:
1043
+ st.error("Please provide both KPI name and calculation.")
1044
+
1045
+ # Execute Clinical Rules and Calculate KPIs
1046
+ selected_data_key = st.selectbox("Select Dataset for Clinical Logic", list(st.session_state.data.keys()))
1047
+ if selected_data_key:
1048
+ data = st.session_state.data[selected_data_key]
1049
+ if st.button("Execute Clinical Rules"):
1050
+ with st.spinner("Executing Clinical Rules..."):
1051
+ result = st.session_state.clinical_rules.execute_rules(data)
1052
+ st.json(result)
1053
+ if st.button("Calculate Clinical KPIs"):
1054
+ with st.spinner("Calculating Clinical KPIs..."):
1055
+ result = st.session_state.kpi_monitoring.calculate_kpis(data)
1056
+ st.json(result)
1057
+ else:
1058
+ st.warning("Please ingest data to execute clinical rules and calculate KPIs.")
1059
+
1060
+ def insights_section():
1061
+ """Handles the Insights tab."""
1062
+ st.header("Automated Insights")
1063
+
1064
+ selected_data_key = st.selectbox("Select Dataset for Insights", list(st.session_state.data.keys()))
1065
+ if not selected_data_key:
1066
+ st.warning("Please select a dataset to generate insights.")
1067
+ return
1068
+
1069
+ data = st.session_state.data[selected_data_key]
1070
+ available_analyses = ["EDA", "temporal", "distribution", "hypothesis", "model"]
1071
+ selected_analyses = st.multiselect("Select Analyses for Insights", available_analyses)
1072
+
1073
+ if st.button("Generate Automated Insights"):
1074
+ if selected_analyses:
1075
+ with st.spinner("Generating Insights..."):
1076
+ results = st.session_state.automated_insights.generate_insights(
1077
+ data, analysis_names=selected_analyses
1078
+ )
1079
+ st.json(results)
1080
+ else:
1081
+ st.warning("Please select at least one analysis to generate insights.")
1082
+
1083
+ # Diagnosis Support
1084
+ st.subheader("Diagnosis Support")
1085
+ target_col = st.selectbox("Select Target Variable for Diagnosis", data.columns.tolist())
1086
+ numeric_cols = data.select_dtypes(include=np.number).columns.tolist()
1087
+ selected_feature_cols = st.multiselect("Select Feature Variables for Diagnosis", numeric_cols)
1088
+
1089
+ if st.button("Generate Diagnosis"):
1090
+ if target_col and selected_feature_cols:
1091
+ with st.spinner("Generating Diagnosis..."):
1092
+ result = st.session_state.diagnosis_support.diagnose(
1093
+ data, target_col=target_col, columns=selected_feature_cols, diagnosis_key="diagnosis_result"
1094
+ )
1095
+ st.json(result)
1096
+ else:
1097
+ st.error("Please select both target and feature variables for diagnosis.")
1098
+
1099
+ # Treatment Recommendation
1100
+ st.subheader("Treatment Recommendation")
1101
+ condition_col = st.selectbox("Select Condition Column for Treatment Recommendation", data.columns.tolist())
1102
+ treatment_col = st.selectbox("Select Treatment Column for Treatment Recommendation", data.columns.tolist())
1103
+
1104
+ if st.button("Generate Treatment Recommendation"):
1105
+ if condition_col and treatment_col:
1106
+ with st.spinner("Generating Treatment Recommendation..."):
1107
+ result = st.session_state.treatment_recommendation.recommend(
1108
+ data, condition_col=condition_col, treatment_col=treatment_col, recommendation_key="treatment_recommendation"
1109
+ )
1110
+ st.json(result)
1111
+ else:
1112
+ st.error("Please select both condition and treatment columns.")
1113
+
1114
+ def reports_section():
1115
+ """Handles the Reports tab."""
1116
+ st.header("Automated Reports")
1117
+
1118
+ # Create Report Definition
1119
+ st.subheader("Create Report Definition")
1120
+ report_name = st.text_input("Report Name")
1121
+ report_def = st.text_area("Report Definition", help="Describe the structure and content of the report.")
1122
+
1123
+ if st.button("Create Report Definition"):
1124
+ if report_name and report_def:
1125
+ st.session_state.automated_reports.create_report_definition(report_name, report_def)
1126
+ st.success("Report definition created successfully.")
1127
+ else:
1128
+ st.error("Please provide both report name and definition.")
1129
+
1130
+ # Generate Report
1131
+ st.subheader("Generate Report")
1132
+ report_names = list(st.session_state.automated_reports.report_definitions.keys())
1133
+ if report_names:
1134
+ report_name_to_generate = st.selectbox("Select Report to Generate", report_names)
1135
+ if st.button("Generate Report"):
1136
+ with st.spinner("Generating Report..."):
1137
+ report = st.session_state.automated_reports.generate_report(report_name_to_generate, st.session_state.data)
1138
+ if "error" not in report:
1139
+ st.header(f"Report: {report['Report Name']}")
1140
+ st.markdown(f"**Definition:** {report['Report Definition']}")
1141
+ for df_name, df_content in report["Data"].items():
1142
+ st.subheader(f"Data: {df_name}")
1143
+ st.dataframe(pd.DataFrame(df_content))
1144
+ else:
1145
+ st.error(report["error"])
1146
+ else:
1147
+ st.info("No report definitions found. Please create a report definition first.")
1148
+
1149
+ def medical_knowledge_section():
1150
+ """Handles the Medical Knowledge tab."""
1151
+ st.header("Medical Knowledge")
1152
+ query = st.text_input("Enter your medical question here:")
1153
+
1154
+ if st.button("Search"):
1155
+ if query.strip():
1156
+ with st.spinner("Searching..."):
1157
+ result = st.session_state.knowledge_base.search_medical_info(
1158
+ query, pub_email=st.session_state.pub_email
1159
+ )
1160
+ st.markdown(result)
1161
+ else:
1162
+ st.error("Please enter a medical question to search.")
1163
+
1164
+ if __name__ == "__main__":
1165
+ main()