mgbam commited on
Commit
330fc43
·
verified ·
1 Parent(s): 7ab03d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -57
app.py CHANGED
@@ -18,10 +18,14 @@ import os
18
  import numpy as np
19
  from scipy.stats import ttest_ind, f_oneway
20
  import json
 
 
 
21
 
22
  # Initialize Groq Client
23
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
24
 
 
25
  # ---------------------- Base Classes and Schemas ---------------------------
26
  class ResearchInput(BaseModel):
27
  """Base schema for research tool inputs"""
@@ -287,17 +291,50 @@ class MedicalKnowledgeBase():
287
  pass
288
 
289
  class SimpleMedicalKnowledge(MedicalKnowledgeBase):
290
- """Simple Medical Knowledge Class"""
291
- def search_medical_info(self, query: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  try:
293
- if "diabetes treatment" in query.lower():
294
- return "The recommended treatment for diabetes includes lifestyle changes, medication, and monitoring"
295
- elif "heart disease risk factors" in query.lower():
296
- return "Risk factors for heart disease include high blood pressure, high cholesterol, and smoking"
 
 
 
 
 
297
  else:
298
- return "No specific information is available"
299
  except Exception as e:
300
- return f"Medical Knowledge Search Failed {e}"
301
 
302
 
303
  class ForecastingEngine(ABC):
@@ -494,8 +531,9 @@ def main():
494
  st.session_state.treatment_recommendation = BasicTreatmentRecommendation()
495
  if 'knowledge_base' not in st.session_state:
496
  st.session_state.knowledge_base = SimpleMedicalKnowledge()
 
 
497
 
498
-
499
  # Sidebar for Data Management
500
  with st.sidebar:
501
  st.header("⚙️ Data Management")
@@ -668,51 +706,4 @@ def main():
668
  st.json(result)
669
  with insights_tab:
670
  if selected_data_key:
671
- data = st.session_state.data[selected_data_key]
672
- available_analysis = ["EDA", "temporal", "distribution", "hypothesis", "model"]
673
- selected_analysis = st.multiselect("Select Analysis", available_analysis)
674
- if st.button("Generate Automated Insights"):
675
- with st.spinner("Generating Insights"):
676
- results = st.session_state.automated_insights.generate_insights(data, analysis_names=selected_analysis)
677
- st.json(results)
678
- st.subheader("Diagnosis Support")
679
- target_col = st.selectbox("Select Target Variable for Diagnosis", data.columns.tolist())
680
- num_cols = data.select_dtypes(include=np.number).columns.tolist()
681
- selected_cols_diagnosis = st.multiselect("Select Feature Variables for Diagnosis", num_cols)
682
- if st.button("Generate Diagnosis"):
683
- if target_col and selected_cols_diagnosis:
684
- with st.spinner("Generating Diagnosis"):
685
- result = st.session_state.diagnosis_support.diagnose(data, target_col=target_col, columns=selected_cols_diagnosis, diagnosis_key="diagnosis_result")
686
- st.json(result)
687
-
688
- st.subheader("Treatment Recommendation")
689
- condition_col = st.selectbox("Select Condition Column for Treatment Recommendation", data.columns.tolist())
690
- treatment_col = st.selectbox("Select Treatment Column for Treatment Recommendation", data.columns.tolist())
691
- if st.button("Generate Treatment Recommendation"):
692
- if condition_col and treatment_col:
693
- with st.spinner("Generating Treatment Recommendation"):
694
- result = st.session_state.treatment_recommendation.recommend(data, condition_col = condition_col, treatment_col = treatment_col, recommendation_key="treatment_recommendation")
695
- st.json(result)
696
-
697
- with reports_tab:
698
- st.header("Reports")
699
- report_name = st.text_input("Report Name")
700
- report_def = st.text_area("Report definition")
701
- if st.button("Create Report Definition"):
702
- st.session_state.automated_reports.create_report_definition(report_name, report_def)
703
- st.success("Report definition created")
704
- if selected_data_key:
705
- data = st.session_state.data
706
- if st.button("Generate Report"):
707
- with st.spinner("Generating Report..."):
708
- report = st.session_state.automated_reports.generate_report(report_name, data)
709
- with knowledge_tab:
710
- st.header("Medical Knowledge")
711
- query = st.text_input("Enter your medical question here:")
712
- if st.button("Search"):
713
- with st.spinner("Searching..."):
714
- result = st.session_state.knowledge_base.search_medical_info(query)
715
- st.write(result)
716
-
717
- if __name__ == "__main__":
718
- main()
 
18
  import numpy as np
19
  from scipy.stats import ttest_ind, f_oneway
20
  import json
21
+ from Bio import Entrez
22
+ from sklearn.feature_extraction.text import TfidfVectorizer
23
+ from sklearn.metrics.pairwise import cosine_similarity
24
 
25
  # Initialize Groq Client
26
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
27
 
28
+
29
  # ---------------------- Base Classes and Schemas ---------------------------
30
  class ResearchInput(BaseModel):
31
  """Base schema for research tool inputs"""
 
291
  pass
292
 
293
  class SimpleMedicalKnowledge(MedicalKnowledgeBase):
294
+ """Simple Medical Knowledge Class with TF-IDF and PubMed"""
295
+ def __init__(self):
296
+ self.knowledge_base = {
297
+ "diabetes": "The recommended treatment for diabetes includes lifestyle changes, medication, and monitoring.",
298
+ "heart disease": "Risk factors for heart disease include high blood pressure, high cholesterol, and smoking.",
299
+ "fever": "For a fever, you can consider over-the-counter medications like acetaminophen or ibuprofen. Rest and hydration are also important.",
300
+ "headache": "For a headache, try rest, hydration, and over-the-counter pain relievers. Consult a doctor if it is severe or persistent.",
301
+ "cold": "For a cold, get rest, drink plenty of fluids, and use over-the-counter remedies like decongestants."
302
+ }
303
+ self.vectorizer = TfidfVectorizer()
304
+ self.tfidf_matrix = self.vectorizer.fit_transform(self.knowledge_base.values())
305
+
306
+ def search_pubmed(self, query: str, email: str) -> str:
307
+ try:
308
+ Entrez.email = email
309
+ handle = Entrez.esearch(db="pubmed", term=query, retmax=1)
310
+ record = Entrez.read(handle)
311
+ handle.close()
312
+ if record["IdList"]:
313
+ handle = Entrez.efetch(db="pubmed", id=record["IdList"][0], rettype="abstract", retmode="text")
314
+ abstract = handle.read()
315
+ handle.close()
316
+ return abstract
317
+ else:
318
+ return "No abstracts found for this query on PubMed"
319
+ except Exception as e:
320
+ return f"Error searching pubmed {e}"
321
+
322
+
323
+ def search_medical_info(self, query: str, pub_email: str = "") -> str:
324
  try:
325
+ query_vector = self.vectorizer.transform([query])
326
+ similarities = cosine_similarity(query_vector, self.tfidf_matrix)
327
+ best_match_index = np.argmax(similarities)
328
+ best_match_keyword = list(self.knowledge_base.keys())[best_match_index]
329
+ best_match_info = list(self.knowledge_base.values())[best_match_index]
330
+
331
+ pubmed_result = self.search_pubmed(query, pub_email)
332
+ if "No abstracts found for this query on PubMed" not in pubmed_result:
333
+ return f"Based on the query provided, I found this: {best_match_info} \n\nFrom Pubmed I also found the following abstract: \n {pubmed_result}"
334
  else:
335
+ return f"Based on the query provided, I found this: {best_match_info} \n\n{pubmed_result}"
336
  except Exception as e:
337
+ return f"Medical Knowledge Search Failed {e}"
338
 
339
 
340
  class ForecastingEngine(ABC):
 
531
  st.session_state.treatment_recommendation = BasicTreatmentRecommendation()
532
  if 'knowledge_base' not in st.session_state:
533
  st.session_state.knowledge_base = SimpleMedicalKnowledge()
534
+ if 'pub_email' not in st.session_state:
535
+ st.session_state.pub_email = st.secrets.get("PUB_EMAIL", "") # Load PUB_EMAIL from secrets
536
 
 
537
  # Sidebar for Data Management
538
  with st.sidebar:
539
  st.header("⚙️ Data Management")
 
706
  st.json(result)
707
  with insights_tab:
708
  if selected_data_key:
709
+