Update app.py
Browse files
app.py
CHANGED
@@ -25,8 +25,12 @@ from statsmodels.tsa.stattools import adfuller
|
|
25 |
from pydantic import BaseModel, Field
|
26 |
from Bio import Entrez # Ensure BioPython is installed
|
27 |
|
28 |
-
from
|
29 |
-
|
|
|
|
|
|
|
|
|
30 |
|
31 |
# ---------------------- Streamlit Page Configuration ---------------------------
|
32 |
# This must be the first Streamlit command in the script
|
@@ -34,7 +38,18 @@ st.set_page_config(page_title="AI Clinical Intelligence Hub", layout="wide")
|
|
34 |
|
35 |
# ---------------------- Initialize External Clients ---------------------------
|
36 |
# Initialize Groq Client with API Key from environment variables
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
# Load spaCy model with error handling
|
40 |
try:
|
@@ -250,9 +265,8 @@ class ClinicalRulesEngine:
|
|
250 |
results = {}
|
251 |
for rule_name, rule in self.rules.items():
|
252 |
try:
|
253 |
-
#
|
254 |
-
|
255 |
-
rule_matched = eval(rule.condition, {"__builtins__": None}, {"df": data})
|
256 |
results[rule_name] = {
|
257 |
"rule_matched": rule_matched,
|
258 |
"action": rule.action if rule_matched else None,
|
@@ -266,6 +280,26 @@ class ClinicalRulesEngine:
|
|
266 |
}
|
267 |
return results
|
268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
class ClinicalKPI(BaseModel):
|
270 |
"""Define a clinical KPI."""
|
271 |
name: str
|
@@ -284,8 +318,8 @@ class ClinicalKPIMonitoring:
|
|
284 |
results = {}
|
285 |
for kpi_name, kpi in self.kpis.items():
|
286 |
try:
|
287 |
-
#
|
288 |
-
kpi_value =
|
289 |
status = self.evaluate_threshold(kpi_value, kpi.threshold)
|
290 |
results[kpi_name] = {
|
291 |
"value": kpi_value,
|
@@ -305,6 +339,26 @@ class ClinicalKPIMonitoring:
|
|
305 |
except TypeError:
|
306 |
return "Threshold Evaluation Not Applicable"
|
307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
class DiagnosisSupport(ABC):
|
309 |
"""Abstract class for implementing clinical diagnoses."""
|
310 |
@abstractmethod
|
@@ -394,21 +448,67 @@ class MedicalKnowledgeBase(ABC):
|
|
394 |
pass
|
395 |
|
396 |
class SimpleMedicalKnowledge(MedicalKnowledgeBase):
|
397 |
-
"""
|
398 |
def __init__(self):
|
399 |
-
self.
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
"
|
405 |
-
"
|
406 |
}
|
407 |
-
|
408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
|
410 |
-
|
411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
try:
|
413 |
Entrez.email = email
|
414 |
handle = Entrez.esearch(db="pubmed", term=query, retmax=1, sort='relevance')
|
@@ -424,68 +524,6 @@ class SimpleMedicalKnowledge(MedicalKnowledgeBase):
|
|
424 |
except Exception as e:
|
425 |
return f"Error searching PubMed: {e}"
|
426 |
|
427 |
-
def search_medical_info(self, query: str, pub_email: str = "") -> str:
|
428 |
-
"""Search the medical knowledge base and PubMed for relevant information."""
|
429 |
-
try:
|
430 |
-
query_lower = query.lower()
|
431 |
-
doc = nlp(query_lower)
|
432 |
-
entities = [ent.text for ent in doc.ents]
|
433 |
-
|
434 |
-
if entities:
|
435 |
-
best_match_keyword = ""
|
436 |
-
best_match_score = -1
|
437 |
-
for entity in entities:
|
438 |
-
query_vector = self.vectorizer.transform([entity])
|
439 |
-
similarities = cosine_similarity(query_vector, self.tfidf_matrix)
|
440 |
-
current_best_match_index = np.argmax(similarities)
|
441 |
-
current_best_score = np.max(similarities)
|
442 |
-
if current_best_score > best_match_score:
|
443 |
-
best_match_keyword = list(self.knowledge_base.keys())[current_best_match_index]
|
444 |
-
best_match_score = current_best_score
|
445 |
-
else:
|
446 |
-
query_vector = self.vectorizer.transform([query_lower])
|
447 |
-
similarities = cosine_similarity(query_vector, self.tfidf_matrix)
|
448 |
-
best_match_index = np.argmax(similarities)
|
449 |
-
best_match_keyword = list(self.knowledge_base.keys())[best_match_index]
|
450 |
-
|
451 |
-
best_match_info = self.knowledge_base.get(best_match_keyword, "No specific information is available based on the query provided.")
|
452 |
-
|
453 |
-
# Enhanced PubMed Search: Combine query and best_match_keyword for better relevance
|
454 |
-
pubmed_query = f"{query_lower} AND {best_match_keyword}"
|
455 |
-
pubmed_result = self.search_pubmed(pubmed_query, pub_email)
|
456 |
-
|
457 |
-
feedback_key = f"feedback_{query_lower}" # Creating a unique key for feedback
|
458 |
-
|
459 |
-
response = f"**Based on your query:** {best_match_info}\n\n"
|
460 |
-
|
461 |
-
if "Error searching PubMed" not in pubmed_result and "No abstracts found" not in pubmed_result:
|
462 |
-
# Format the PubMed abstract with proper markdown
|
463 |
-
abstract_title = pubmed_result.split('\n')[0] # Assuming the first line is the title
|
464 |
-
abstract_body = '\n'.join(pubmed_result.split('\n')[2:]) # Skipping authors and affiliations
|
465 |
-
response += f"**PubMed Abstract:**\n\n**{abstract_title}**\n\n{abstract_body}"
|
466 |
-
else:
|
467 |
-
response += f"{pubmed_result}"
|
468 |
-
|
469 |
-
# Initialize feedback in session state
|
470 |
-
if feedback_key not in st.session_state:
|
471 |
-
st.session_state[feedback_key] = {"feedback": None}
|
472 |
-
|
473 |
-
# Display feedback buttons only if a valid response is generated
|
474 |
-
if "Error searching PubMed" not in pubmed_result:
|
475 |
-
col1, col2 = st.columns([1, 1])
|
476 |
-
with col1:
|
477 |
-
if st.button("Good Result", key=f"good_{feedback_key}"):
|
478 |
-
st.session_state[feedback_key]["feedback"] = "positive"
|
479 |
-
st.success("Thank you for the feedback!")
|
480 |
-
with col2:
|
481 |
-
if st.button("Bad Result", key=f"bad_{feedback_key}"):
|
482 |
-
st.session_state[feedback_key]["feedback"] = "negative"
|
483 |
-
st.error("Thank you for the feedback!")
|
484 |
-
|
485 |
-
return response
|
486 |
-
except Exception as e:
|
487 |
-
return f"Medical Knowledge Search Failed: {e}"
|
488 |
-
|
489 |
# ---------------------- Forecasting Engine ---------------------------
|
490 |
|
491 |
class ForecastingEngine(ABC):
|
@@ -711,7 +749,7 @@ def initialize_session_state():
|
|
711 |
if 'knowledge_base' not in st.session_state:
|
712 |
st.session_state.knowledge_base = SimpleMedicalKnowledge()
|
713 |
if 'pub_email' not in st.session_state:
|
714 |
-
st.session_state.pub_email =
|
715 |
|
716 |
def data_management_section():
|
717 |
"""Handles the data management section in the sidebar."""
|
|
|
25 |
from pydantic import BaseModel, Field
|
26 |
from Bio import Entrez # Ensure BioPython is installed
|
27 |
|
28 |
+
from dotenv import load_dotenv
|
29 |
+
import requests
|
30 |
+
import ast
|
31 |
+
|
32 |
+
# ---------------------- Load Environment Variables ---------------------------
|
33 |
+
load_dotenv()
|
34 |
|
35 |
# ---------------------- Streamlit Page Configuration ---------------------------
|
36 |
# This must be the first Streamlit command in the script
|
|
|
38 |
|
39 |
# ---------------------- Initialize External Clients ---------------------------
|
40 |
# Initialize Groq Client with API Key from environment variables
|
41 |
+
GROQ_API_ENDPOINT = os.getenv("GROQ_API_ENDPOINT")
|
42 |
+
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
43 |
+
PUB_EMAIL = os.getenv("PUB_EMAIL", "")
|
44 |
+
|
45 |
+
if not GROQ_API_ENDPOINT or not GROQ_API_KEY:
|
46 |
+
st.error("Groq API endpoint and key must be set as environment variables.")
|
47 |
+
st.stop()
|
48 |
+
|
49 |
+
headers = {
|
50 |
+
"Authorization": f"Bearer {GROQ_API_KEY}",
|
51 |
+
"Content-Type": "application/json"
|
52 |
+
}
|
53 |
|
54 |
# Load spaCy model with error handling
|
55 |
try:
|
|
|
265 |
results = {}
|
266 |
for rule_name, rule in self.rules.items():
|
267 |
try:
|
268 |
+
# Using safe_eval instead of eval for security
|
269 |
+
rule_matched = self.safe_eval(rule.condition, {"df": data})
|
|
|
270 |
results[rule_name] = {
|
271 |
"rule_matched": rule_matched,
|
272 |
"action": rule.action if rule_matched else None,
|
|
|
280 |
}
|
281 |
return results
|
282 |
|
283 |
+
@staticmethod
|
284 |
+
def safe_eval(expr, variables):
|
285 |
+
"""
|
286 |
+
Safely evaluate an expression using AST parsing.
|
287 |
+
Only allows certain node types to prevent execution of arbitrary code.
|
288 |
+
"""
|
289 |
+
allowed_nodes = (
|
290 |
+
ast.Expression, ast.BoolOp, ast.BinOp, ast.UnaryOp, ast.Compare,
|
291 |
+
ast.Call, ast.Name, ast.Load, ast.Constant, ast.Num, ast.Str,
|
292 |
+
ast.List, ast.Tuple, ast.Dict
|
293 |
+
)
|
294 |
+
try:
|
295 |
+
node = ast.parse(expr, mode='eval')
|
296 |
+
for subnode in ast.walk(node):
|
297 |
+
if not isinstance(subnode, allowed_nodes):
|
298 |
+
raise ValueError(f"Unsupported expression: {expr}")
|
299 |
+
return eval(compile(node, '<string>', mode='eval'), {"__builtins__": None}, variables)
|
300 |
+
except Exception as e:
|
301 |
+
raise ValueError(f"Invalid expression: {e}")
|
302 |
+
|
303 |
class ClinicalKPI(BaseModel):
|
304 |
"""Define a clinical KPI."""
|
305 |
name: str
|
|
|
318 |
results = {}
|
319 |
for kpi_name, kpi in self.kpis.items():
|
320 |
try:
|
321 |
+
# Using safe_eval instead of eval for security
|
322 |
+
kpi_value = self.safe_eval(kpi.calculation, {"df": data})
|
323 |
status = self.evaluate_threshold(kpi_value, kpi.threshold)
|
324 |
results[kpi_name] = {
|
325 |
"value": kpi_value,
|
|
|
339 |
except TypeError:
|
340 |
return "Threshold Evaluation Not Applicable"
|
341 |
|
342 |
+
@staticmethod
|
343 |
+
def safe_eval(expr, variables):
|
344 |
+
"""
|
345 |
+
Safely evaluate an expression using AST parsing.
|
346 |
+
Only allows certain node types to prevent execution of arbitrary code.
|
347 |
+
"""
|
348 |
+
allowed_nodes = (
|
349 |
+
ast.Expression, ast.BoolOp, ast.BinOp, ast.UnaryOp, ast.Compare,
|
350 |
+
ast.Call, ast.Name, ast.Load, ast.Constant, ast.Num, ast.Str,
|
351 |
+
ast.List, ast.Tuple, ast.Dict
|
352 |
+
)
|
353 |
+
try:
|
354 |
+
node = ast.parse(expr, mode='eval')
|
355 |
+
for subnode in ast.walk(node):
|
356 |
+
if not isinstance(subnode, allowed_nodes):
|
357 |
+
raise ValueError(f"Unsupported expression: {expr}")
|
358 |
+
return eval(compile(node, '<string>', mode='eval'), {"__builtins__": None}, variables)
|
359 |
+
except Exception as e:
|
360 |
+
raise ValueError(f"Invalid expression: {e}")
|
361 |
+
|
362 |
class DiagnosisSupport(ABC):
|
363 |
"""Abstract class for implementing clinical diagnoses."""
|
364 |
@abstractmethod
|
|
|
448 |
pass
|
449 |
|
450 |
class SimpleMedicalKnowledge(MedicalKnowledgeBase):
|
451 |
+
"""Enhanced Medical Knowledge Class using Groq API."""
|
452 |
def __init__(self):
|
453 |
+
self.api_endpoint = GROQ_API_ENDPOINT
|
454 |
+
self.api_key = GROQ_API_KEY
|
455 |
+
self.pub_email = PUB_EMAIL
|
456 |
+
|
457 |
+
self.headers = {
|
458 |
+
"Authorization": f"Bearer {self.api_key}",
|
459 |
+
"Content-Type": "application/json"
|
460 |
}
|
461 |
+
|
462 |
+
# Initialize spaCy model for entity recognition if needed
|
463 |
+
try:
|
464 |
+
self.nlp = spacy.load("en_core_web_sm")
|
465 |
+
except OSError:
|
466 |
+
import subprocess
|
467 |
+
import sys
|
468 |
+
subprocess.run([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
|
469 |
+
self.nlp = spacy.load("en_core_web_sm")
|
470 |
+
|
471 |
+
def search_medical_info(self, query: str, pub_email: str = "") -> str:
|
472 |
+
"""
|
473 |
+
Uses the Groq API to fetch medical information based on the user's query.
|
474 |
+
"""
|
475 |
+
try:
|
476 |
+
# Preprocess the query if necessary (e.g., entity recognition)
|
477 |
+
doc = self.nlp(query.lower())
|
478 |
+
entities = [ent.text for ent in doc.ents]
|
479 |
+
if entities:
|
480 |
+
processed_query = " ".join(entities)
|
481 |
+
else:
|
482 |
+
processed_query = query.lower()
|
483 |
+
|
484 |
+
# Prepare the payload for the Groq API
|
485 |
+
payload = {
|
486 |
+
"query": processed_query,
|
487 |
+
"context": "medical" # Assuming the API can handle context specification
|
488 |
+
}
|
489 |
+
|
490 |
+
# Make the API request
|
491 |
+
response = requests.post(
|
492 |
+
self.api_endpoint,
|
493 |
+
headers=self.headers,
|
494 |
+
data=json.dumps(payload)
|
495 |
+
)
|
496 |
|
497 |
+
if response.status_code == 200:
|
498 |
+
data = response.json()
|
499 |
+
answer = data.get("answer", "I'm sorry, I couldn't find relevant information.")
|
500 |
+
pubmed_abstract = self.fetch_pubmed_abstract(processed_query, pub_email)
|
501 |
+
return f"**Based on your query:** {answer}\n\n**PubMed Abstract:**\n\n{pubmed_abstract}"
|
502 |
+
else:
|
503 |
+
return f"Error: Received status code {response.status_code} from Groq API."
|
504 |
+
|
505 |
+
except Exception as e:
|
506 |
+
return f"Medical Knowledge Search Failed: {str(e)}"
|
507 |
+
|
508 |
+
def fetch_pubmed_abstract(self, query: str, email: str) -> str:
|
509 |
+
"""
|
510 |
+
Searches PubMed for abstracts related to the query.
|
511 |
+
"""
|
512 |
try:
|
513 |
Entrez.email = email
|
514 |
handle = Entrez.esearch(db="pubmed", term=query, retmax=1, sort='relevance')
|
|
|
524 |
except Exception as e:
|
525 |
return f"Error searching PubMed: {e}"
|
526 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
527 |
# ---------------------- Forecasting Engine ---------------------------
|
528 |
|
529 |
class ForecastingEngine(ABC):
|
|
|
749 |
if 'knowledge_base' not in st.session_state:
|
750 |
st.session_state.knowledge_base = SimpleMedicalKnowledge()
|
751 |
if 'pub_email' not in st.session_state:
|
752 |
+
st.session_state.pub_email = PUB_EMAIL # Load PUB_EMAIL from environment variables
|
753 |
|
754 |
def data_management_section():
|
755 |
"""Handles the data management section in the sidebar."""
|