|
from dotenv import load_dotenv |
|
import os |
|
import psycopg2 |
|
from contextlib import contextmanager |
|
from typing import List, Tuple, Optional |
|
import bcrypt |
|
import logging |
|
from datetime import datetime |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
load_dotenv(dotenv_path=".env.local") |
|
|
|
|
|
DB_CONNECT = os.getenv("DB_CONNECT") |
|
|
|
if not DB_CONNECT: |
|
raise ValueError("DB_CONNECT environment variable not set. Please check .env.local.") |
|
|
|
@contextmanager |
|
def get_db_connection(): |
|
"""Context manager for database connections.""" |
|
conn = None |
|
try: |
|
conn = psycopg2.connect(DB_CONNECT) |
|
yield conn |
|
except psycopg2.Error as e: |
|
logging.error(f"Database connection failed: {str(e)}") |
|
raise Exception(f"Database connection failed: {str(e)}") |
|
finally: |
|
if conn: |
|
conn.close() |
|
|
|
def get_symptoms() -> List[str]: |
|
"""Fetch all symptom names from the database.""" |
|
try: |
|
with get_db_connection() as conn: |
|
with conn.cursor() as cur: |
|
cur.execute("SELECT symptom_name FROM mb_symptoms ORDER BY symptom_name") |
|
symptoms = [row[0] for row in cur.fetchall()] |
|
if not symptoms: |
|
logging.warning("No symptoms found in mb_symptoms table.") |
|
else: |
|
logging.info(f"Loaded {len(symptoms)} symptoms from mb_symptoms.") |
|
return symptoms |
|
except Exception as e: |
|
logging.error(f"Error fetching symptoms: {str(e)}") |
|
raise Exception(f"Error fetching symptoms: {str(e)}") |
|
|
|
def get_disease_info(symptoms: List[str]) -> List[Tuple[str, str, List[str]]]: |
|
"""Get possible diseases and precautions based on symptoms.""" |
|
try: |
|
logging.info(f"Querying diseases for symptoms: {symptoms}") |
|
with get_db_connection() as conn: |
|
with conn.cursor() as cur: |
|
|
|
cur.execute("SELECT symptom_name FROM mb_symptoms WHERE symptom_name IN %s", (tuple(symptoms),)) |
|
found_symptoms = [row[0] for row in cur.fetchall()] |
|
if not found_symptoms: |
|
logging.warning(f"No matching symptoms found in mb_symptoms for: {symptoms}") |
|
return [] |
|
elif len(found_symptoms) != len(symptoms): |
|
logging.warning(f"Some symptoms not found in mb_symptoms: {set(symptoms) - set(found_symptoms)}") |
|
|
|
|
|
placeholders = ','.join(['%s'] * len(symptoms)) |
|
query = f""" |
|
SELECT d.disease_name, d.description, |
|
ARRAY[COALESCE(d.precaution_1, ''), COALESCE(d.precaution_2, ''), |
|
COALESCE(d.precaution_3, ''), COALESCE(d.precaution_4, '')] as precautions |
|
FROM mb_diseases d |
|
JOIN mb_disease_symptoms ds ON d.disease_id = ds.disease_id |
|
JOIN mb_symptoms s ON ds.symptom_id = s.symptom_id |
|
WHERE s.symptom_name IN ({placeholders}) |
|
GROUP BY d.disease_id, d.disease_name, d.description, d.precaution_1, d.precaution_2, d.precaution_3, d.precaution_4 |
|
ORDER BY COUNT(*) DESC |
|
LIMIT 5 |
|
""" |
|
cur.execute(query, symptoms) |
|
results = cur.fetchall() |
|
if not results: |
|
logging.warning(f"No diseases found for symptoms: {symptoms}. Check mb_disease_symptoms mappings.") |
|
else: |
|
logging.info(f"Found {len(results)} diseases for symptoms: {symptoms}") |
|
return results |
|
except Exception as e: |
|
logging.error(f"Error fetching disease info: {str(e)}") |
|
raise Exception(f"Error fetching disease info: {str(e)}") |
|
|
|
def save_user_history(user_id: str, symptoms: str, predicted_diseases: str) -> None: |
|
"""Save user symptom history to database.""" |
|
try: |
|
with get_db_connection() as conn: |
|
with conn.cursor() as cur: |
|
cur.execute( |
|
"INSERT INTO mb_history (user_id, symptoms, predicted_diseases) VALUES (%s, %s, %s)", |
|
(user_id, symptoms, predicted_diseases) |
|
) |
|
conn.commit() |
|
except Exception as e: |
|
logging.error(f"Error saving user history: {str(e)}") |
|
raise Exception(f"Error saving user history: {str(e)}") |
|
|
|
def get_user_history(user_id: str) -> List[Tuple[int, str, str, datetime]]: |
|
"""Retrieve user history from database.""" |
|
try: |
|
with get_db_connection() as conn: |
|
with conn.cursor() as cur: |
|
cur.execute( |
|
"SELECT history_id, symptoms, predicted_diseases, query_timestamp FROM mb_history WHERE user_id = %s ORDER BY query_timestamp DESC", |
|
(user_id,) |
|
) |
|
return cur.fetchall() |
|
except Exception as e: |
|
logging.error(f"Error fetching user history: {str(e)}") |
|
raise Exception(f"Error fetching user history: {str(e)}") |
|
|
|
def register_user(user_id: str, full_name: str, dob: str, email: Optional[str], password: str) -> bool: |
|
"""Register a new user with hashed password.""" |
|
try: |
|
hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') |
|
with get_db_connection() as conn: |
|
with conn.cursor() as cur: |
|
cur.execute( |
|
"INSERT INTO mb_users (user_id, full_name, date_of_birth, email, password) VALUES (%s, %s, %s, %s, %s)", |
|
(user_id, full_name, dob, email or None, hashed_password) |
|
) |
|
conn.commit() |
|
return True |
|
except psycopg2.IntegrityError: |
|
logging.error(f"User ID {user_id} already exists.") |
|
return False |
|
except Exception as e: |
|
logging.error(f"Error registering user: {str(e)}") |
|
raise Exception(f"Error registering user: {str(e)}") |
|
|
|
def user_exists(user_id: str) -> bool: |
|
"""Check if user_id already exists.""" |
|
try: |
|
with get_db_connection() as conn: |
|
with conn.cursor() as cur: |
|
cur.execute("SELECT 1 FROM mb_users WHERE user_id = %s", (user_id,)) |
|
return cur.fetchone() is not None |
|
except Exception as e: |
|
logging.error(f"Error checking user existence: {str(e)}") |
|
raise Exception(f"Error checking user existence: {str(e)}") |
|
|
|
def check_user_credentials(user_id: str, password: str) -> bool: |
|
"""Verify user credentials with hashed password.""" |
|
try: |
|
with get_db_connection() as conn: |
|
with conn.cursor() as cur: |
|
cur.execute("SELECT password FROM mb_users WHERE user_id = %s", (user_id,)) |
|
result = cur.fetchone() |
|
if result: |
|
stored_password = result[0] |
|
return bcrypt.checkpw(password.encode('utf-8'), stored_password.encode('utf-8')) |
|
return False |
|
except Exception as e: |
|
logging.error(f"Error checking user credentials: {str(e)}") |
|
raise Exception(f"Error checking user credentials: {str(e)}") |