mgbam commited on
Commit
55ef016
·
verified ·
1 Parent(s): 211e3a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -19
app.py CHANGED
@@ -55,7 +55,7 @@ if not OPENAI_API_KEY:
55
 
56
  # Instantiate the OpenAI client
57
  try:
58
- client = OpenAI(api_key=OPENAI_API_KEY) # Instantiating the client right here
59
  except Exception as e:
60
  st.error(f"Failed to initialize OpenAI client: {e}")
61
  logger.error(f"Failed to initialize OpenAI client: {e}")
@@ -239,35 +239,48 @@ class HypothesisTester(DataAnalyzer):
239
  return "No significant evidence against H0"
240
 
241
  from sklearn.impute import SimpleImputer
 
242
 
243
  class LogisticRegressionTrainer(DataAnalyzer):
244
- """Logistic Regression Model Trainer with Missing Value Handling."""
245
  def invoke(self, data: pd.DataFrame, target_col: str, columns: List[str], **kwargs) -> Dict[str, Any]:
246
  try:
247
- X = data[columns]
248
- y = data[target_col]
 
 
249
 
250
- # Check for missing values in X
 
 
 
251
  if X.isnull().values.any():
252
  logger.info("Missing values detected in feature variables. Applying imputation.")
253
- imputer = SimpleImputer(strategy='mean') # You can choose 'median', 'most_frequent', etc.
254
  X_imputed = imputer.fit_transform(X)
255
  X = pd.DataFrame(X_imputed, columns=columns)
256
  logger.info("Imputation completed for feature variables.")
257
  else:
258
  logger.info("No missing values detected in feature variables.")
259
 
260
- # Check for missing values in y
261
  if y.isnull().values.any():
262
- logger.info("Missing values detected in target variable. Applying imputation.")
263
- # For classification, it's common to impute with the mode
264
- y_imputer = SimpleImputer(strategy='most_frequent')
265
- y_imputed = y_imputer.fit_transform(y.values.reshape(-1, 1))
266
- y = pd.Series(y_imputer.ravel())
267
- logger.info("Imputation completed for target variable.")
268
  else:
269
  logger.info("No missing values detected in target variable.")
270
 
 
 
 
 
 
 
 
271
  # Split the data
272
  X_train, X_test, y_train, y_test = train_test_split(
273
  X, y, test_size=0.2, random_state=42
@@ -275,7 +288,7 @@ class LogisticRegressionTrainer(DataAnalyzer):
275
  logger.info("Data split into training and testing sets.")
276
 
277
  # Initialize and train the model
278
- model = LogisticRegression(max_iter=1000)
279
  model.fit(X_train, y_train)
280
  logger.info("Logistic Regression model training completed.")
281
 
@@ -293,7 +306,6 @@ class LogisticRegressionTrainer(DataAnalyzer):
293
  logger.error(f"Logistic Regression Model Error: {str(e)}")
294
  return {"error": f"Logistic Regression Model Error: {str(e)}"}
295
 
296
-
297
  # ---------------------- Business Logic Layer ---------------------------
298
 
299
  class ClinicalRule(BaseModel):
@@ -544,7 +556,7 @@ class SimpleMedicalKnowledge(MedicalKnowledgeBase):
544
  )
545
 
546
  # Extract the answer from the response
547
- answer = response.choices[0].message.content.strip() # Corrected access
548
 
549
  logger.info("Successfully retrieved data from OpenAI GPT-4.")
550
 
@@ -800,7 +812,7 @@ def initialize_session_state():
800
 
801
  if 'openai_client' not in st.session_state:
802
  # Instantiate the OpenAI client only if it doesn't exist in session state
803
- st.session_state.openai_client = client # The one created earlier
804
 
805
  if 'data' not in st.session_state:
806
  st.session_state.data = {} # Store pd.DataFrame under a name
@@ -826,7 +838,7 @@ def initialize_session_state():
826
  if 'knowledge_base' not in st.session_state:
827
  st.session_state.knowledge_base = SimpleMedicalKnowledge(nlp_model=nlp, client=st.session_state.openai_client)
828
  if 'pub_email' not in st.session_state:
829
- st.session_state.pub_email = PUB_EMAIL # Load PUB_EMAIL from environment variables
830
  if 'treatment_recommendation' not in st.session_state:
831
  st.session_state.treatment_recommendation = BasicTreatmentRecommendation()
832
 
@@ -1209,4 +1221,4 @@ def medical_knowledge_section():
1209
  st.error("Please enter a medical question to search.")
1210
 
1211
  if __name__ == "__main__":
1212
- main()
 
55
 
56
  # Instantiate the OpenAI client
57
  try:
58
+ client = OpenAI(api_key=OPENAI_API_KEY) # Instantiating the client right here
59
  except Exception as e:
60
  st.error(f"Failed to initialize OpenAI client: {e}")
61
  logger.error(f"Failed to initialize OpenAI client: {e}")
 
239
  return "No significant evidence against H0"
240
 
241
  from sklearn.impute import SimpleImputer
242
+ from sklearn.preprocessing import LabelEncoder
243
 
244
  class LogisticRegressionTrainer(DataAnalyzer):
245
+ """Logistic Regression Model Trainer with Missing Value Handling and Target Encoding."""
246
  def invoke(self, data: pd.DataFrame, target_col: str, columns: List[str], **kwargs) -> Dict[str, Any]:
247
  try:
248
+ # Prevent data leakage by removing target_col from features if present
249
+ if target_col in columns:
250
+ columns.remove(target_col)
251
+ logger.warning(f"Removed target column '{target_col}' from feature list to prevent data leakage.")
252
 
253
+ X = data[columns].copy()
254
+ y = data[target_col].copy()
255
+
256
+ # Handle missing values in X
257
  if X.isnull().values.any():
258
  logger.info("Missing values detected in feature variables. Applying imputation.")
259
+ imputer = SimpleImputer(strategy='mean') # Choose strategy as needed
260
  X_imputed = imputer.fit_transform(X)
261
  X = pd.DataFrame(X_imputed, columns=columns)
262
  logger.info("Imputation completed for feature variables.")
263
  else:
264
  logger.info("No missing values detected in feature variables.")
265
 
266
+ # Handle missing values in y
267
  if y.isnull().values.any():
268
+ logger.info("Missing values detected in target variable. Dropping missing targets.")
269
+ # For classification, it's common to impute with the mode or drop missing targets
270
+ data = data.dropna(subset=[target_col])
271
+ y = data[target_col]
272
+ X = data[columns]
273
+ logger.info("Dropped rows with missing target values.")
274
  else:
275
  logger.info("No missing values detected in target variable.")
276
 
277
+ # Encode target if it's categorical and not numeric
278
+ if y.dtype == 'object' or y.dtype.name == 'category':
279
+ logger.info("Encoding categorical target variable.")
280
+ label_encoder = LabelEncoder()
281
+ y = label_encoder.fit_transform(y)
282
+ logger.info("Encoding completed.")
283
+
284
  # Split the data
285
  X_train, X_test, y_train, y_test = train_test_split(
286
  X, y, test_size=0.2, random_state=42
 
288
  logger.info("Data split into training and testing sets.")
289
 
290
  # Initialize and train the model
291
+ model = LogisticRegression(max_iter=1000, multi_class='auto', solver='lbfgs')
292
  model.fit(X_train, y_train)
293
  logger.info("Logistic Regression model training completed.")
294
 
 
306
  logger.error(f"Logistic Regression Model Error: {str(e)}")
307
  return {"error": f"Logistic Regression Model Error: {str(e)}"}
308
 
 
309
  # ---------------------- Business Logic Layer ---------------------------
310
 
311
  class ClinicalRule(BaseModel):
 
556
  )
557
 
558
  # Extract the answer from the response
559
+ answer = response.choices[0].message.content.strip() # Corrected access
560
 
561
  logger.info("Successfully retrieved data from OpenAI GPT-4.")
562
 
 
812
 
813
  if 'openai_client' not in st.session_state:
814
  # Instantiate the OpenAI client only if it doesn't exist in session state
815
+ st.session_state.openai_client = client # The one created earlier
816
 
817
  if 'data' not in st.session_state:
818
  st.session_state.data = {} # Store pd.DataFrame under a name
 
838
  if 'knowledge_base' not in st.session_state:
839
  st.session_state.knowledge_base = SimpleMedicalKnowledge(nlp_model=nlp, client=st.session_state.openai_client)
840
  if 'pub_email' not in st.session_state:
841
+ st.session_state.pub_email = PUB_EMAIL # Load PUB_EMAIL from environment variables
842
  if 'treatment_recommendation' not in st.session_state:
843
  st.session_state.treatment_recommendation = BasicTreatmentRecommendation()
844
 
 
1221
  st.error("Please enter a medical question to search.")
1222
 
1223
  if __name__ == "__main__":
1224
+ main()