File size: 7,422 Bytes
4b02284
6081f39
 
 
 
2559b31
 
73d90ba
2559b31
 
 
6081f39
4b02284
 
 
6081f39
 
 
4b02284
 
 
6081f39
 
 
4b02284
6081f39
4b02284
6081f39
4b02284
2559b31
4b02284
6081f39
4b02284
 
6081f39
 
 
4b02284
 
 
 
2559b31
 
 
 
 
 
4b02284
2559b31
4b02284
6081f39
 
 
4b02284
2559b31
4b02284
 
2559b31
 
 
 
 
 
 
 
 
 
4b02284
73d90ba
4b02284
 
 
 
 
 
73d90ba
4b02284
 
 
73d90ba
4b02284
2559b31
 
 
 
 
 
4b02284
2559b31
4b02284
6081f39
 
 
4b02284
 
 
 
 
 
 
 
 
2559b31
4b02284
6081f39
73d90ba
6081f39
4b02284
 
 
 
 
 
 
 
 
2559b31
4b02284
6081f39
2559b31
 
6081f39
2559b31
6081f39
 
 
2559b31
 
6081f39
 
 
 
2559b31
6081f39
4b02284
2559b31
4b02284
6081f39
 
 
4b02284
 
 
 
 
 
2559b31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from dotenv import load_dotenv
import os
import psycopg2
from contextlib import contextmanager
from typing import List, Tuple, Optional
import bcrypt
import logging
from datetime import datetime # <--- ADD THIS IMPORT

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Load environment variables from .env.local
load_dotenv(dotenv_path=".env.local")

# Database connection string from environment variable
DB_CONNECT = os.getenv("DB_CONNECT")

if not DB_CONNECT:
    raise ValueError("DB_CONNECT environment variable not set. Please check .env.local.")

@contextmanager
def get_db_connection():
    """Context manager for database connections."""
    conn = None
    try:
        conn = psycopg2.connect(DB_CONNECT)
        yield conn
    except psycopg2.Error as e:
        logging.error(f"Database connection failed: {str(e)}")
        raise Exception(f"Database connection failed: {str(e)}")
    finally:
        if conn:
            conn.close()

def get_symptoms() -> List[str]:
    """Fetch all symptom names from the database."""
    try:
        with get_db_connection() as conn:
            with conn.cursor() as cur:
                cur.execute("SELECT symptom_name FROM mb_symptoms ORDER BY symptom_name")
                symptoms = [row[0] for row in cur.fetchall()]
                if not symptoms:
                    logging.warning("No symptoms found in mb_symptoms table.")
                else:
                    logging.info(f"Loaded {len(symptoms)} symptoms from mb_symptoms.")
                return symptoms
    except Exception as e:
        logging.error(f"Error fetching symptoms: {str(e)}")
        raise Exception(f"Error fetching symptoms: {str(e)}")

def get_disease_info(symptoms: List[str]) -> List[Tuple[str, str, List[str]]]:
    """Get possible diseases and precautions based on symptoms."""
    try:
        logging.info(f"Querying diseases for symptoms: {symptoms}")
        with get_db_connection() as conn:
            with conn.cursor() as cur:
                # Check if symptoms exist in mb_symptoms
                cur.execute("SELECT symptom_name FROM mb_symptoms WHERE symptom_name IN %s", (tuple(symptoms),))
                found_symptoms = [row[0] for row in cur.fetchall()]
                if not found_symptoms:
                    logging.warning(f"No matching symptoms found in mb_symptoms for: {symptoms}")
                    return []
                elif len(found_symptoms) != len(symptoms):
                    logging.warning(f"Some symptoms not found in mb_symptoms: {set(symptoms) - set(found_symptoms)}")

                # Check mappings in mb_disease_symptoms
                placeholders = ','.join(['%s'] * len(symptoms))
                query = f"""
                    SELECT d.disease_name, d.description, 
                           ARRAY[COALESCE(d.precaution_1, ''), COALESCE(d.precaution_2, ''), 
                                 COALESCE(d.precaution_3, ''), COALESCE(d.precaution_4, '')] as precautions
                    FROM mb_diseases d
                    JOIN mb_disease_symptoms ds ON d.disease_id = ds.disease_id
                    JOIN mb_symptoms s ON ds.symptom_id = s.symptom_id
                    WHERE s.symptom_name IN ({placeholders})
                    GROUP BY d.disease_id, d.disease_name, d.description, d.precaution_1, d.precaution_2, d.precaution_3, d.precaution_4
                    ORDER BY COUNT(*) DESC
                    LIMIT 5
                """
                cur.execute(query, symptoms)
                results = cur.fetchall()
                if not results:
                    logging.warning(f"No diseases found for symptoms: {symptoms}. Check mb_disease_symptoms mappings.")
                else:
                    logging.info(f"Found {len(results)} diseases for symptoms: {symptoms}")
                return results
    except Exception as e:
        logging.error(f"Error fetching disease info: {str(e)}")
        raise Exception(f"Error fetching disease info: {str(e)}")

def save_user_history(user_id: str, symptoms: str, predicted_diseases: str) -> None:
    """Save user symptom history to database."""
    try:
        with get_db_connection() as conn:
            with conn.cursor() as cur:
                cur.execute(
                    "INSERT INTO mb_history (user_id, symptoms, predicted_diseases) VALUES (%s, %s, %s)",
                    (user_id, symptoms, predicted_diseases)
                )
                conn.commit()
    except Exception as e:
        logging.error(f"Error saving user history: {str(e)}")
        raise Exception(f"Error saving user history: {str(e)}")

def get_user_history(user_id: str) -> List[Tuple[int, str, str, datetime]]:
    """Retrieve user history from database."""
    try:
        with get_db_connection() as conn:
            with conn.cursor() as cur:
                cur.execute(
                    "SELECT history_id, symptoms, predicted_diseases, query_timestamp FROM mb_history WHERE user_id = %s ORDER BY query_timestamp DESC",
                    (user_id,)
                )
                return cur.fetchall()
    except Exception as e:
        logging.error(f"Error fetching user history: {str(e)}")
        raise Exception(f"Error fetching user history: {str(e)}")

def register_user(user_id: str, full_name: str, dob: str, email: Optional[str], password: str) -> bool:
    """Register a new user with hashed password."""
    try:
        hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
        with get_db_connection() as conn:
            with conn.cursor() as cur:
                cur.execute(
                    "INSERT INTO mb_users (user_id, full_name, date_of_birth, email, password) VALUES (%s, %s, %s, %s, %s)",
                    (user_id, full_name, dob, email or None, hashed_password)
                )
                conn.commit()
                return True
    except psycopg2.IntegrityError:
        logging.error(f"User ID {user_id} already exists.")
        return False
    except Exception as e:
        logging.error(f"Error registering user: {str(e)}")
        raise Exception(f"Error registering user: {str(e)}")

def user_exists(user_id: str) -> bool:
    """Check if user_id already exists."""
    try:
        with get_db_connection() as conn:
            with conn.cursor() as cur:
                cur.execute("SELECT 1 FROM mb_users WHERE user_id = %s", (user_id,))
                return cur.fetchone() is not None
    except Exception as e:
        logging.error(f"Error checking user existence: {str(e)}")
        raise Exception(f"Error checking user existence: {str(e)}")

def check_user_credentials(user_id: str, password: str) -> bool:
    """Verify user credentials with hashed password."""
    try:
        with get_db_connection() as conn:
            with conn.cursor() as cur:
                cur.execute("SELECT password FROM mb_users WHERE user_id = %s", (user_id,))
                result = cur.fetchone()
                if result:
                    stored_password = result[0]
                    return bcrypt.checkpw(password.encode('utf-8'), stored_password.encode('utf-8'))
                return False
    except Exception as e:
        logging.error(f"Error checking user credentials: {str(e)}")
        raise Exception(f"Error checking user credentials: {str(e)}")