Spaces:
Sleeping
Sleeping
from dotenv import load_dotenv | |
import os | |
import psycopg2 | |
from contextlib import contextmanager | |
from typing import List, Tuple, Optional | |
import bcrypt | |
import logging | |
from datetime import datetime # <--- ADD THIS IMPORT | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# Load environment variables from .env.local | |
load_dotenv(dotenv_path=".env.local") | |
# Database connection string from environment variable | |
DB_CONNECT = os.getenv("DB_CONNECT") | |
if not DB_CONNECT: | |
raise ValueError("DB_CONNECT environment variable not set. Please check .env.local.") | |
def get_db_connection(): | |
"""Context manager for database connections.""" | |
conn = None | |
try: | |
conn = psycopg2.connect(DB_CONNECT) | |
yield conn | |
except psycopg2.Error as e: | |
logging.error(f"Database connection failed: {str(e)}") | |
raise Exception(f"Database connection failed: {str(e)}") | |
finally: | |
if conn: | |
conn.close() | |
def get_symptoms() -> List[str]: | |
"""Fetch all symptom names from the database.""" | |
try: | |
with get_db_connection() as conn: | |
with conn.cursor() as cur: | |
cur.execute("SELECT symptom_name FROM mb_symptoms ORDER BY symptom_name") | |
symptoms = [row[0] for row in cur.fetchall()] | |
if not symptoms: | |
logging.warning("No symptoms found in mb_symptoms table.") | |
else: | |
logging.info(f"Loaded {len(symptoms)} symptoms from mb_symptoms.") | |
return symptoms | |
except Exception as e: | |
logging.error(f"Error fetching symptoms: {str(e)}") | |
raise Exception(f"Error fetching symptoms: {str(e)}") | |
def get_disease_info(symptoms: List[str]) -> List[Tuple[str, str, List[str]]]: | |
"""Get possible diseases and precautions based on symptoms.""" | |
try: | |
logging.info(f"Querying diseases for symptoms: {symptoms}") | |
with get_db_connection() as conn: | |
with conn.cursor() as cur: | |
# Check if symptoms exist in mb_symptoms | |
cur.execute("SELECT symptom_name FROM mb_symptoms WHERE symptom_name IN %s", (tuple(symptoms),)) | |
found_symptoms = [row[0] for row in cur.fetchall()] | |
if not found_symptoms: | |
logging.warning(f"No matching symptoms found in mb_symptoms for: {symptoms}") | |
return [] | |
elif len(found_symptoms) != len(symptoms): | |
logging.warning(f"Some symptoms not found in mb_symptoms: {set(symptoms) - set(found_symptoms)}") | |
# Check mappings in mb_disease_symptoms | |
placeholders = ','.join(['%s'] * len(symptoms)) | |
query = f""" | |
SELECT d.disease_name, d.description, | |
ARRAY[COALESCE(d.precaution_1, ''), COALESCE(d.precaution_2, ''), | |
COALESCE(d.precaution_3, ''), COALESCE(d.precaution_4, '')] as precautions | |
FROM mb_diseases d | |
JOIN mb_disease_symptoms ds ON d.disease_id = ds.disease_id | |
JOIN mb_symptoms s ON ds.symptom_id = s.symptom_id | |
WHERE s.symptom_name IN ({placeholders}) | |
GROUP BY d.disease_id, d.disease_name, d.description, d.precaution_1, d.precaution_2, d.precaution_3, d.precaution_4 | |
ORDER BY COUNT(*) DESC | |
LIMIT 5 | |
""" | |
cur.execute(query, symptoms) | |
results = cur.fetchall() | |
if not results: | |
logging.warning(f"No diseases found for symptoms: {symptoms}. Check mb_disease_symptoms mappings.") | |
else: | |
logging.info(f"Found {len(results)} diseases for symptoms: {symptoms}") | |
return results | |
except Exception as e: | |
logging.error(f"Error fetching disease info: {str(e)}") | |
raise Exception(f"Error fetching disease info: {str(e)}") | |
def save_user_history(user_id: str, symptoms: str, predicted_diseases: str) -> None: | |
"""Save user symptom history to database.""" | |
try: | |
with get_db_connection() as conn: | |
with conn.cursor() as cur: | |
cur.execute( | |
"INSERT INTO mb_history (user_id, symptoms, predicted_diseases) VALUES (%s, %s, %s)", | |
(user_id, symptoms, predicted_diseases) | |
) | |
conn.commit() | |
except Exception as e: | |
logging.error(f"Error saving user history: {str(e)}") | |
raise Exception(f"Error saving user history: {str(e)}") | |
def get_user_history(user_id: str) -> List[Tuple[int, str, str, datetime]]: | |
"""Retrieve user history from database.""" | |
try: | |
with get_db_connection() as conn: | |
with conn.cursor() as cur: | |
cur.execute( | |
"SELECT history_id, symptoms, predicted_diseases, query_timestamp FROM mb_history WHERE user_id = %s ORDER BY query_timestamp DESC", | |
(user_id,) | |
) | |
return cur.fetchall() | |
except Exception as e: | |
logging.error(f"Error fetching user history: {str(e)}") | |
raise Exception(f"Error fetching user history: {str(e)}") | |
def register_user(user_id: str, full_name: str, dob: str, email: Optional[str], password: str) -> bool: | |
"""Register a new user with hashed password.""" | |
try: | |
hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') | |
with get_db_connection() as conn: | |
with conn.cursor() as cur: | |
cur.execute( | |
"INSERT INTO mb_users (user_id, full_name, date_of_birth, email, password) VALUES (%s, %s, %s, %s, %s)", | |
(user_id, full_name, dob, email or None, hashed_password) | |
) | |
conn.commit() | |
return True | |
except psycopg2.IntegrityError: | |
logging.error(f"User ID {user_id} already exists.") | |
return False | |
except Exception as e: | |
logging.error(f"Error registering user: {str(e)}") | |
raise Exception(f"Error registering user: {str(e)}") | |
def user_exists(user_id: str) -> bool: | |
"""Check if user_id already exists.""" | |
try: | |
with get_db_connection() as conn: | |
with conn.cursor() as cur: | |
cur.execute("SELECT 1 FROM mb_users WHERE user_id = %s", (user_id,)) | |
return cur.fetchone() is not None | |
except Exception as e: | |
logging.error(f"Error checking user existence: {str(e)}") | |
raise Exception(f"Error checking user existence: {str(e)}") | |
def check_user_credentials(user_id: str, password: str) -> bool: | |
"""Verify user credentials with hashed password.""" | |
try: | |
with get_db_connection() as conn: | |
with conn.cursor() as cur: | |
cur.execute("SELECT password FROM mb_users WHERE user_id = %s", (user_id,)) | |
result = cur.fetchone() | |
if result: | |
stored_password = result[0] | |
return bcrypt.checkpw(password.encode('utf-8'), stored_password.encode('utf-8')) | |
return False | |
except Exception as e: | |
logging.error(f"Error checking user credentials: {str(e)}") | |
raise Exception(f"Error checking user credentials: {str(e)}") |