from dotenv import load_dotenv import os import psycopg2 from contextlib import contextmanager from typing import List, Tuple, Optional import bcrypt import logging from datetime import datetime # <--- ADD THIS IMPORT # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Load environment variables from .env.local load_dotenv(dotenv_path=".env.local") # Database connection string from environment variable DB_CONNECT = os.getenv("DB_CONNECT") if not DB_CONNECT: raise ValueError("DB_CONNECT environment variable not set. Please check .env.local.") @contextmanager def get_db_connection(): """Context manager for database connections.""" conn = None try: conn = psycopg2.connect(DB_CONNECT) yield conn except psycopg2.Error as e: logging.error(f"Database connection failed: {str(e)}") raise Exception(f"Database connection failed: {str(e)}") finally: if conn: conn.close() def get_symptoms() -> List[str]: """Fetch all symptom names from the database.""" try: with get_db_connection() as conn: with conn.cursor() as cur: cur.execute("SELECT symptom_name FROM mb_symptoms ORDER BY symptom_name") symptoms = [row[0] for row in cur.fetchall()] if not symptoms: logging.warning("No symptoms found in mb_symptoms table.") else: logging.info(f"Loaded {len(symptoms)} symptoms from mb_symptoms.") return symptoms except Exception as e: logging.error(f"Error fetching symptoms: {str(e)}") raise Exception(f"Error fetching symptoms: {str(e)}") def get_disease_info(symptoms: List[str]) -> List[Tuple[str, str, List[str]]]: """Get possible diseases and precautions based on symptoms.""" try: logging.info(f"Querying diseases for symptoms: {symptoms}") with get_db_connection() as conn: with conn.cursor() as cur: # Check if symptoms exist in mb_symptoms cur.execute("SELECT symptom_name FROM mb_symptoms WHERE symptom_name IN %s", (tuple(symptoms),)) found_symptoms = [row[0] for row in cur.fetchall()] if not found_symptoms: logging.warning(f"No matching symptoms found in mb_symptoms for: {symptoms}") return [] elif len(found_symptoms) != len(symptoms): logging.warning(f"Some symptoms not found in mb_symptoms: {set(symptoms) - set(found_symptoms)}") # Check mappings in mb_disease_symptoms placeholders = ','.join(['%s'] * len(symptoms)) query = f""" SELECT d.disease_name, d.description, ARRAY[COALESCE(d.precaution_1, ''), COALESCE(d.precaution_2, ''), COALESCE(d.precaution_3, ''), COALESCE(d.precaution_4, '')] as precautions FROM mb_diseases d JOIN mb_disease_symptoms ds ON d.disease_id = ds.disease_id JOIN mb_symptoms s ON ds.symptom_id = s.symptom_id WHERE s.symptom_name IN ({placeholders}) GROUP BY d.disease_id, d.disease_name, d.description, d.precaution_1, d.precaution_2, d.precaution_3, d.precaution_4 ORDER BY COUNT(*) DESC LIMIT 5 """ cur.execute(query, symptoms) results = cur.fetchall() if not results: logging.warning(f"No diseases found for symptoms: {symptoms}. Check mb_disease_symptoms mappings.") else: logging.info(f"Found {len(results)} diseases for symptoms: {symptoms}") return results except Exception as e: logging.error(f"Error fetching disease info: {str(e)}") raise Exception(f"Error fetching disease info: {str(e)}") def save_user_history(user_id: str, symptoms: str, predicted_diseases: str) -> None: """Save user symptom history to database.""" try: with get_db_connection() as conn: with conn.cursor() as cur: cur.execute( "INSERT INTO mb_history (user_id, symptoms, predicted_diseases) VALUES (%s, %s, %s)", (user_id, symptoms, predicted_diseases) ) conn.commit() except Exception as e: logging.error(f"Error saving user history: {str(e)}") raise Exception(f"Error saving user history: {str(e)}") def get_user_history(user_id: str) -> List[Tuple[int, str, str, datetime]]: """Retrieve user history from database.""" try: with get_db_connection() as conn: with conn.cursor() as cur: cur.execute( "SELECT history_id, symptoms, predicted_diseases, query_timestamp FROM mb_history WHERE user_id = %s ORDER BY query_timestamp DESC", (user_id,) ) return cur.fetchall() except Exception as e: logging.error(f"Error fetching user history: {str(e)}") raise Exception(f"Error fetching user history: {str(e)}") def register_user(user_id: str, full_name: str, dob: str, email: Optional[str], password: str) -> bool: """Register a new user with hashed password.""" try: hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') with get_db_connection() as conn: with conn.cursor() as cur: cur.execute( "INSERT INTO mb_users (user_id, full_name, date_of_birth, email, password) VALUES (%s, %s, %s, %s, %s)", (user_id, full_name, dob, email or None, hashed_password) ) conn.commit() return True except psycopg2.IntegrityError: logging.error(f"User ID {user_id} already exists.") return False except Exception as e: logging.error(f"Error registering user: {str(e)}") raise Exception(f"Error registering user: {str(e)}") def user_exists(user_id: str) -> bool: """Check if user_id already exists.""" try: with get_db_connection() as conn: with conn.cursor() as cur: cur.execute("SELECT 1 FROM mb_users WHERE user_id = %s", (user_id,)) return cur.fetchone() is not None except Exception as e: logging.error(f"Error checking user existence: {str(e)}") raise Exception(f"Error checking user existence: {str(e)}") def check_user_credentials(user_id: str, password: str) -> bool: """Verify user credentials with hashed password.""" try: with get_db_connection() as conn: with conn.cursor() as cur: cur.execute("SELECT password FROM mb_users WHERE user_id = %s", (user_id,)) result = cur.fetchone() if result: stored_password = result[0] return bcrypt.checkpw(password.encode('utf-8'), stored_password.encode('utf-8')) return False except Exception as e: logging.error(f"Error checking user credentials: {str(e)}") raise Exception(f"Error checking user credentials: {str(e)}")