medibot / database.py
abidkh's picture
Final app version before submission.
73d90ba
from dotenv import load_dotenv
import os
import psycopg2
from contextlib import contextmanager
from typing import List, Tuple, Optional
import bcrypt
import logging
from datetime import datetime # <--- ADD THIS IMPORT
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Load environment variables from .env.local
load_dotenv(dotenv_path=".env.local")
# Database connection string from environment variable
DB_CONNECT = os.getenv("DB_CONNECT")
if not DB_CONNECT:
raise ValueError("DB_CONNECT environment variable not set. Please check .env.local.")
@contextmanager
def get_db_connection():
"""Context manager for database connections."""
conn = None
try:
conn = psycopg2.connect(DB_CONNECT)
yield conn
except psycopg2.Error as e:
logging.error(f"Database connection failed: {str(e)}")
raise Exception(f"Database connection failed: {str(e)}")
finally:
if conn:
conn.close()
def get_symptoms() -> List[str]:
"""Fetch all symptom names from the database."""
try:
with get_db_connection() as conn:
with conn.cursor() as cur:
cur.execute("SELECT symptom_name FROM mb_symptoms ORDER BY symptom_name")
symptoms = [row[0] for row in cur.fetchall()]
if not symptoms:
logging.warning("No symptoms found in mb_symptoms table.")
else:
logging.info(f"Loaded {len(symptoms)} symptoms from mb_symptoms.")
return symptoms
except Exception as e:
logging.error(f"Error fetching symptoms: {str(e)}")
raise Exception(f"Error fetching symptoms: {str(e)}")
def get_disease_info(symptoms: List[str]) -> List[Tuple[str, str, List[str]]]:
"""Get possible diseases and precautions based on symptoms."""
try:
logging.info(f"Querying diseases for symptoms: {symptoms}")
with get_db_connection() as conn:
with conn.cursor() as cur:
# Check if symptoms exist in mb_symptoms
cur.execute("SELECT symptom_name FROM mb_symptoms WHERE symptom_name IN %s", (tuple(symptoms),))
found_symptoms = [row[0] for row in cur.fetchall()]
if not found_symptoms:
logging.warning(f"No matching symptoms found in mb_symptoms for: {symptoms}")
return []
elif len(found_symptoms) != len(symptoms):
logging.warning(f"Some symptoms not found in mb_symptoms: {set(symptoms) - set(found_symptoms)}")
# Check mappings in mb_disease_symptoms
placeholders = ','.join(['%s'] * len(symptoms))
query = f"""
SELECT d.disease_name, d.description,
ARRAY[COALESCE(d.precaution_1, ''), COALESCE(d.precaution_2, ''),
COALESCE(d.precaution_3, ''), COALESCE(d.precaution_4, '')] as precautions
FROM mb_diseases d
JOIN mb_disease_symptoms ds ON d.disease_id = ds.disease_id
JOIN mb_symptoms s ON ds.symptom_id = s.symptom_id
WHERE s.symptom_name IN ({placeholders})
GROUP BY d.disease_id, d.disease_name, d.description, d.precaution_1, d.precaution_2, d.precaution_3, d.precaution_4
ORDER BY COUNT(*) DESC
LIMIT 5
"""
cur.execute(query, symptoms)
results = cur.fetchall()
if not results:
logging.warning(f"No diseases found for symptoms: {symptoms}. Check mb_disease_symptoms mappings.")
else:
logging.info(f"Found {len(results)} diseases for symptoms: {symptoms}")
return results
except Exception as e:
logging.error(f"Error fetching disease info: {str(e)}")
raise Exception(f"Error fetching disease info: {str(e)}")
def save_user_history(user_id: str, symptoms: str, predicted_diseases: str) -> None:
"""Save user symptom history to database."""
try:
with get_db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"INSERT INTO mb_history (user_id, symptoms, predicted_diseases) VALUES (%s, %s, %s)",
(user_id, symptoms, predicted_diseases)
)
conn.commit()
except Exception as e:
logging.error(f"Error saving user history: {str(e)}")
raise Exception(f"Error saving user history: {str(e)}")
def get_user_history(user_id: str) -> List[Tuple[int, str, str, datetime]]:
"""Retrieve user history from database."""
try:
with get_db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT history_id, symptoms, predicted_diseases, query_timestamp FROM mb_history WHERE user_id = %s ORDER BY query_timestamp DESC",
(user_id,)
)
return cur.fetchall()
except Exception as e:
logging.error(f"Error fetching user history: {str(e)}")
raise Exception(f"Error fetching user history: {str(e)}")
def register_user(user_id: str, full_name: str, dob: str, email: Optional[str], password: str) -> bool:
"""Register a new user with hashed password."""
try:
hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
with get_db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"INSERT INTO mb_users (user_id, full_name, date_of_birth, email, password) VALUES (%s, %s, %s, %s, %s)",
(user_id, full_name, dob, email or None, hashed_password)
)
conn.commit()
return True
except psycopg2.IntegrityError:
logging.error(f"User ID {user_id} already exists.")
return False
except Exception as e:
logging.error(f"Error registering user: {str(e)}")
raise Exception(f"Error registering user: {str(e)}")
def user_exists(user_id: str) -> bool:
"""Check if user_id already exists."""
try:
with get_db_connection() as conn:
with conn.cursor() as cur:
cur.execute("SELECT 1 FROM mb_users WHERE user_id = %s", (user_id,))
return cur.fetchone() is not None
except Exception as e:
logging.error(f"Error checking user existence: {str(e)}")
raise Exception(f"Error checking user existence: {str(e)}")
def check_user_credentials(user_id: str, password: str) -> bool:
"""Verify user credentials with hashed password."""
try:
with get_db_connection() as conn:
with conn.cursor() as cur:
cur.execute("SELECT password FROM mb_users WHERE user_id = %s", (user_id,))
result = cur.fetchone()
if result:
stored_password = result[0]
return bcrypt.checkpw(password.encode('utf-8'), stored_password.encode('utf-8'))
return False
except Exception as e:
logging.error(f"Error checking user credentials: {str(e)}")
raise Exception(f"Error checking user credentials: {str(e)}")