Cuong2004 commited on
Commit
ac0f906
·
1 Parent(s): 188c1cd

first commit

Browse files
.dockerignore ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Git
2
+ .git
3
+ .gitignore
4
+ .gitattributes
5
+
6
+ # Environment files
7
+ .env
8
+ .env.*
9
+ !.env.example
10
+
11
+ # Python cache files
12
+ __pycache__/
13
+ *.py[cod]
14
+ *$py.class
15
+ *.so
16
+ .Python
17
+ .pytest_cache/
18
+ *.egg-info/
19
+ .installed.cfg
20
+ *.egg
21
+
22
+ # Logs
23
+ *.log
24
+
25
+ # Tests
26
+ tests/
27
+
28
+ # Docker related
29
+ Dockerfile
30
+ docker-compose.yml
31
+ .dockerignore
32
+
33
+ # Other files
34
+ .vscode/
35
+ .idea/
36
+ *.swp
37
+ *.swo
38
+ .DS_Store
39
+ .coverage
40
+ htmlcov/
41
+ .mypy_cache/
42
+ .tox/
43
+ .nox/
44
+ instance/
45
+ .webassets-cache
46
+ main.py
.env.example ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PostgreSQL Configuration
2
+ DB_CONNECTION_MODE=aiven
3
+ AIVEN_DB_URL=postgresql://username:password@host:port/dbname?sslmode=require
4
+
5
+ # MongoDB Configuration
6
+ MONGODB_URL=mongodb+srv://username:[email protected]/?retryWrites=true&w=majority
7
+ DB_NAME=Telegram
8
+ COLLECTION_NAME=session_chat
9
+
10
+ # Pinecone configuration
11
+ PINECONE_API_KEY=your-pinecone-api-key
12
+ PINECONE_INDEX_NAME=your-pinecone-index-name
13
+ PINECONE_ENVIRONMENT=gcp-starter
14
+
15
+ # Google Gemini API key
16
+ GOOGLE_API_KEY=your-google-api-key
17
+
18
+ # WebSocket configuration
19
+ WEBSOCKET_SERVER=localhost
20
+ WEBSOCKET_PORT=7860
21
+ WEBSOCKET_PATH=/notify
22
+
23
+ # Application settings
24
+ ENVIRONMENT=production
25
+ DEBUG=false
26
+ PORT=7860
.gitattributes CHANGED
@@ -1,35 +1,29 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ # Auto detect text files and perform LF normalization
2
+ * text=auto eol=lf
3
+
4
+ # Documents
5
+ *.md text
6
+ *.txt text
7
+ *.ini text
8
+ *.yaml text
9
+ *.yml text
10
+ *.json text
11
+ *.py text
12
+ *.env.example text
13
+
14
+ # Binary files
15
+ *.png binary
16
+ *.jpg binary
17
+ *.jpeg binary
18
+ *.gif binary
19
+ *.ico binary
20
+ *.db binary
21
+
22
+ # Git related files
23
+ .gitignore text
24
+ .gitattributes text
25
+
26
+ # Docker related files
27
+ Dockerfile text
28
+ docker-compose.yml text
29
+ .dockerignore text
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+ .pytest_cache/
23
+ htmlcov/
24
+ .coverage
25
+ .coverage.*
26
+ .cache/
27
+ coverage.xml
28
+ *.cover
29
+ .mypy_cache/
30
+
31
+ # Environment
32
+ .env
33
+ .venv
34
+ env/
35
+ venv/
36
+ ENV/
37
+
38
+ # VSCode
39
+ .vscode/
40
+ *.code-workspace
41
+ .history/
42
+
43
+ # PyCharm
44
+ .idea/
45
+ *.iml
46
+ *.iws
47
+ *.ipr
48
+ *.iws
49
+ out/
50
+ .idea_modules/
51
+
52
+ # Logs and databases
53
+ *.log
54
+ *.sql
55
+ *.sqlite
56
+ *.db
57
+
58
+ # Tests
59
+ tests/
60
+
61
+ Admin_bot/
62
+
63
+ # Hugging Face Spaces
64
+ .gitattributes
65
+
66
+ # OS specific
67
+ .DS_Store
68
+ .DS_Store?
69
+ ._*
70
+ .Spotlight-V100
71
+ .Trashes
72
+ Icon?
73
+ ehthumbs.db
74
+ Thumbs.db
75
+
76
+ # Project specific
77
+ *.log
78
+ .env
79
+ main.py
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Cài đặt các gói hệ thống cần thiết
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ curl \
9
+ software-properties-common \
10
+ git \
11
+ gcc \
12
+ python3-dev \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Sao chép các file yêu cầu trước để tận dụng cache của Docker
16
+ COPY requirements.txt .
17
+
18
+ # Cài đặt các gói Python
19
+ RUN pip install --no-cache-dir -r requirements.txt
20
+
21
+ # Ensure langchain-core is installed
22
+ RUN pip install --no-cache-dir langchain-core==0.1.19
23
+
24
+ # Sao chép toàn bộ code vào container
25
+ COPY . .
26
+
27
+ # Mở cổng mà ứng dụng sẽ chạy
28
+ EXPOSE 7860
29
+
30
+ # Chạy ứng dụng với uvicorn
31
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,361 @@
1
  ---
2
- title: Pix Agent
3
- emoji: 🦀
4
- colorFrom: indigo
5
- colorTo: indigo
6
  sdk: docker
 
 
7
  pinned: false
8
  ---
9
 
10
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: PIX Project Backend
3
+ emoji: 🤖
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: docker
7
+ sdk_version: "3.0.0"
8
+ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
+
14
+ # PIX Project Backend
15
+
16
+ [![FastAPI](https://img.shields.io/badge/FastAPI-0.103.1-009688?style=flat&logo=fastapi&logoColor=white)](https://fastapi.tiangolo.com/)
17
+ [![Python 3.11](https://img.shields.io/badge/Python-3.11-3776AB?style=flat&logo=python&logoColor=white)](https://www.python.org/)
18
+ [![HuggingFace Spaces](https://img.shields.io/badge/HuggingFace-Spaces-yellow?style=flat&logo=huggingface&logoColor=white)](https://huggingface.co/spaces)
19
+
20
+ Backend API for PIX Project with MongoDB, PostgreSQL and RAG integration. This project provides a comprehensive backend solution for managing FAQ items, emergency contacts, events, and a RAG-based question answering system.
21
+
22
+ ## Features
23
+
24
+ - **MongoDB Integration**: Store user sessions and conversation history
25
+ - **PostgreSQL Integration**: Manage FAQ items, emergency contacts, and events
26
+ - **Pinecone Vector Database**: Store and retrieve vector embeddings for RAG
27
+ - **RAG Question Answering**: Answer questions using relevant information from the vector database
28
+ - **WebSocket Notifications**: Real-time notifications for Admin Bot
29
+ - **API Documentation**: Automatic OpenAPI documentation via Swagger
30
+ - **Docker Support**: Easy deployment using Docker
31
+ - **Auto-Debugging**: Built-in debugging, error tracking, and performance monitoring
32
+
33
+ ## API Endpoints
34
+
35
+ ### MongoDB Endpoints
36
+
37
+ - `POST /mongodb/session`: Create a new session record
38
+ - `PUT /mongodb/session/{session_id}/response`: Update a session with a response
39
+ - `GET /mongodb/history`: Get user conversation history
40
+ - `GET /mongodb/health`: Check MongoDB connection health
41
+
42
+ ### PostgreSQL Endpoints
43
+
44
+ - `GET /postgres/health`: Check PostgreSQL connection health
45
+ - `GET /postgres/faq`: Get FAQ items
46
+ - `POST /postgres/faq`: Create a new FAQ item
47
+ - `GET /postgres/faq/{faq_id}`: Get a specific FAQ item
48
+ - `PUT /postgres/faq/{faq_id}`: Update a specific FAQ item
49
+ - `DELETE /postgres/faq/{faq_id}`: Delete a specific FAQ item
50
+ - `GET /postgres/emergency`: Get emergency contact items
51
+ - `POST /postgres/emergency`: Create a new emergency contact item
52
+ - `GET /postgres/emergency/{emergency_id}`: Get a specific emergency contact
53
+ - `GET /postgres/events`: Get event items
54
+
55
+ ### RAG Endpoints
56
+
57
+ - `POST /rag/chat`: Get answer for a question using RAG
58
+ - `POST /rag/embedding`: Generate embedding for text
59
+ - `GET /rag/health`: Check RAG services health
60
+
61
+ ### WebSocket Endpoints
62
+
63
+ - `WebSocket /notify`: Receive real-time notifications for new sessions
64
+
65
+ ### Debug Endpoints (Available in Debug Mode Only)
66
+
67
+ - `GET /debug/config`: Get configuration information
68
+ - `GET /debug/system`: Get system information (CPU, memory, disk usage)
69
+ - `GET /debug/database`: Check all database connections
70
+ - `GET /debug/errors`: View recent error logs
71
+ - `GET /debug/performance`: Get performance metrics
72
+ - `GET /debug/full`: Get comprehensive debug information
73
+
74
+ ## WebSocket API
75
+
76
+ ### Notifications for New Sessions
77
+
78
+ The backend provides a WebSocket endpoint for receiving notifications about new sessions that match specific criteria.
79
+
80
+ #### WebSocket Endpoint Configuration
81
+
82
+ The WebSocket endpoint is configured using environment variables:
83
+
84
+ ```
85
+ # WebSocket configuration
86
+ WEBSOCKET_SERVER=localhost
87
+ WEBSOCKET_PORT=7860
88
+ WEBSOCKET_PATH=/notify
89
+ ```
90
+
91
+ The full WebSocket URL will be:
92
+ ```
93
+ ws://{WEBSOCKET_SERVER}:{WEBSOCKET_PORT}{WEBSOCKET_PATH}
94
+ ```
95
+
96
+ For example: `ws://localhost:7860/notify`
97
+
98
+ #### Notification Criteria
99
+
100
+ A notification is sent when:
101
+ 1. A new session is created with `factor` set to "RAG"
102
+ 2. The message content starts with "I don't know"
103
+
104
+ #### Notification Format
105
+
106
+ ```json
107
+ {
108
+ "type": "new_session",
109
+ "timestamp": "2025-04-15 22:30:45",
110
+ "data": {
111
+ "session_id": "123e4567-e89b-12d3-a456-426614174000",
112
+ "factor": "rag",
113
+ "action": "asking_freely",
114
+ "created_at": "2025-04-15 22:30:45",
115
+ "first_name": "John",
116
+ "last_name": "Doe",
117
+ "message": "I don't know how to find emergency contacts",
118
+ "user_id": "12345678",
119
+ "username": "johndoe"
120
+ }
121
+ }
122
+ ```
123
+
124
+ #### Usage Example
125
+
126
+ Admin Bot should establish a WebSocket connection to this endpoint using the configured URL:
127
+
128
+ ```python
129
+ import websocket
130
+ import json
131
+ import os
132
+ from dotenv import load_dotenv
133
+
134
+ # Load environment variables
135
+ load_dotenv()
136
+
137
+ # Get WebSocket configuration from environment variables
138
+ WEBSOCKET_SERVER = os.getenv("WEBSOCKET_SERVER", "localhost")
139
+ WEBSOCKET_PORT = os.getenv("WEBSOCKET_PORT", "7860")
140
+ WEBSOCKET_PATH = os.getenv("WEBSOCKET_PATH", "/notify")
141
+
142
+ # Create full URL
143
+ ws_url = f"ws://{WEBSOCKET_SERVER}:{WEBSOCKET_PORT}{WEBSOCKET_PATH}"
144
+
145
+ def on_message(ws, message):
146
+ data = json.loads(message)
147
+ print(f"Received notification: {data}")
148
+ # Forward to Telegram Admin
149
+
150
+ def on_error(ws, error):
151
+ print(f"Error: {error}")
152
+
153
+ def on_close(ws, close_status_code, close_msg):
154
+ print("Connection closed")
155
+
156
+ def on_open(ws):
157
+ print("Connection opened")
158
+ # Send keepalive message periodically
159
+ ws.send("keepalive")
160
+
161
+ # Connect to WebSocket
162
+ ws = websocket.WebSocketApp(
163
+ ws_url,
164
+ on_open=on_open,
165
+ on_message=on_message,
166
+ on_error=on_error,
167
+ on_close=on_close
168
+ )
169
+ ws.run_forever()
170
+ ```
171
+
172
+ When a notification is received, Admin Bot should forward the content to the Telegram Admin.
173
+
174
+ ## Environment Variables
175
+
176
+ Create a `.env` file with the following variables:
177
+
178
+ ```
179
+ # PostgreSQL Configuration
180
+ DB_CONNECTION_MODE=aiven
181
+ AIVEN_DB_URL=postgresql://username:password@host:port/dbname?sslmode=require
182
+
183
+ # MongoDB Configuration
184
+ MONGODB_URL=mongodb+srv://username:[email protected]/?retryWrites=true&w=majority
185
+ DB_NAME=Telegram
186
+ COLLECTION_NAME=session_chat
187
+
188
+ # Pinecone configuration
189
+ PINECONE_API_KEY=your-pinecone-api-key
190
+ PINECONE_INDEX_NAME=your-pinecone-index-name
191
+ PINECONE_ENVIRONMENT=gcp-starter
192
+
193
+ # Google Gemini API key
194
+ GOOGLE_API_KEY=your-google-api-key
195
+
196
+ # WebSocket configuration
197
+ WEBSOCKET_SERVER=localhost
198
+ WEBSOCKET_PORT=7860
199
+ WEBSOCKET_PATH=/notify
200
+
201
+ # Application settings
202
+ ENVIRONMENT=production
203
+ DEBUG=false
204
+ PORT=7860
205
+ ```
206
+
207
+ ## Installation and Setup
208
+
209
+ ### Local Development
210
+
211
+ 1. Clone the repository:
212
+ ```bash
213
+ git clone https://github.com/ManTT-Data/PixAgent.git
214
+ cd PixAgent
215
+ ```
216
+
217
+ 2. Create a virtual environment and install dependencies:
218
+ ```bash
219
+ python -m venv venv
220
+ source venv/bin/activate # On Windows: venv\Scripts\activate
221
+ pip install -r requirements.txt
222
+ ```
223
+
224
+ 3. Create a `.env` file with your configuration (see above)
225
+
226
+ 4. Run the application:
227
+ ```bash
228
+ uvicorn app:app --reload --port 7860
229
+ ```
230
+
231
+ 5. Open your browser and navigate to [http://localhost:7860/docs](http://localhost:7860/docs) to see the API documentation
232
+
233
+ ### Docker Deployment
234
+
235
+ 1. Build the Docker image:
236
+ ```bash
237
+ docker build -t pix-project-backend .
238
+ ```
239
+
240
+ 2. Run the Docker container:
241
+ ```bash
242
+ docker run -p 7860:7860 --env-file .env pix-project-backend
243
+ ```
244
+
245
+ ## Deployment to HuggingFace Spaces
246
+
247
+ 1. Create a new Space on HuggingFace (Dockerfile type)
248
+ 2. Link your GitHub repository or push directly to the HuggingFace repo
249
+ 3. Add your environment variables in the Space settings
250
+ 4. The deployment will use `app.py` as the entry point, which is the standard for HuggingFace Spaces
251
+
252
+ ### Important Notes for HuggingFace Deployment
253
+
254
+ - The application uses `app.py` with the FastAPI instance named `app` to avoid the "Error loading ASGI app. Attribute 'app' not found in module 'app'" error
255
+ - Make sure all environment variables are set in the Space settings
256
+ - The Dockerfile is configured to expose port 7860, which is the default port for HuggingFace Spaces
257
+
258
+ ## Project Structure
259
+
260
+ ```
261
+ .
262
+ ├── app # Main application package
263
+ │ ├── api # API endpoints
264
+ │ │ ├── mongodb_routes.py
265
+ │ │ ├── postgresql_routes.py
266
+ │ │ ├── rag_routes.py
267
+ │ │ └── websocket_routes.py
268
+ │ ├── database # Database connections
269
+ │ │ ├── mongodb.py
270
+ │ │ ├── pinecone.py
271
+ │ │ └── postgresql.py
272
+ │ ├── models # Pydantic models
273
+ │ │ ├── mongodb_models.py
274
+ │ │ ├── postgresql_models.py
275
+ │ │ └── rag_models.py
276
+ │ └── utils # Utility functions
277
+ │ ├── debug_utils.py
278
+ │ └── middleware.py
279
+ ├── tests # Test directory
280
+ │ └── test_api_endpoints.py
281
+ ├── .dockerignore # Docker ignore file
282
+ ├── .env.example # Example environment file
283
+ ├── .gitattributes # Git attributes
284
+ ├── .gitignore # Git ignore file
285
+ ├── app.py # Application entry point
286
+ ├── docker-compose.yml # Docker compose configuration
287
+ ├── Dockerfile # Docker configuration
288
+ ├── pytest.ini # Pytest configuration
289
+ ├── README.md # Project documentation
290
+ ├── requirements.txt # Project dependencies
291
+ └── api_documentation.txt # API documentation for frontend engineers
292
+ ```
293
+
294
+ ## License
295
+
296
+ This project is licensed under the MIT License - see the LICENSE file for details.
297
+
298
+ # Advanced Retrieval System
299
+
300
+ This project now features an enhanced vector retrieval system that improves the quality and relevance of information retrieved from Pinecone using threshold-based filtering and multiple similarity metrics.
301
+
302
+ ## Features
303
+
304
+ ### 1. Threshold-Based Retrieval
305
+
306
+ The system implements a threshold-based approach to vector retrieval, which:
307
+ - Retrieves a larger candidate set from the vector database
308
+ - Applies a similarity threshold to filter out less relevant results
309
+ - Returns only the most relevant documents that exceed the threshold
310
+
311
+ ### 2. Multiple Similarity Metrics
312
+
313
+ The system supports multiple similarity metrics:
314
+ - **Cosine Similarity** (default): Measures the cosine of the angle between vectors
315
+ - **Dot Product**: Calculates the dot product between vectors
316
+ - **Euclidean Distance**: Measures the straight-line distance between vectors
317
+
318
+ Each metric has different characteristics and may perform better for different types of data and queries.
319
+
320
+ ### 3. Score Normalization
321
+
322
+ For metrics like Euclidean distance where lower values indicate higher similarity, the system automatically normalizes scores to a 0-1 scale where higher values always indicate higher similarity. This makes it easier to compare results across different metrics.
323
+
324
+ ## Configuration
325
+
326
+ The retrieval system can be configured through environment variables:
327
+
328
+ ```
329
+ # Pinecone retrieval configuration
330
+ PINECONE_DEFAULT_LIMIT_K=10 # Maximum number of candidates to retrieve
331
+ PINECONE_DEFAULT_TOP_K=6 # Number of results to return after filtering
332
+ PINECONE_DEFAULT_SIMILARITY_METRIC=cosine # Default similarity metric
333
+ PINECONE_DEFAULT_SIMILARITY_THRESHOLD=0.75 # Similarity threshold (0-1)
334
+ PINECONE_ALLOWED_METRICS=cosine,dotproduct,euclidean # Available metrics
335
+ ```
336
+
337
+ ## API Usage
338
+
339
+ You can customize the retrieval parameters when making API requests:
340
+
341
+ ```json
342
+ {
343
+ "user_id": "user123",
344
+ "question": "What are the best restaurants in Da Nang?",
345
+ "similarity_top_k": 5,
346
+ "limit_k": 15,
347
+ "similarity_metric": "cosine",
348
+ "similarity_threshold": 0.8
349
+ }
350
+ ```
351
+
352
+ ## Benefits
353
+
354
+ 1. **Quality Improvement**: Retrieves only the most relevant documents above a certain quality threshold
355
+ 2. **Flexibility**: Different similarity metrics can be used for different types of queries
356
+ 3. **Efficiency**: Avoids processing irrelevant documents, improving response time
357
+ 4. **Configurability**: All parameters can be adjusted via environment variables or at request time
358
+
359
+ ## Implementation Details
360
+
361
+ The system is implemented as a custom retriever class `ThresholdRetriever` that integrates with LangChain's retrieval infrastructure while providing enhanced functionality.
api_documentation.txt ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Frontend Integration Guide for PixAgent API
2
+
3
+ This guide provides instructions for integrating with the optimized PostgreSQL-based API endpoints for Event, FAQ, and Emergency data.
4
+
5
+ ## API Endpoints
6
+
7
+ ### Events
8
+
9
+ | Endpoint | Method | Description |
10
+ |----------|--------|-------------|
11
+ | /postgres/events/ | GET | Fetch all events (with optional filtering) |
12
+ | /postgres/events/{event_id} | GET | Fetch a specific event by ID |
13
+ | /postgres/events/featured | GET | Fetch featured events |
14
+ | /postgres/events/ | POST | Create a new event |
15
+ | /postgres/events/{event_id} | PUT | Update an existing event |
16
+ | /postgres/events/{event_id} | DELETE | Delete an event |
17
+
18
+ ### FAQs
19
+
20
+ | Endpoint | Method | Description |
21
+ |----------|--------|-------------|
22
+ | /postgres/faqs/ | GET | Fetch all FAQs |
23
+ | /postgres/faqs/{faq_id} | GET | Fetch a specific FAQ by ID |
24
+ | /postgres/faqs/ | POST | Create a new FAQ |
25
+ | /postgres/faqs/{faq_id} | PUT | Update an existing FAQ |
26
+ | /postgres/faqs/{faq_id} | DELETE | Delete a FAQ |
27
+
28
+ ### Emergency Contacts
29
+
30
+ | Endpoint | Method | Description |
31
+ |----------|--------|-------------|
32
+ | /postgres/emergencies/ | GET | Fetch all emergency contacts |
33
+ | /postgres/emergencies/{emergency_id} | GET | Fetch a specific emergency contact by ID |
34
+ | /postgres/emergencies/ | POST | Create a new emergency contact |
35
+ | /postgres/emergencies/{emergency_id} | PUT | Update an existing emergency contact |
36
+ | /postgres/emergencies/{emergency_id} | DELETE | Delete an emergency contact |
37
+
38
+ ## Response Models
39
+
40
+ ### Event Response Model
41
+
42
+ interface EventResponse {
43
+ id: number;
44
+ name: string;
45
+ description: string;
46
+ date_start: string; // ISO format date
47
+ date_end: string; // ISO format date
48
+ location: string;
49
+ image_url: string;
50
+ price: {
51
+ currency: string;
52
+ amount: string;
53
+ };
54
+ featured: boolean;
55
+ is_active: boolean;
56
+ created_at: string; // ISO format date
57
+ updated_at: string; // ISO format date
58
+ }
59
+
60
+ ### FAQ Response Model
61
+
62
+ interface FaqResponse {
63
+ id: number;
64
+ question: string;
65
+ answer: string;
66
+ is_active: boolean;
67
+ created_at: string; // ISO format date
68
+ updated_at: string; // ISO format date
69
+ }
70
+
71
+ ### Emergency Response Model
72
+
73
+ interface EmergencyResponse {
74
+ id: number;
75
+ name: string;
76
+ phone_number: string;
77
+ description: string;
78
+ address: string;
79
+ priority: number;
80
+ is_active: boolean;
81
+ created_at: string; // ISO format date
82
+ updated_at: string; // ISO format date
83
+ }
84
+
85
+ ## Example Usage (React)
86
+
87
+ ### Fetching Events
88
+
89
+ import { useState, useEffect } from 'react';
90
+ import axios from 'axios';
91
+
92
+ const API_BASE_URL = 'http://localhost:8000';
93
+
94
+ function EventList() {
95
+ const [events, setEvents] = useState([]);
96
+ const [loading, setLoading] = useState(true);
97
+ const [error, setError] = useState(null);
98
+
99
+ useEffect(() => {
100
+ const fetchEvents = async () => {
101
+ try {
102
+ setLoading(true);
103
+ const response = await axios.get(`${API_BASE_URL}/postgres/events/`);
104
+ setEvents(response.data);
105
+ setLoading(false);
106
+ } catch (err) {
107
+ setError('Failed to fetch events');
108
+ setLoading(false);
109
+ console.error('Error fetching events:', err);
110
+ }
111
+ };
112
+
113
+ fetchEvents();
114
+ }, []);
115
+
116
+ if (loading) return <p>Loading events...</p>;
117
+ if (error) return <p>{error}</p>;
118
+
119
+ return (
120
+ <div>
121
+ <h1>Events</h1>
122
+ <div className="event-list">
123
+ {events.map(event => (
124
+ <div key={event.id} className="event-card">
125
+ <h2>{event.name}</h2>
126
+ <p>{event.description}</p>
127
+ <p>
128
+ <strong>When:</strong> {new Date(event.date_start).toLocaleDateString()} - {new Date(event.date_end).toLocaleDateString()}
129
+ </p>
130
+ <p><strong>Where:</strong> {event.location}</p>
131
+ <p><strong>Price:</strong> {event.price.amount} {event.price.currency}</p>
132
+ {event.featured && <span className="featured-badge">Featured</span>}
133
+ </div>
134
+ ))}
135
+ </div>
136
+ </div>
137
+ );
138
+ }
139
+
140
+ ### Creating an Event
141
+
142
+ import { useState } from 'react';
143
+ import axios from 'axios';
144
+
145
+ function CreateEvent() {
146
+ const [eventData, setEventData] = useState({
147
+ name: '',
148
+ description: '',
149
+ date_start: '',
150
+ date_end: '',
151
+ location: '',
152
+ image_url: '',
153
+ price: {
154
+ currency: 'USD',
155
+ amount: '0'
156
+ },
157
+ featured: false,
158
+ is_active: true
159
+ });
160
+ const [loading, setLoading] = useState(false);
161
+ const [error, setError] = useState(null);
162
+ const [success, setSuccess] = useState(false);
163
+
164
+ const handleChange = (e) => {
165
+ const { name, value, type, checked } = e.target;
166
+
167
+ if (name === 'price_amount') {
168
+ setEventData(prev => ({
169
+ ...prev,
170
+ price: {
171
+ ...prev.price,
172
+ amount: value
173
+ }
174
+ }));
175
+ } else if (name === 'price_currency') {
176
+ setEventData(prev => ({
177
+ ...prev,
178
+ price: {
179
+ ...prev.price,
180
+ currency: value
181
+ }
182
+ }));
183
+ } else {
184
+ setEventData(prev => ({
185
+ ...prev,
186
+ [name]: type === 'checkbox' ? checked : value
187
+ }));
188
+ }
189
+ };
190
+
191
+ const handleSubmit = async (e) => {
192
+ e.preventDefault();
193
+ try {
194
+ setLoading(true);
195
+ setError(null);
196
+ setSuccess(false);
197
+
198
+ const response = await axios.post(`${API_BASE_URL}/postgres/events/`, eventData);
199
+ setSuccess(true);
200
+ setEventData({
201
+ name: '',
202
+ description: '',
203
+ date_start: '',
204
+ date_end: '',
205
+ location: '',
206
+ image_url: '',
207
+ price: {
208
+ currency: 'USD',
209
+ amount: '0'
210
+ },
211
+ featured: false,
212
+ is_active: true
213
+ });
214
+ setLoading(false);
215
+ } catch (err) {
216
+ setError('Failed to create event');
217
+ setLoading(false);
218
+ console.error('Error creating event:', err);
219
+ }
220
+ };
221
+
222
+ return (
223
+ <div>
224
+ <h1>Create New Event</h1>
225
+ {success && <div className="success-message">Event created successfully!</div>}
226
+ {error && <div className="error-message">{error}</div>}
227
+ <form onSubmit={handleSubmit}>
228
+ {/* Form fields would go here */}
229
+ <button type="submit" disabled={loading}>
230
+ {loading ? 'Creating...' : 'Create Event'}
231
+ </button>
232
+ </form>
233
+ </div>
234
+ );
235
+ }
236
+
237
+ ## Performance Optimizations
238
+
239
+ The API now includes several performance optimizations:
240
+
241
+ ### Caching
242
+
243
+ The server implements caching for read operations, which significantly improves response times for repeated requests. The average cache improvement is over 70%.
244
+
245
+ Frontend considerations:
246
+ No need to implement client-side caching for data that doesn't change frequently
247
+ For real-time data, consider adding a refresh button in the UI
248
+ If data might be updated by other users, consider adding a polling mechanism or websocket for updates
249
+
250
+
251
+ ### Error Handling
252
+
253
+ The API returns standardized error responses. Example:
254
+
255
+ async function fetchData(url) {
256
+ try {
257
+ const response = await fetch(url);
258
+ if (!response.ok) {
259
+ const errorData = await response.json();
260
+ throw new Error(errorData.detail || 'An error occurred');
261
+ }
262
+ return await response.json();
263
+ } catch (error) {
264
+ console.error('API request failed:', error);
265
+ // Handle error in UI
266
+ return null;
267
+ }
268
+ }
269
+
270
+ ### Price Field Handling
271
+
272
+ The price field of events is a JSON object with currency and amount properties. When creating or updating events, ensure this is properly formatted:
273
+
274
+ // Correct format for price field
275
+ const eventData = {
276
+ // other fields...
277
+ price: {
278
+ currency: 'USD',
279
+ amount: '10.99'
280
+ }
281
+ };
282
+
283
+ // When displaying price
284
+ function formatPrice(price) {
285
+ if (!price) return 'Free';
286
+ if (typeof price === 'string') {
287
+ try {
288
+ price = JSON.parse(price);
289
+ } catch {
290
+ return price;
291
+ }
292
+ }
293
+ return `${price.amount} ${price.currency}`;
294
+ }
295
+
296
+ ## CORS Configuration
297
+
298
+ The API has CORS enabled for frontend applications. If you're experiencing CORS issues, ensure your frontend domain is allowed in the server configuration.
299
+
300
+ For local development, the following origins are typically allowed:
301
+ - http://localhost:3000
302
+ - http://localhost:5000
303
+ - http://localhost:8080
304
+
305
+ ## Status Codes
306
+
307
+ | Status Code | Description |
308
+ |-------------|-------------|
309
+ | 200 | Success - The request was successful |
310
+ | 201 | Created - A new resource was successfully created |
311
+ | 400 | Bad Request - The request could not be understood or was missing required parameters |
312
+ | 404 | Not Found - Resource not found |
313
+ | 422 | Validation Error - Request data failed validation |
314
+ | 500 | Internal Server Error - An error occurred on the server |
315
+
316
+ ## Questions?
317
+
318
+ For further inquiries about the API, please contact the development team.
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Depends, Request, HTTPException, status
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from contextlib import asynccontextmanager
4
+ import uvicorn
5
+ import os
6
+ import sys
7
+ import logging
8
+ from dotenv import load_dotenv
9
+
10
+ # Cấu hình logging
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
14
+ handlers=[
15
+ logging.StreamHandler(sys.stdout),
16
+ ]
17
+ )
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Load environment variables
21
+ load_dotenv()
22
+ DEBUG = os.getenv("DEBUG", "False").lower() in ("true", "1", "t")
23
+
24
+ # Kiểm tra các biến môi trường bắt buộc
25
+ required_env_vars = [
26
+ "AIVEN_DB_URL",
27
+ "MONGODB_URL",
28
+ "PINECONE_API_KEY",
29
+ "PINECONE_INDEX_NAME",
30
+ "GOOGLE_API_KEY"
31
+ ]
32
+
33
+ missing_vars = [var for var in required_env_vars if not os.getenv(var)]
34
+ if missing_vars:
35
+ logger.error(f"Missing required environment variables: {', '.join(missing_vars)}")
36
+ if not DEBUG: # Chỉ thoát nếu không ở chế độ debug
37
+ sys.exit(1)
38
+
39
+ # Database health checks
40
+ def check_database_connections():
41
+ """Kiểm tra kết nối các database khi khởi động"""
42
+ from app.database.postgresql import check_db_connection as check_postgresql
43
+ from app.database.mongodb import check_db_connection as check_mongodb
44
+ from app.database.pinecone import check_db_connection as check_pinecone
45
+
46
+ db_status = {
47
+ "postgresql": check_postgresql(),
48
+ "mongodb": check_mongodb(),
49
+ "pinecone": check_pinecone()
50
+ }
51
+
52
+ all_ok = all(db_status.values())
53
+ if not all_ok:
54
+ failed_dbs = [name for name, status in db_status.items() if not status]
55
+ logger.error(f"Failed to connect to databases: {', '.join(failed_dbs)}")
56
+ if not DEBUG: # Chỉ thoát nếu không ở chế độ debug
57
+ sys.exit(1)
58
+
59
+ return db_status
60
+
61
+ # Khởi tạo lifespan để kiểm tra kết nối database khi khởi động
62
+ @asynccontextmanager
63
+ async def lifespan(app: FastAPI):
64
+ # Startup: kiểm tra kết nối các database
65
+ logger.info("Starting application...")
66
+ db_status = check_database_connections()
67
+ if all(db_status.values()):
68
+ logger.info("All database connections are working")
69
+
70
+ # Khởi tạo bảng trong cơ sở dữ liệu (nếu chưa tồn tại)
71
+ if DEBUG: # Chỉ khởi tạo bảng trong chế độ debug
72
+ from app.database.postgresql import create_tables
73
+ if create_tables():
74
+ logger.info("Database tables created or already exist")
75
+
76
+ yield
77
+
78
+ # Shutdown
79
+ logger.info("Shutting down application...")
80
+
81
+ # Import routers
82
+ try:
83
+ from app.api.mongodb_routes import router as mongodb_router
84
+ from app.api.postgresql_routes import router as postgresql_router
85
+ from app.api.rag_routes import router as rag_router
86
+ from app.api.websocket_routes import router as websocket_router
87
+
88
+ # Import middlewares
89
+ from app.utils.middleware import RequestLoggingMiddleware, ErrorHandlingMiddleware, DatabaseCheckMiddleware
90
+
91
+ # Import debug utilities
92
+ from app.utils.debug_utils import debug_view, DebugInfo, error_tracker, performance_monitor
93
+
94
+ except ImportError as e:
95
+ logger.error(f"Error importing routes or middlewares: {e}")
96
+ raise
97
+
98
+ # Create FastAPI app
99
+ app = FastAPI(
100
+ title="PIX Project Backend API",
101
+ description="Backend API for PIX Project with MongoDB, PostgreSQL and RAG integration",
102
+ version="1.0.0",
103
+ docs_url="/docs",
104
+ redoc_url="/redoc",
105
+ debug=DEBUG,
106
+ lifespan=lifespan,
107
+ )
108
+
109
+ # Configure CORS
110
+ app.add_middleware(
111
+ CORSMiddleware,
112
+ allow_origins=["*"],
113
+ allow_credentials=True,
114
+ allow_methods=["*"],
115
+ allow_headers=["*"],
116
+ )
117
+
118
+ # Thêm middlewares
119
+ app.add_middleware(ErrorHandlingMiddleware)
120
+ app.add_middleware(RequestLoggingMiddleware)
121
+ if not DEBUG: # Chỉ thêm middleware kiểm tra database trong production
122
+ app.add_middleware(DatabaseCheckMiddleware)
123
+
124
+ # Include routers
125
+ app.include_router(mongodb_router)
126
+ app.include_router(postgresql_router)
127
+ app.include_router(rag_router)
128
+ app.include_router(websocket_router)
129
+
130
+ # Root endpoint
131
+ @app.get("/")
132
+ def read_root():
133
+ return {
134
+ "message": "Welcome to PIX Project Backend API",
135
+ "documentation": "/docs",
136
+ }
137
+
138
+ # Health check endpoint
139
+ @app.get("/health")
140
+ def health_check():
141
+ # Kiểm tra kết nối database
142
+ db_status = check_database_connections()
143
+ all_db_ok = all(db_status.values())
144
+
145
+ return {
146
+ "status": "healthy" if all_db_ok else "degraded",
147
+ "version": "1.0.0",
148
+ "environment": os.environ.get("ENVIRONMENT", "production"),
149
+ "databases": db_status
150
+ }
151
+
152
+ # Debug endpoints (chỉ có trong chế độ debug)
153
+ if DEBUG:
154
+ @app.get("/debug/config")
155
+ def debug_config():
156
+ """Hiển thị thông tin cấu hình (chỉ trong chế độ debug)"""
157
+ config = {
158
+ "environment": os.environ.get("ENVIRONMENT", "production"),
159
+ "debug": DEBUG,
160
+ "db_connection_mode": os.environ.get("DB_CONNECTION_MODE", "aiven"),
161
+ "databases": {
162
+ "postgresql": os.environ.get("AIVEN_DB_URL", "").split("@")[1].split("/")[0] if "@" in os.environ.get("AIVEN_DB_URL", "") else "N/A",
163
+ "mongodb": os.environ.get("MONGODB_URL", "").split("@")[1].split("/?")[0] if "@" in os.environ.get("MONGODB_URL", "") else "N/A",
164
+ "pinecone": os.environ.get("PINECONE_INDEX_NAME", "N/A"),
165
+ }
166
+ }
167
+ return config
168
+
169
+ @app.get("/debug/system")
170
+ def debug_system():
171
+ """Hiển thị thông tin hệ thống (chỉ trong chế độ debug)"""
172
+ return DebugInfo.get_system_info()
173
+
174
+ @app.get("/debug/database")
175
+ def debug_database():
176
+ """Hiển thị trạng thái database (chỉ trong chế độ debug)"""
177
+ return DebugInfo.get_database_status()
178
+
179
+ @app.get("/debug/errors")
180
+ def debug_errors(limit: int = 10):
181
+ """Hiển thị các lỗi gần đây (chỉ trong chế độ debug)"""
182
+ return error_tracker.get_errors(limit=limit)
183
+
184
+ @app.get("/debug/performance")
185
+ def debug_performance():
186
+ """Hiển thị thông tin hiệu suất (chỉ trong chế độ debug)"""
187
+ return performance_monitor.get_report()
188
+
189
+ @app.get("/debug/full")
190
+ def debug_full_report(request: Request):
191
+ """Hiển thị báo cáo debug đầy đủ (chỉ trong chế độ debug)"""
192
+ return debug_view(request)
193
+
194
+ # Run the app with uvicorn when executed directly
195
+ if __name__ == "__main__":
196
+ port = int(os.environ.get("PORT", 8000))
197
+ uvicorn.run("app:app", host="0.0.0.0", port=port, reload=DEBUG)
app/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PIX Project Backend
2
+ # Version: 1.0.0
3
+
4
+ __version__ = "1.0.0"
5
+
6
+ # Import app từ app.py để tests có thể tìm thấy
7
+ import sys
8
+ import os
9
+
10
+ # Thêm thư mục gốc vào sys.path
11
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
+
13
+ try:
14
+ from app.py import app
15
+ except ImportError:
16
+ # Thử cách khác nếu import trực tiếp không hoạt động
17
+ import importlib.util
18
+ spec = importlib.util.spec_from_file_location("app_module",
19
+ os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
20
+ "app.py"))
21
+ app_module = importlib.util.module_from_spec(spec)
22
+ spec.loader.exec_module(app_module)
23
+ app = app_module.app
app/api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # API routes package
app/api/mongodb_routes.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, Depends, Query, status, Response
2
+ from typing import List, Optional, Dict
3
+ from pymongo.errors import PyMongoError
4
+ import logging
5
+ from datetime import datetime
6
+ import traceback
7
+ import asyncio
8
+
9
+ from app.database.mongodb import (
10
+ save_session,
11
+ get_user_history,
12
+ update_session_response,
13
+ check_db_connection,
14
+ session_collection
15
+ )
16
+ from app.models.mongodb_models import (
17
+ SessionCreate,
18
+ SessionResponse,
19
+ HistoryRequest,
20
+ HistoryResponse,
21
+ QuestionAnswer
22
+ )
23
+ from app.api.websocket_routes import send_notification
24
+
25
+ # Configure logging
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # Create router
29
+ router = APIRouter(
30
+ prefix="/mongodb",
31
+ tags=["MongoDB"],
32
+ )
33
+
34
+ @router.post("/session", response_model=SessionResponse, status_code=status.HTTP_201_CREATED)
35
+ async def create_session(session: SessionCreate, response: Response):
36
+ """
37
+ Create a new session record in MongoDB.
38
+
39
+ - **session_id**: Unique identifier for the session (auto-generated if not provided)
40
+ - **factor**: Factor type (user, rag, etc.)
41
+ - **action**: Action type (start, events, faq, emergency, help, asking_freely, etc.)
42
+ - **first_name**: User's first name
43
+ - **last_name**: User's last name (optional)
44
+ - **message**: User's message (optional)
45
+ - **user_id**: User's ID from Telegram
46
+ - **username**: User's username (optional)
47
+ - **response**: Response from RAG (optional)
48
+ """
49
+ try:
50
+ # Kiểm tra kết nối MongoDB
51
+ if not check_db_connection():
52
+ logger.error("MongoDB connection failed when trying to create session")
53
+ raise HTTPException(
54
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
55
+ detail="MongoDB connection failed"
56
+ )
57
+
58
+ # Create new session in MongoDB
59
+ result = save_session(
60
+ session_id=session.session_id,
61
+ factor=session.factor,
62
+ action=session.action,
63
+ first_name=session.first_name,
64
+ last_name=session.last_name,
65
+ message=session.message,
66
+ user_id=session.user_id,
67
+ username=session.username,
68
+ response=session.response
69
+ )
70
+
71
+ # Chuẩn bị response object
72
+ session_response = SessionResponse(
73
+ **session.model_dump(),
74
+ created_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
75
+ )
76
+
77
+ # Kiểm tra nếu session cần gửi thông báo (response bắt đầu bằng "I don't know")
78
+ if session.response and session.response.strip().lower().startswith("i don't know"):
79
+ # Gửi thông báo qua WebSocket
80
+ try:
81
+ notification_data = {
82
+ "session_id": session.session_id,
83
+ "factor": session.factor,
84
+ "action": session.action,
85
+ "message": session.message,
86
+ "user_id": session.user_id,
87
+ "username": session.username,
88
+ "first_name": session.first_name,
89
+ "last_name": session.last_name,
90
+ "response": session.response,
91
+ "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
92
+ }
93
+
94
+ # Khởi tạo task để gửi thông báo - sử dụng asyncio.create_task để đảm bảo không block quá trình chính
95
+ asyncio.create_task(send_notification(notification_data))
96
+ logger.info(f"Notification queued for session {session.session_id} - response starts with 'I don't know'")
97
+ except Exception as e:
98
+ logger.error(f"Error queueing notification: {e}")
99
+ # Không dừng xử lý chính khi gửi thông báo thất bại
100
+
101
+ # Return response
102
+ return session_response
103
+ except PyMongoError as e:
104
+ logger.error(f"MongoDB error creating session: {e}")
105
+ logger.error(traceback.format_exc())
106
+ raise HTTPException(
107
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
108
+ detail=f"MongoDB error: {str(e)}"
109
+ )
110
+ except Exception as e:
111
+ logger.error(f"Unexpected error creating session: {e}")
112
+ logger.error(traceback.format_exc())
113
+ raise HTTPException(
114
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
115
+ detail=f"Failed to create session: {str(e)}"
116
+ )
117
+
118
+ @router.put("/session/{session_id}/response", status_code=status.HTTP_200_OK)
119
+ async def update_session_with_response(session_id: str, response_text: str):
120
+ """
121
+ Update a session with the response.
122
+
123
+ - **session_id**: ID of the session to update
124
+ - **response_text**: Response to add to the session
125
+ """
126
+ try:
127
+ # Kiểm tra kết nối MongoDB
128
+ if not check_db_connection():
129
+ logger.error("MongoDB connection failed when trying to update session response")
130
+ raise HTTPException(
131
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
132
+ detail="MongoDB connection failed"
133
+ )
134
+
135
+ # Update session in MongoDB
136
+ result = update_session_response(session_id, response_text)
137
+
138
+ if not result:
139
+ raise HTTPException(
140
+ status_code=status.HTTP_404_NOT_FOUND,
141
+ detail=f"Session with ID {session_id} not found"
142
+ )
143
+
144
+ return {"status": "success", "message": "Response added to session"}
145
+ except PyMongoError as e:
146
+ logger.error(f"MongoDB error updating session response: {e}")
147
+ logger.error(traceback.format_exc())
148
+ raise HTTPException(
149
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
150
+ detail=f"MongoDB error: {str(e)}"
151
+ )
152
+ except HTTPException:
153
+ # Re-throw HTTP exceptions
154
+ raise
155
+ except Exception as e:
156
+ logger.error(f"Unexpected error updating session response: {e}")
157
+ logger.error(traceback.format_exc())
158
+ raise HTTPException(
159
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
160
+ detail=f"Failed to update session: {str(e)}"
161
+ )
162
+
163
+ @router.get("/history", response_model=HistoryResponse)
164
+ async def get_history(user_id: str, n: int = Query(3, ge=1, le=10)):
165
+ """
166
+ Get user history for a specific user.
167
+
168
+ - **user_id**: User's ID from Telegram
169
+ - **n**: Number of most recent interactions to return (default: 3, min: 1, max: 10)
170
+ """
171
+ try:
172
+ # Kiểm tra kết nối MongoDB
173
+ if not check_db_connection():
174
+ logger.error("MongoDB connection failed when trying to get user history")
175
+ raise HTTPException(
176
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
177
+ detail="MongoDB connection failed"
178
+ )
179
+
180
+ # Get user history from MongoDB
181
+ history_data = get_user_history(user_id=user_id, n=n)
182
+
183
+ # Convert to response model
184
+ return HistoryResponse(history=history_data)
185
+ except PyMongoError as e:
186
+ logger.error(f"MongoDB error getting user history: {e}")
187
+ logger.error(traceback.format_exc())
188
+ raise HTTPException(
189
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
190
+ detail=f"MongoDB error: {str(e)}"
191
+ )
192
+ except Exception as e:
193
+ logger.error(f"Unexpected error getting user history: {e}")
194
+ logger.error(traceback.format_exc())
195
+ raise HTTPException(
196
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
197
+ detail=f"Failed to get user history: {str(e)}"
198
+ )
199
+
200
+ @router.get("/health")
201
+ async def health_check():
202
+ """
203
+ Check health of MongoDB connection.
204
+ """
205
+ try:
206
+ # Kiểm tra kết nối MongoDB
207
+ is_connected = check_db_connection()
208
+
209
+ if not is_connected:
210
+ return {
211
+ "status": "unhealthy",
212
+ "message": "MongoDB connection failed",
213
+ "timestamp": datetime.now().isoformat()
214
+ }
215
+
216
+ return {
217
+ "status": "healthy",
218
+ "message": "MongoDB connection is working",
219
+ "timestamp": datetime.now().isoformat()
220
+ }
221
+ except Exception as e:
222
+ logger.error(f"MongoDB health check failed: {e}")
223
+ logger.error(traceback.format_exc())
224
+ return {
225
+ "status": "error",
226
+ "message": f"MongoDB health check error: {str(e)}",
227
+ "timestamp": datetime.now().isoformat()
228
+ }
229
+
230
+ @router.get("/session/{session_id}")
231
+ async def get_session(session_id: str):
232
+ """
233
+ Lấy thông tin session từ MongoDB theo session_id.
234
+
235
+ - **session_id**: ID của session cần lấy
236
+ """
237
+ try:
238
+ # Kiểm tra kết nối MongoDB
239
+ if not check_db_connection():
240
+ logger.error("MongoDB connection failed when trying to get session")
241
+ raise HTTPException(
242
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
243
+ detail="MongoDB connection failed"
244
+ )
245
+
246
+ # Lấy thông tin từ MongoDB
247
+ session_data = session_collection.find_one({"session_id": session_id})
248
+
249
+ if not session_data:
250
+ raise HTTPException(
251
+ status_code=status.HTTP_404_NOT_FOUND,
252
+ detail=f"Session with ID {session_id} not found"
253
+ )
254
+
255
+ # Chuyển _id thành string để có thể JSON serialize
256
+ if "_id" in session_data:
257
+ session_data["_id"] = str(session_data["_id"])
258
+
259
+ return session_data
260
+ except PyMongoError as e:
261
+ logger.error(f"MongoDB error getting session: {e}")
262
+ logger.error(traceback.format_exc())
263
+ raise HTTPException(
264
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
265
+ detail=f"MongoDB error: {str(e)}"
266
+ )
267
+ except HTTPException:
268
+ # Re-throw HTTP exceptions
269
+ raise
270
+ except Exception as e:
271
+ logger.error(f"Unexpected error getting session: {e}")
272
+ logger.error(traceback.format_exc())
273
+ raise HTTPException(
274
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
275
+ detail=f"Failed to get session: {str(e)}"
276
+ )
app/api/postgresql_routes.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, Depends, Query, Path, Body
2
+ from sqlalchemy.orm import Session
3
+ from sqlalchemy.exc import SQLAlchemyError
4
+ from typing import List, Optional, Dict, Any
5
+ import logging
6
+ import traceback
7
+ from datetime import datetime
8
+ from sqlalchemy import text, inspect
9
+
10
+ from app.database.postgresql import get_db
11
+ from app.database.models import FAQItem, EmergencyItem, EventItem
12
+ from pydantic import BaseModel, Field, ConfigDict
13
+
14
+ # Configure logging
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Create router
18
+ router = APIRouter(
19
+ prefix="/postgres",
20
+ tags=["PostgreSQL"],
21
+ )
22
+
23
+ # --- Pydantic models for request/response ---
24
+
25
+ # FAQ models
26
+ class FAQBase(BaseModel):
27
+ question: str
28
+ answer: str
29
+ is_active: bool = True
30
+
31
+ class FAQCreate(FAQBase):
32
+ pass
33
+
34
+ class FAQUpdate(BaseModel):
35
+ question: Optional[str] = None
36
+ answer: Optional[str] = None
37
+ is_active: Optional[bool] = None
38
+
39
+ class FAQResponse(FAQBase):
40
+ id: int
41
+ created_at: datetime
42
+ updated_at: datetime
43
+
44
+ # Sử dụng ConfigDict thay vì class Config cho Pydantic V2
45
+ model_config = ConfigDict(from_attributes=True)
46
+
47
+ # Emergency contact models
48
+ class EmergencyBase(BaseModel):
49
+ name: str
50
+ phone_number: str
51
+ description: Optional[str] = None
52
+ address: Optional[str] = None
53
+ location: Optional[str] = None
54
+ priority: int = 0
55
+ is_active: bool = True
56
+
57
+ class EmergencyCreate(EmergencyBase):
58
+ pass
59
+
60
+ class EmergencyUpdate(BaseModel):
61
+ name: Optional[str] = None
62
+ phone_number: Optional[str] = None
63
+ description: Optional[str] = None
64
+ address: Optional[str] = None
65
+ location: Optional[str] = None
66
+ priority: Optional[int] = None
67
+ is_active: Optional[bool] = None
68
+
69
+ class EmergencyResponse(EmergencyBase):
70
+ id: int
71
+ created_at: datetime
72
+ updated_at: datetime
73
+
74
+ # Sử dụng ConfigDict thay vì class Config cho Pydantic V2
75
+ model_config = ConfigDict(from_attributes=True)
76
+
77
+ # Event models
78
+ class EventBase(BaseModel):
79
+ name: str
80
+ description: str
81
+ address: str
82
+ location: Optional[str] = None
83
+ date_start: datetime
84
+ date_end: Optional[datetime] = None
85
+ price: Optional[List[dict]] = None
86
+ is_active: bool = True
87
+ featured: bool = False
88
+
89
+ class EventCreate(EventBase):
90
+ pass
91
+
92
+ class EventUpdate(BaseModel):
93
+ name: Optional[str] = None
94
+ description: Optional[str] = None
95
+ address: Optional[str] = None
96
+ location: Optional[str] = None
97
+ date_start: Optional[datetime] = None
98
+ date_end: Optional[datetime] = None
99
+ price: Optional[List[dict]] = None
100
+ is_active: Optional[bool] = None
101
+ featured: Optional[bool] = None
102
+
103
+ class EventResponse(EventBase):
104
+ id: int
105
+ created_at: datetime
106
+ updated_at: datetime
107
+
108
+ # Sử dụng ConfigDict thay vì class Config cho Pydantic V2
109
+ model_config = ConfigDict(from_attributes=True)
110
+
111
+ # --- FAQ endpoints ---
112
+
113
+ @router.get("/faq", response_model=List[FAQResponse])
114
+ async def get_faqs(
115
+ skip: int = 0,
116
+ limit: int = 100,
117
+ active_only: bool = False,
118
+ db: Session = Depends(get_db)
119
+ ):
120
+ """
121
+ Get all FAQ items.
122
+
123
+ - **skip**: Number of items to skip
124
+ - **limit**: Maximum number of items to return
125
+ - **active_only**: If true, only return active items
126
+ """
127
+ try:
128
+ # Log detailed connection info
129
+ logger.info(f"Attempting to fetch FAQs with skip={skip}, limit={limit}, active_only={active_only}")
130
+
131
+ # Check if the FAQItem table exists
132
+ inspector = inspect(db.bind)
133
+ if not inspector.has_table("faq_item"):
134
+ logger.error("The faq_item table does not exist in the database")
135
+ raise HTTPException(status_code=500, detail="Table 'faq_item' does not exist")
136
+
137
+ # Log table columns
138
+ columns = inspector.get_columns("faq_item")
139
+ logger.info(f"faq_item table columns: {[c['name'] for c in columns]}")
140
+
141
+ # Query the FAQs with detailed logging
142
+ query = db.query(FAQItem)
143
+ if active_only:
144
+ query = query.filter(FAQItem.is_active == True)
145
+
146
+ # Try direct SQL to debug
147
+ try:
148
+ test_result = db.execute(text("SELECT COUNT(*) FROM faq_item")).scalar()
149
+ logger.info(f"SQL test query succeeded, found {test_result} FAQ items")
150
+ except Exception as sql_error:
151
+ logger.error(f"SQL test query failed: {sql_error}")
152
+
153
+ # Execute the ORM query
154
+ faqs = query.offset(skip).limit(limit).all()
155
+ logger.info(f"Successfully fetched {len(faqs)} FAQ items")
156
+
157
+ # Check what we're returning
158
+ for i, faq in enumerate(faqs[:3]): # Log the first 3 items
159
+ logger.info(f"FAQ item {i+1}: id={faq.id}, question={faq.question[:30]}...")
160
+
161
+ # Convert SQLAlchemy models to Pydantic models - updated for Pydantic v2
162
+ result = [FAQResponse.model_validate(faq, from_attributes=True) for faq in faqs]
163
+ return result
164
+ except SQLAlchemyError as e:
165
+ logger.error(f"Database error in get_faqs: {e}")
166
+ logger.error(traceback.format_exc())
167
+ raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
168
+ except Exception as e:
169
+ logger.error(f"Unexpected error in get_faqs: {e}")
170
+ logger.error(traceback.format_exc())
171
+ raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
172
+
173
+ @router.post("/faq", response_model=FAQResponse)
174
+ async def create_faq(
175
+ faq: FAQCreate,
176
+ db: Session = Depends(get_db)
177
+ ):
178
+ """
179
+ Create a new FAQ item.
180
+
181
+ - **question**: Question text
182
+ - **answer**: Answer text
183
+ - **is_active**: Whether the FAQ is active (default: True)
184
+ """
185
+ try:
186
+ # Sử dụng model_dump thay vì dict
187
+ db_faq = FAQItem(**faq.model_dump())
188
+ db.add(db_faq)
189
+ db.commit()
190
+ db.refresh(db_faq)
191
+ return FAQResponse.model_validate(db_faq, from_attributes=True)
192
+ except SQLAlchemyError as e:
193
+ db.rollback()
194
+ logger.error(f"Database error: {e}")
195
+ raise HTTPException(status_code=500, detail="Failed to create FAQ item")
196
+
197
+ @router.get("/faq/{faq_id}", response_model=FAQResponse)
198
+ async def get_faq(
199
+ faq_id: int = Path(..., gt=0),
200
+ db: Session = Depends(get_db)
201
+ ):
202
+ """
203
+ Get a specific FAQ item by ID.
204
+
205
+ - **faq_id**: ID of the FAQ item
206
+ """
207
+ try:
208
+ faq = db.query(FAQItem).filter(FAQItem.id == faq_id).first()
209
+ if not faq:
210
+ raise HTTPException(status_code=404, detail="FAQ item not found")
211
+ return FAQResponse.model_validate(faq, from_attributes=True)
212
+ except SQLAlchemyError as e:
213
+ logger.error(f"Database error: {e}")
214
+ raise HTTPException(status_code=500, detail="Database error")
215
+
216
+ @router.put("/faq/{faq_id}", response_model=FAQResponse)
217
+ async def update_faq(
218
+ faq_id: int = Path(..., gt=0),
219
+ faq_update: FAQUpdate = Body(...),
220
+ db: Session = Depends(get_db)
221
+ ):
222
+ """
223
+ Update a specific FAQ item.
224
+
225
+ - **faq_id**: ID of the FAQ item to update
226
+ - **question**: New question text (optional)
227
+ - **answer**: New answer text (optional)
228
+ - **is_active**: New active status (optional)
229
+ """
230
+ try:
231
+ faq = db.query(FAQItem).filter(FAQItem.id == faq_id).first()
232
+ if not faq:
233
+ raise HTTPException(status_code=404, detail="FAQ item not found")
234
+
235
+ # Sử dụng model_dump thay vì dict
236
+ update_data = faq_update.model_dump(exclude_unset=True)
237
+ for key, value in update_data.items():
238
+ setattr(faq, key, value)
239
+
240
+ db.commit()
241
+ db.refresh(faq)
242
+ return FAQResponse.model_validate(faq, from_attributes=True)
243
+ except SQLAlchemyError as e:
244
+ db.rollback()
245
+ logger.error(f"Database error: {e}")
246
+ raise HTTPException(status_code=500, detail="Failed to update FAQ item")
247
+
248
+ @router.delete("/faq/{faq_id}", response_model=dict)
249
+ async def delete_faq(
250
+ faq_id: int = Path(..., gt=0),
251
+ db: Session = Depends(get_db)
252
+ ):
253
+ """
254
+ Delete a specific FAQ item.
255
+
256
+ - **faq_id**: ID of the FAQ item to delete
257
+ """
258
+ try:
259
+ faq = db.query(FAQItem).filter(FAQItem.id == faq_id).first()
260
+ if not faq:
261
+ raise HTTPException(status_code=404, detail="FAQ item not found")
262
+
263
+ db.delete(faq)
264
+ db.commit()
265
+ return {"status": "success", "message": f"FAQ item {faq_id} deleted"}
266
+ except SQLAlchemyError as e:
267
+ db.rollback()
268
+ logger.error(f"Database error: {e}")
269
+ raise HTTPException(status_code=500, detail="Failed to delete FAQ item")
270
+
271
+ # --- Emergency endpoints ---
272
+
273
+ @router.get("/emergency", response_model=List[EmergencyResponse])
274
+ async def get_emergency_contacts(
275
+ skip: int = 0,
276
+ limit: int = 100,
277
+ active_only: bool = False,
278
+ db: Session = Depends(get_db)
279
+ ):
280
+ """
281
+ Get all emergency contacts.
282
+
283
+ - **skip**: Number of items to skip
284
+ - **limit**: Maximum number of items to return
285
+ - **active_only**: If true, only return active items
286
+ """
287
+ try:
288
+ # Log detailed connection info
289
+ logger.info(f"Attempting to fetch emergency contacts with skip={skip}, limit={limit}, active_only={active_only}")
290
+
291
+ # Check if the EmergencyItem table exists
292
+ inspector = inspect(db.bind)
293
+ if not inspector.has_table("emergency_item"):
294
+ logger.error("The emergency_item table does not exist in the database")
295
+ raise HTTPException(status_code=500, detail="Table 'emergency_item' does not exist")
296
+
297
+ # Log table columns
298
+ columns = inspector.get_columns("emergency_item")
299
+ logger.info(f"emergency_item table columns: {[c['name'] for c in columns]}")
300
+
301
+ # Try direct SQL to debug
302
+ try:
303
+ test_result = db.execute(text("SELECT COUNT(*) FROM emergency_item")).scalar()
304
+ logger.info(f"SQL test query succeeded, found {test_result} emergency contacts")
305
+ except Exception as sql_error:
306
+ logger.error(f"SQL test query failed: {sql_error}")
307
+
308
+ # Query the emergency contacts
309
+ query = db.query(EmergencyItem)
310
+ if active_only:
311
+ query = query.filter(EmergencyItem.is_active == True)
312
+
313
+ # Execute the ORM query
314
+ emergency_contacts = query.offset(skip).limit(limit).all()
315
+ logger.info(f"Successfully fetched {len(emergency_contacts)} emergency contacts")
316
+
317
+ # Check what we're returning
318
+ for i, contact in enumerate(emergency_contacts[:3]): # Log the first 3 items
319
+ logger.info(f"Emergency contact {i+1}: id={contact.id}, name={contact.name}")
320
+
321
+ return emergency_contacts
322
+ except SQLAlchemyError as e:
323
+ logger.error(f"Database error in get_emergency_contacts: {e}")
324
+ logger.error(traceback.format_exc())
325
+ raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
326
+ except Exception as e:
327
+ logger.error(f"Unexpected error in get_emergency_contacts: {e}")
328
+ logger.error(traceback.format_exc())
329
+ raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
330
+
331
+ @router.post("/emergency", response_model=EmergencyResponse)
332
+ async def create_emergency_contact(
333
+ emergency: EmergencyCreate,
334
+ db: Session = Depends(get_db)
335
+ ):
336
+ """
337
+ Create a new emergency contact.
338
+
339
+ - **name**: Contact name
340
+ - **phone_number**: Phone number
341
+ - **description**: Description (optional)
342
+ - **address**: Address (optional)
343
+ - **location**: Location coordinates (optional)
344
+ - **priority**: Priority order (default: 0)
345
+ - **is_active**: Whether the contact is active (default: True)
346
+ """
347
+ try:
348
+ db_emergency = EmergencyItem(**emergency.model_dump())
349
+ db.add(db_emergency)
350
+ db.commit()
351
+ db.refresh(db_emergency)
352
+ return db_emergency
353
+ except SQLAlchemyError as e:
354
+ db.rollback()
355
+ logger.error(f"Database error: {e}")
356
+ raise HTTPException(status_code=500, detail="Failed to create emergency contact")
357
+
358
+ @router.get("/emergency/{emergency_id}", response_model=EmergencyResponse)
359
+ async def get_emergency_contact(
360
+ emergency_id: int = Path(..., gt=0),
361
+ db: Session = Depends(get_db)
362
+ ):
363
+ """
364
+ Get a specific emergency contact by ID.
365
+
366
+ - **emergency_id**: ID of the emergency contact
367
+ """
368
+ try:
369
+ emergency = db.query(EmergencyItem).filter(EmergencyItem.id == emergency_id).first()
370
+ if not emergency:
371
+ raise HTTPException(status_code=404, detail="Emergency contact not found")
372
+ return emergency
373
+ except SQLAlchemyError as e:
374
+ logger.error(f"Database error: {e}")
375
+ raise HTTPException(status_code=500, detail="Database error")
376
+
377
+ @router.put("/emergency/{emergency_id}", response_model=EmergencyResponse)
378
+ async def update_emergency_contact(
379
+ emergency_id: int = Path(..., gt=0),
380
+ emergency_update: EmergencyUpdate = Body(...),
381
+ db: Session = Depends(get_db)
382
+ ):
383
+ """
384
+ Update a specific emergency contact.
385
+
386
+ - **emergency_id**: ID of the emergency contact to update
387
+ - **name**: New name (optional)
388
+ - **phone_number**: New phone number (optional)
389
+ - **description**: New description (optional)
390
+ - **address**: New address (optional)
391
+ - **location**: New location coordinates (optional)
392
+ - **priority**: New priority order (optional)
393
+ - **is_active**: New active status (optional)
394
+ """
395
+ try:
396
+ emergency = db.query(EmergencyItem).filter(EmergencyItem.id == emergency_id).first()
397
+ if not emergency:
398
+ raise HTTPException(status_code=404, detail="Emergency contact not found")
399
+
400
+ # Update fields if provided
401
+ update_data = emergency_update.model_dump(exclude_unset=True)
402
+ for key, value in update_data.items():
403
+ setattr(emergency, key, value)
404
+
405
+ db.commit()
406
+ db.refresh(emergency)
407
+ return emergency
408
+ except SQLAlchemyError as e:
409
+ db.rollback()
410
+ logger.error(f"Database error: {e}")
411
+ raise HTTPException(status_code=500, detail="Failed to update emergency contact")
412
+
413
+ @router.delete("/emergency/{emergency_id}", response_model=dict)
414
+ async def delete_emergency_contact(
415
+ emergency_id: int = Path(..., gt=0),
416
+ db: Session = Depends(get_db)
417
+ ):
418
+ """
419
+ Delete a specific emergency contact.
420
+
421
+ - **emergency_id**: ID of the emergency contact to delete
422
+ """
423
+ try:
424
+ emergency = db.query(EmergencyItem).filter(EmergencyItem.id == emergency_id).first()
425
+ if not emergency:
426
+ raise HTTPException(status_code=404, detail="Emergency contact not found")
427
+
428
+ db.delete(emergency)
429
+ db.commit()
430
+ return {"status": "success", "message": f"Emergency contact {emergency_id} deleted"}
431
+ except SQLAlchemyError as e:
432
+ db.rollback()
433
+ logger.error(f"Database error: {e}")
434
+ raise HTTPException(status_code=500, detail="Failed to delete emergency contact")
435
+
436
+ # --- Event endpoints ---
437
+
438
+ @router.get("/events", response_model=List[EventResponse])
439
+ async def get_events(
440
+ skip: int = 0,
441
+ limit: int = 100,
442
+ active_only: bool = False,
443
+ featured_only: bool = False,
444
+ db: Session = Depends(get_db)
445
+ ):
446
+ """
447
+ Get all events.
448
+
449
+ - **skip**: Number of items to skip
450
+ - **limit**: Maximum number of items to return
451
+ - **active_only**: If true, only return active items
452
+ - **featured_only**: If true, only return featured items
453
+ """
454
+ try:
455
+ # Log detailed connection info
456
+ logger.info(f"Attempting to fetch events with skip={skip}, limit={limit}, active_only={active_only}, featured_only={featured_only}")
457
+
458
+ # Check if the EventItem table exists
459
+ inspector = inspect(db.bind)
460
+ if not inspector.has_table("event_item"):
461
+ logger.error("The event_item table does not exist in the database")
462
+ raise HTTPException(status_code=500, detail="Table 'event_item' does not exist")
463
+
464
+ # Log table columns
465
+ columns = inspector.get_columns("event_item")
466
+ logger.info(f"event_item table columns: {[c['name'] for c in columns]}")
467
+
468
+ # Try direct SQL to debug
469
+ try:
470
+ test_result = db.execute(text("SELECT COUNT(*) FROM event_item")).scalar()
471
+ logger.info(f"SQL test query succeeded, found {test_result} events")
472
+ except Exception as sql_error:
473
+ logger.error(f"SQL test query failed: {sql_error}")
474
+
475
+ # Query the events
476
+ query = db.query(EventItem)
477
+ if active_only:
478
+ query = query.filter(EventItem.is_active == True)
479
+ if featured_only:
480
+ query = query.filter(EventItem.featured == True)
481
+
482
+ # Execute the ORM query
483
+ events = query.offset(skip).limit(limit).all()
484
+ logger.info(f"Successfully fetched {len(events)} events")
485
+
486
+ # Debug price field of first event
487
+ if events and len(events) > 0:
488
+ logger.info(f"First event price type: {type(events[0].price)}, value: {events[0].price}")
489
+
490
+ # Check what we're returning
491
+ for i, event in enumerate(events[:3]): # Log the first 3 items
492
+ logger.info(f"Event {i+1}: id={event.id}, name={event.name}, price={type(event.price)}")
493
+
494
+ return events
495
+ except SQLAlchemyError as e:
496
+ logger.error(f"Database error in get_events: {e}")
497
+ logger.error(traceback.format_exc())
498
+ raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
499
+ except Exception as e:
500
+ logger.error(f"Unexpected error in get_events: {e}")
501
+ logger.error(traceback.format_exc())
502
+ raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
503
+
504
+ @router.post("/events", response_model=EventResponse)
505
+ async def create_event(
506
+ event: EventCreate,
507
+ db: Session = Depends(get_db)
508
+ ):
509
+ """
510
+ Create a new event.
511
+
512
+ - **name**: Event name
513
+ - **description**: Event description
514
+ - **address**: Event address
515
+ - **location**: Location coordinates (optional)
516
+ - **date_start**: Start date and time
517
+ - **date_end**: End date and time (optional)
518
+ - **price**: Price information (optional JSON object)
519
+ - **is_active**: Whether the event is active (default: True)
520
+ - **featured**: Whether the event is featured (default: False)
521
+ """
522
+ try:
523
+ db_event = EventItem(**event.model_dump())
524
+ db.add(db_event)
525
+ db.commit()
526
+ db.refresh(db_event)
527
+ return db_event
528
+ except SQLAlchemyError as e:
529
+ db.rollback()
530
+ logger.error(f"Database error: {e}")
531
+ raise HTTPException(status_code=500, detail="Failed to create event")
532
+
533
+ @router.get("/events/{event_id}", response_model=EventResponse)
534
+ async def get_event(
535
+ event_id: int = Path(..., gt=0),
536
+ db: Session = Depends(get_db)
537
+ ):
538
+ """
539
+ Get a specific event by ID.
540
+
541
+ - **event_id**: ID of the event
542
+ """
543
+ try:
544
+ event = db.query(EventItem).filter(EventItem.id == event_id).first()
545
+ if not event:
546
+ raise HTTPException(status_code=404, detail="Event not found")
547
+ return event
548
+ except SQLAlchemyError as e:
549
+ logger.error(f"Database error: {e}")
550
+ raise HTTPException(status_code=500, detail="Database error")
551
+
552
+ @router.put("/events/{event_id}", response_model=EventResponse)
553
+ async def update_event(
554
+ event_id: int = Path(..., gt=0),
555
+ event_update: EventUpdate = Body(...),
556
+ db: Session = Depends(get_db)
557
+ ):
558
+ """
559
+ Update a specific event.
560
+
561
+ - **event_id**: ID of the event to update
562
+ - **name**: New name (optional)
563
+ - **description**: New description (optional)
564
+ - **address**: New address (optional)
565
+ - **location**: New location coordinates (optional)
566
+ - **date_start**: New start date and time (optional)
567
+ - **date_end**: New end date and time (optional)
568
+ - **price**: New price information (optional JSON object)
569
+ - **is_active**: New active status (optional)
570
+ - **featured**: New featured status (optional)
571
+ """
572
+ try:
573
+ event = db.query(EventItem).filter(EventItem.id == event_id).first()
574
+ if not event:
575
+ raise HTTPException(status_code=404, detail="Event not found")
576
+
577
+ # Update fields if provided
578
+ update_data = event_update.model_dump(exclude_unset=True)
579
+ for key, value in update_data.items():
580
+ setattr(event, key, value)
581
+
582
+ db.commit()
583
+ db.refresh(event)
584
+ return event
585
+ except SQLAlchemyError as e:
586
+ db.rollback()
587
+ logger.error(f"Database error: {e}")
588
+ raise HTTPException(status_code=500, detail="Failed to update event")
589
+
590
+ @router.delete("/events/{event_id}", response_model=dict)
591
+ async def delete_event(
592
+ event_id: int = Path(..., gt=0),
593
+ db: Session = Depends(get_db)
594
+ ):
595
+ """
596
+ Delete a specific event.
597
+
598
+ - **event_id**: ID of the event to delete
599
+ """
600
+ try:
601
+ event = db.query(EventItem).filter(EventItem.id == event_id).first()
602
+ if not event:
603
+ raise HTTPException(status_code=404, detail="Event not found")
604
+
605
+ db.delete(event)
606
+ db.commit()
607
+ return {"status": "success", "message": f"Event {event_id} deleted"}
608
+ except SQLAlchemyError as e:
609
+ db.rollback()
610
+ logger.error(f"Database error: {e}")
611
+ raise HTTPException(status_code=500, detail="Failed to delete event")
612
+
613
+ # Health check endpoint
614
+ @router.get("/health")
615
+ async def health_check(db: Session = Depends(get_db)):
616
+ """
617
+ Check health of PostgreSQL connection.
618
+ """
619
+ try:
620
+ # Perform a simple database query to check health
621
+ # Use text() to wrap the SQL query for SQLAlchemy 2.0 compatibility
622
+ db.execute(text("SELECT 1")).first()
623
+ return {"status": "healthy", "message": "PostgreSQL connection is working", "timestamp": datetime.now().isoformat()}
624
+ except Exception as e:
625
+ logger.error(f"PostgreSQL health check failed: {e}")
626
+ raise HTTPException(status_code=503, detail=f"PostgreSQL connection failed: {str(e)}")
app/api/rag_routes.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, Depends, Query, BackgroundTasks, Request
2
+ from typing import List, Optional, Dict, Any
3
+ import logging
4
+ import time
5
+ import os
6
+ import json
7
+ import hashlib
8
+ import asyncio
9
+ import traceback
10
+ import google.generativeai as genai
11
+ from datetime import datetime
12
+ from langchain.prompts import PromptTemplate
13
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
14
+ from app.utils.utils import cache, timer_decorator
15
+
16
+ from app.database.mongodb import get_user_history, get_chat_history, get_request_history, save_session, session_collection
17
+ from app.database.pinecone import (
18
+ search_vectors,
19
+ get_chain,
20
+ DEFAULT_TOP_K,
21
+ DEFAULT_LIMIT_K,
22
+ DEFAULT_SIMILARITY_METRIC,
23
+ DEFAULT_SIMILARITY_THRESHOLD,
24
+ ALLOWED_METRICS
25
+ )
26
+ from app.models.rag_models import (
27
+ ChatRequest,
28
+ ChatResponse,
29
+ ChatResponseInternal,
30
+ SourceDocument,
31
+ EmbeddingRequest,
32
+ EmbeddingResponse,
33
+ UserMessageModel
34
+ )
35
+
36
+ # Sử dụng bộ nhớ đệm thay vì Redis
37
+ class SimpleCache:
38
+ def __init__(self):
39
+ self.cache = {}
40
+ self.expiration = {}
41
+
42
+ async def get(self, key):
43
+ if key in self.cache:
44
+ # Kiểm tra xem cache đã hết hạn chưa
45
+ if key in self.expiration and self.expiration[key] > time.time():
46
+ return self.cache[key]
47
+ else:
48
+ # Xóa cache đã hết hạn
49
+ if key in self.cache:
50
+ del self.cache[key]
51
+ if key in self.expiration:
52
+ del self.expiration[key]
53
+ return None
54
+
55
+ async def set(self, key, value, ex=300): # Mặc định 5 phút
56
+ self.cache[key] = value
57
+ self.expiration[key] = time.time() + ex
58
+
59
+ # Khởi tạo SimpleCache
60
+ redis_client = SimpleCache()
61
+
62
+ # Configure logging
63
+ logger = logging.getLogger(__name__)
64
+
65
+ # Configure Google Gemini API
66
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
67
+ genai.configure(api_key=GOOGLE_API_KEY)
68
+
69
+ # Create router
70
+ router = APIRouter(
71
+ prefix="/rag",
72
+ tags=["RAG"],
73
+ )
74
+
75
+ # Create a prompt template with conversation history
76
+ prompt = PromptTemplate(
77
+ template = """Goal:
78
+ You are a professional tour guide assistant that assists users in finding information about places in Da Nang, Vietnam.
79
+ You can provide details on restaurants, cafes, hotels, attractions, and other local venues.
80
+ You have to use core knowledge and conversation history to chat with users, who are Da Nang's tourists.
81
+
82
+ Return Format:
83
+ Respond in friendly, natural, concise and use only English like a real tour guide.
84
+ Always use HTML tags (e.g. <b> for bold) so that Telegram can render the special formatting correctly.
85
+
86
+ Warning:
87
+ Let's support users like a real tour guide, not a bot. The information in core knowledge is your own knowledge.
88
+ Your knowledge is provided in the Core Knowledge. All of information in Core Knowledge is about Da Nang, Vietnam.
89
+ You just care about current time that user mention when user ask about Solana event.
90
+ If you do not have enough information to answer user's question, please reply with "I don't know. I don't have information about that".
91
+
92
+ Core knowledge:
93
+ {context}
94
+
95
+ Conversation History:
96
+ {chat_history}
97
+
98
+ User message:
99
+ {question}
100
+
101
+ Your message:
102
+ """,
103
+ input_variables = ["context", "question", "chat_history"],
104
+ )
105
+
106
+ # Helper for embeddings
107
+ async def get_embedding(text: str):
108
+ """Get embedding from Google Gemini API"""
109
+ try:
110
+ # Initialize embedding model
111
+ embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
112
+
113
+ # Generate embedding
114
+ result = await embedding_model.aembed_query(text)
115
+
116
+ # Return embedding
117
+ return {
118
+ "embedding": result,
119
+ "text": text,
120
+ "model": "embedding-001"
121
+ }
122
+ except Exception as e:
123
+ logger.error(f"Error generating embedding: {e}")
124
+ raise HTTPException(status_code=500, detail=f"Failed to generate embedding: {str(e)}")
125
+
126
+ # Endpoint for generating embeddings
127
+ @router.post("/embedding", response_model=EmbeddingResponse)
128
+ async def create_embedding(request: EmbeddingRequest):
129
+ """
130
+ Generate embedding for text.
131
+
132
+ - **text**: Text to generate embedding for
133
+ """
134
+ try:
135
+ # Get embedding
136
+ embedding_data = await get_embedding(request.text)
137
+
138
+ # Return embedding
139
+ return EmbeddingResponse(**embedding_data)
140
+ except Exception as e:
141
+ logger.error(f"Error generating embedding: {e}")
142
+ raise HTTPException(status_code=500, detail=f"Failed to generate embedding: {str(e)}")
143
+
144
+ @timer_decorator
145
+ @router.post("/chat", response_model=ChatResponse)
146
+ async def chat(request: ChatRequest, background_tasks: BackgroundTasks):
147
+ """
148
+ Get answer for a question using RAG.
149
+
150
+ - **user_id**: User's ID from Telegram
151
+ - **question**: User's question
152
+ - **include_history**: Whether to include user history in prompt (default: True)
153
+ - **use_rag**: Whether to use RAG (default: True)
154
+ - **similarity_top_k**: Number of top similar documents to return after filtering (default: 6)
155
+ - **limit_k**: Maximum number of documents to retrieve from vector store (default: 10)
156
+ - **similarity_metric**: Similarity metric to use - cosine, dotproduct, euclidean (default: cosine)
157
+ - **similarity_threshold**: Threshold for vector similarity (default: 0.75)
158
+ - **session_id**: Optional session ID for tracking conversations
159
+ - **first_name**: User's first name
160
+ - **last_name**: User's last name
161
+ - **username**: User's username
162
+ """
163
+ start_time = time.time()
164
+ try:
165
+ # Create cache key for request
166
+ cache_key = f"rag_chat:{request.user_id}:{request.question}:{request.include_history}:{request.use_rag}:{request.similarity_top_k}:{request.limit_k}:{request.similarity_metric}:{request.similarity_threshold}"
167
+
168
+ # Check cache using redis_client instead of cache
169
+ cached_response = await redis_client.get(cache_key)
170
+ if cached_response is not None:
171
+ logger.info(f"Cache hit for RAG chat request from user {request.user_id}")
172
+ try:
173
+ # If cached_response is string (JSON), parse it
174
+ if isinstance(cached_response, str):
175
+ cached_data = json.loads(cached_response)
176
+ return ChatResponse(
177
+ answer=cached_data.get("answer", ""),
178
+ processing_time=cached_data.get("processing_time", 0.0)
179
+ )
180
+ # If cached_response is object with sources, extract answer and processing_time
181
+ elif hasattr(cached_response, 'sources'):
182
+ return ChatResponse(
183
+ answer=cached_response.answer,
184
+ processing_time=cached_response.processing_time
185
+ )
186
+ # Otherwise, return cached response as is
187
+ return cached_response
188
+ except Exception as e:
189
+ logger.error(f"Error parsing cached response: {e}")
190
+ # Continue processing if cache parsing fails
191
+
192
+ # Save user message first (so it's available for user history)
193
+ session_id = request.session_id or f"{request.user_id}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
194
+ logger.info(f"Processing chat request for user {request.user_id}, session {session_id}")
195
+
196
+ # First, save the user's message so it's available for history lookups
197
+ try:
198
+ # Save user's question
199
+ save_session(
200
+ session_id=session_id,
201
+ factor="user",
202
+ action="asking_freely",
203
+ first_name=getattr(request, 'first_name', "User"),
204
+ last_name=getattr(request, 'last_name', ""),
205
+ message=request.question,
206
+ user_id=request.user_id,
207
+ username=getattr(request, 'username', ""),
208
+ response=None # No response yet
209
+ )
210
+ logger.info(f"User message saved for session {session_id}")
211
+ except Exception as e:
212
+ logger.error(f"Error saving user message to session: {e}")
213
+ # Continue processing even if saving fails
214
+
215
+ # Use the RAG pipeline
216
+ if request.use_rag:
217
+ # Get the retriever with custom parameters
218
+ retriever = get_chain(
219
+ top_k=request.similarity_top_k,
220
+ limit_k=request.limit_k,
221
+ similarity_metric=request.similarity_metric,
222
+ similarity_threshold=request.similarity_threshold
223
+ )
224
+ if not retriever:
225
+ raise HTTPException(status_code=500, detail="Failed to initialize retriever")
226
+
227
+ # Get request history for context
228
+ context_query = get_request_history(request.user_id) if request.include_history else request.question
229
+ logger.info(f"Using context query for retrieval: {context_query[:100]}...")
230
+
231
+ # Retrieve relevant documents
232
+ retrieved_docs = retriever.invoke(context_query)
233
+ context = "\n".join([doc.page_content for doc in retrieved_docs])
234
+
235
+ # Prepare sources
236
+ sources = []
237
+ for doc in retrieved_docs:
238
+ source = None
239
+ metadata = {}
240
+
241
+ if hasattr(doc, 'metadata'):
242
+ source = doc.metadata.get('source', None)
243
+ # Extract score information
244
+ score = doc.metadata.get('score', None)
245
+ normalized_score = doc.metadata.get('normalized_score', None)
246
+ # Remove score info from metadata to avoid duplication
247
+ metadata = {k: v for k, v in doc.metadata.items()
248
+ if k not in ['text', 'source', 'score', 'normalized_score']}
249
+
250
+ sources.append(SourceDocument(
251
+ text=doc.page_content,
252
+ source=source,
253
+ score=score,
254
+ normalized_score=normalized_score,
255
+ metadata=metadata
256
+ ))
257
+ else:
258
+ # No RAG
259
+ context = ""
260
+ sources = None
261
+
262
+ # Get chat history
263
+ chat_history = get_chat_history(request.user_id) if request.include_history else ""
264
+ logger.info(f"Using chat history: {chat_history[:100]}...")
265
+
266
+ # Initialize Gemini model
267
+ generation_config = {
268
+ "temperature": 0.9,
269
+ "top_p": 1,
270
+ "top_k": 1,
271
+ "max_output_tokens": 2048,
272
+ }
273
+
274
+ safety_settings = [
275
+ {
276
+ "category": "HARM_CATEGORY_HARASSMENT",
277
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
278
+ },
279
+ {
280
+ "category": "HARM_CATEGORY_HATE_SPEECH",
281
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
282
+ },
283
+ {
284
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
285
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
286
+ },
287
+ {
288
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
289
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
290
+ },
291
+ ]
292
+
293
+ model = genai.GenerativeModel(
294
+ model_name='models/gemini-2.0-flash',
295
+ generation_config=generation_config,
296
+ safety_settings=safety_settings
297
+ )
298
+
299
+ # Generate the prompt using template
300
+ prompt_text = prompt.format(
301
+ context=context,
302
+ question=request.question,
303
+ chat_history=chat_history
304
+ )
305
+ logger.info(f"Full prompt with history and context: {prompt_text}")
306
+
307
+ # Generate response
308
+ response = model.generate_content(prompt_text)
309
+ answer = response.text
310
+
311
+ # Save the RAG response
312
+ try:
313
+ # Now save the RAG response with the same session_id
314
+ save_session(
315
+ session_id=session_id,
316
+ factor="rag",
317
+ action="RAG_response",
318
+ first_name=getattr(request, 'first_name', "User"),
319
+ last_name=getattr(request, 'last_name', ""),
320
+ message=request.question,
321
+ user_id=request.user_id,
322
+ username=getattr(request, 'username', ""),
323
+ response=answer
324
+ )
325
+ logger.info(f"RAG response saved for session {session_id}")
326
+
327
+ # Check if the response starts with "I don't know" and trigger notification
328
+ if answer.strip().lower().startswith("i don't know"):
329
+ from app.api.websocket_routes import send_notification
330
+ notification_data = {
331
+ "session_id": session_id,
332
+ "factor": "rag",
333
+ "action": "RAG_response",
334
+ "message": request.question,
335
+ "user_id": request.user_id,
336
+ "username": getattr(request, 'username', ""),
337
+ "first_name": getattr(request, 'first_name', "User"),
338
+ "last_name": getattr(request, 'last_name', ""),
339
+ "response": answer,
340
+ "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
341
+ }
342
+ background_tasks.add_task(send_notification, notification_data)
343
+ logger.info(f"Notification queued for session {session_id} - response starts with 'I don't know'")
344
+ except Exception as e:
345
+ logger.error(f"Error saving RAG response to session: {e}")
346
+ # Continue processing even if saving fails
347
+
348
+ # Calculate processing time
349
+ processing_time = time.time() - start_time
350
+
351
+ # Create internal response object with sources for logging
352
+ internal_response = ChatResponseInternal(
353
+ answer=answer,
354
+ sources=sources,
355
+ processing_time=processing_time
356
+ )
357
+
358
+ # Log full response with sources
359
+ logger.info(f"Generated response for user {request.user_id}: {answer}")
360
+ if sources:
361
+ logger.info(f"Sources used: {len(sources)} documents")
362
+ for i, source in enumerate(sources):
363
+ logger.info(f"Source {i+1}: {source.source or 'Unknown'} (score: {source.score})")
364
+
365
+ # Create response object for API (without sources)
366
+ chat_response = ChatResponse(
367
+ answer=answer,
368
+ processing_time=processing_time
369
+ )
370
+
371
+ # Cache result using redis_client instead of cache
372
+ try:
373
+ # Convert to JSON to ensure it can be cached
374
+ cache_data = {
375
+ "answer": answer,
376
+ "processing_time": processing_time
377
+ }
378
+ await redis_client.set(cache_key, json.dumps(cache_data), ex=300)
379
+ except Exception as e:
380
+ logger.error(f"Error caching response: {e}")
381
+ # Continue even if caching fails
382
+
383
+ # Return response
384
+ return chat_response
385
+ except Exception as e:
386
+ logger.error(f"Error processing chat request: {e}")
387
+ import traceback
388
+ logger.error(traceback.format_exc())
389
+ raise HTTPException(status_code=500, detail=f"Failed to process chat request: {str(e)}")
390
+
391
+ # Health check endpoint
392
+ @router.get("/health")
393
+ async def health_check():
394
+ """
395
+ Check health of RAG services and retrieval system.
396
+
397
+ Returns:
398
+ - status: "healthy" if all services are working, "degraded" otherwise
399
+ - services: Status of each service (gemini, pinecone)
400
+ - retrieval_config: Current retrieval configuration
401
+ - timestamp: Current time
402
+ """
403
+ services = {
404
+ "gemini": False,
405
+ "pinecone": False
406
+ }
407
+
408
+ # Check Gemini
409
+ try:
410
+ # Initialize simple model
411
+ model = genai.GenerativeModel("gemini-2.0-flash")
412
+ # Test generation
413
+ response = model.generate_content("Hello")
414
+ services["gemini"] = True
415
+ except Exception as e:
416
+ logger.error(f"Gemini health check failed: {e}")
417
+
418
+ # Check Pinecone
419
+ try:
420
+ # Import pinecone function
421
+ from app.database.pinecone import get_pinecone_index
422
+ # Get index
423
+ index = get_pinecone_index()
424
+ # Check if index exists
425
+ if index:
426
+ services["pinecone"] = True
427
+ except Exception as e:
428
+ logger.error(f"Pinecone health check failed: {e}")
429
+
430
+ # Get retrieval configuration
431
+ retrieval_config = {
432
+ "default_top_k": DEFAULT_TOP_K,
433
+ "default_limit_k": DEFAULT_LIMIT_K,
434
+ "default_similarity_metric": DEFAULT_SIMILARITY_METRIC,
435
+ "default_similarity_threshold": DEFAULT_SIMILARITY_THRESHOLD,
436
+ "allowed_metrics": ALLOWED_METRICS
437
+ }
438
+
439
+ # Return health status
440
+ status = "healthy" if all(services.values()) else "degraded"
441
+ return {
442
+ "status": status,
443
+ "services": services,
444
+ "retrieval_config": retrieval_config,
445
+ "timestamp": datetime.now().isoformat()
446
+ }
447
+
448
+ @router.post("/rag")
449
+ async def process_rag(request: Request, user_data: UserMessageModel, background_tasks: BackgroundTasks):
450
+ """
451
+ Process a user message through the RAG pipeline and return a response.
452
+
453
+ Parameters:
454
+ - **user_id**: User ID from the client application
455
+ - **session_id**: Session ID for tracking the conversation
456
+ - **message**: User's message/question
457
+ - **similarity_top_k**: (Optional) Number of top similar documents to return after filtering
458
+ - **limit_k**: (Optional) Maximum number of documents to retrieve from vector store
459
+ - **similarity_metric**: (Optional) Similarity metric to use (cosine, dotproduct, euclidean)
460
+ - **similarity_threshold**: (Optional) Threshold for vector similarity (0-1)
461
+ """
462
+ try:
463
+ # Extract request data
464
+ user_id = user_data.user_id
465
+ session_id = user_data.session_id
466
+ message = user_data.message
467
+
468
+ # Extract retrieval parameters (use defaults if not provided)
469
+ top_k = user_data.similarity_top_k or DEFAULT_TOP_K
470
+ limit_k = user_data.limit_k or DEFAULT_LIMIT_K
471
+ similarity_metric = user_data.similarity_metric or DEFAULT_SIMILARITY_METRIC
472
+ similarity_threshold = user_data.similarity_threshold or DEFAULT_SIMILARITY_THRESHOLD
473
+
474
+ logger.info(f"RAG request received for user_id={user_id}, session_id={session_id}")
475
+ logger.info(f"Message: {message[:100]}..." if len(message) > 100 else f"Message: {message}")
476
+ logger.info(f"Retrieval parameters: top_k={top_k}, limit_k={limit_k}, metric={similarity_metric}, threshold={similarity_threshold}")
477
+
478
+ # Create a cache key for this request to avoid reprocessing identical questions
479
+ cache_key = f"rag_{user_id}_{session_id}_{hashlib.md5(message.encode()).hexdigest()}_{top_k}_{limit_k}_{similarity_metric}_{similarity_threshold}"
480
+
481
+ # Check if we have this response cached
482
+ cached_result = await redis_client.get(cache_key)
483
+ if cached_result:
484
+ logger.info(f"Cache hit for key: {cache_key}")
485
+ if isinstance(cached_result, str): # If stored as JSON string
486
+ return json.loads(cached_result)
487
+ return cached_result
488
+
489
+ # Save user message to MongoDB
490
+ try:
491
+ # Save user's question
492
+ save_session(
493
+ session_id=session_id,
494
+ factor="user",
495
+ action="asking_freely",
496
+ first_name="User", # You can update this with actual data if available
497
+ last_name="",
498
+ message=message,
499
+ user_id=user_id,
500
+ username="",
501
+ response=None # No response yet
502
+ )
503
+ logger.info(f"User message saved to MongoDB with session_id: {session_id}")
504
+ except Exception as e:
505
+ logger.error(f"Error saving user message: {e}")
506
+ # Continue anyway to try to get a response
507
+
508
+ # Create a ChatRequest object to reuse the existing chat endpoint
509
+ chat_request = ChatRequest(
510
+ user_id=user_id,
511
+ question=message,
512
+ include_history=True,
513
+ use_rag=True,
514
+ similarity_top_k=top_k,
515
+ limit_k=limit_k,
516
+ similarity_metric=similarity_metric,
517
+ similarity_threshold=similarity_threshold,
518
+ session_id=session_id
519
+ )
520
+
521
+ # Process through the chat endpoint
522
+ response = await chat(chat_request, background_tasks)
523
+
524
+ # Cache the response
525
+ try:
526
+ await redis_client.set(cache_key, json.dumps({
527
+ "answer": response.answer,
528
+ "processing_time": response.processing_time
529
+ }))
530
+ logger.info(f"Cached response for key: {cache_key}")
531
+ except Exception as e:
532
+ logger.error(f"Failed to cache response: {e}")
533
+
534
+ return response
535
+ except Exception as e:
536
+ logger.error(f"Error processing RAG request: {e}")
537
+ logger.error(traceback.format_exc())
538
+ raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
app/api/websocket_routes.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, status
2
+ from typing import List, Dict
3
+ import logging
4
+ from datetime import datetime
5
+ import asyncio
6
+ import json
7
+ import os
8
+ from dotenv import load_dotenv
9
+ from app.database.mongodb import session_collection
10
+ from app.utils.utils import get_local_time
11
+
12
+ # Load environment variables
13
+ load_dotenv()
14
+
15
+ # Get WebSocket configuration from environment variables
16
+ WEBSOCKET_SERVER = os.getenv("WEBSOCKET_SERVER", "localhost")
17
+ WEBSOCKET_PORT = os.getenv("WEBSOCKET_PORT", "7860")
18
+ WEBSOCKET_PATH = os.getenv("WEBSOCKET_PATH", "/notify")
19
+
20
+ # Configure logging
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Create router
24
+ router = APIRouter(
25
+ tags=["WebSocket"],
26
+ )
27
+
28
+ # Store active WebSocket connections
29
+ class ConnectionManager:
30
+ def __init__(self):
31
+ self.active_connections: List[WebSocket] = []
32
+
33
+ async def connect(self, websocket: WebSocket):
34
+ await websocket.accept()
35
+ self.active_connections.append(websocket)
36
+ client_info = f"{websocket.client.host}:{websocket.client.port}" if hasattr(websocket, 'client') else "Unknown"
37
+ logger.info(f"New WebSocket connection from {client_info}. Total connections: {len(self.active_connections)}")
38
+
39
+ def disconnect(self, websocket: WebSocket):
40
+ self.active_connections.remove(websocket)
41
+ logger.info(f"WebSocket connection removed. Total connections: {len(self.active_connections)}")
42
+
43
+ async def broadcast(self, message: Dict):
44
+ if not self.active_connections:
45
+ logger.warning("No active WebSocket connections to broadcast to")
46
+ return
47
+
48
+ disconnected = []
49
+ for connection in self.active_connections:
50
+ try:
51
+ await connection.send_json(message)
52
+ logger.info(f"Message sent to WebSocket connection")
53
+ except Exception as e:
54
+ logger.error(f"Error sending message to WebSocket: {e}")
55
+ disconnected.append(connection)
56
+
57
+ # Remove disconnected connections
58
+ for conn in disconnected:
59
+ if conn in self.active_connections:
60
+ self.active_connections.remove(conn)
61
+ logger.info(f"Removed disconnected WebSocket. Remaining: {len(self.active_connections)}")
62
+
63
+ # Initialize connection manager
64
+ manager = ConnectionManager()
65
+
66
+ # Create full URL of WebSocket server from environment variables
67
+ def get_full_websocket_url(server_side=False):
68
+ if server_side:
69
+ # Relative URL (for server side)
70
+ return WEBSOCKET_PATH
71
+ else:
72
+ # Full URL (for client)
73
+ # Check if should use wss:// for HTTPS
74
+ is_https = True if int(WEBSOCKET_PORT) == 443 else False
75
+ protocol = "wss" if is_https else "ws"
76
+
77
+ # If using default port for protocol, don't include in URL
78
+ if (is_https and int(WEBSOCKET_PORT) == 443) or (not is_https and int(WEBSOCKET_PORT) == 80):
79
+ return f"{protocol}://{WEBSOCKET_SERVER}{WEBSOCKET_PATH}"
80
+ else:
81
+ return f"{protocol}://{WEBSOCKET_SERVER}:{WEBSOCKET_PORT}{WEBSOCKET_PATH}"
82
+
83
+ # Add GET endpoint to display WebSocket information in Swagger
84
+ @router.get("/notify",
85
+ summary="WebSocket notifications for Admin Bot",
86
+ description=f"""
87
+ This is documentation for the WebSocket endpoint.
88
+
89
+ To connect to WebSocket:
90
+ 1. Use the path `{get_full_websocket_url()}`
91
+ 2. Connect using a WebSocket client library
92
+ 3. When there are new sessions requiring attention, you will receive notifications through this connection
93
+
94
+ Notifications are sent when:
95
+ - Session response starts with "I don't know"
96
+ - The system cannot answer the user's question
97
+
98
+ Make sure to send a "keepalive" message every 5 minutes to maintain the connection.
99
+ """,
100
+ status_code=status.HTTP_200_OK
101
+ )
102
+ async def websocket_documentation():
103
+ """
104
+ Provides information about how to use the WebSocket endpoint /notify.
105
+ This endpoint is for documentation purposes only. To use WebSocket, please connect to the WebSocket URL.
106
+ """
107
+ ws_url = get_full_websocket_url()
108
+ return {
109
+ "websocket_endpoint": WEBSOCKET_PATH,
110
+ "connection_type": "WebSocket",
111
+ "protocol": "ws://",
112
+ "server": WEBSOCKET_SERVER,
113
+ "port": WEBSOCKET_PORT,
114
+ "full_url": ws_url,
115
+ "description": "Endpoint to receive notifications about new sessions requiring attention",
116
+ "notification_format": {
117
+ "type": "new_session",
118
+ "timestamp": "YYYY-MM-DD HH:MM:SS",
119
+ "data": {
120
+ "session_id": "session id",
121
+ "factor": "user",
122
+ "action": "action type",
123
+ "message": "User question",
124
+ "response": "I don't know...",
125
+ "user_id": "user id",
126
+ "first_name": "user's first name",
127
+ "last_name": "user's last name",
128
+ "username": "username",
129
+ "created_at": "creation time"
130
+ }
131
+ },
132
+ "client_example": """
133
+ import websocket
134
+ import json
135
+ import os
136
+ import time
137
+ import threading
138
+ from dotenv import load_dotenv
139
+
140
+ # Load environment variables
141
+ load_dotenv()
142
+
143
+ # Get WebSocket configuration from environment variables
144
+ WEBSOCKET_SERVER = os.getenv("WEBSOCKET_SERVER", "localhost")
145
+ WEBSOCKET_PORT = os.getenv("WEBSOCKET_PORT", "7860")
146
+ WEBSOCKET_PATH = os.getenv("WEBSOCKET_PATH", "/notify")
147
+
148
+ # Create full URL
149
+ ws_url = f"ws://{WEBSOCKET_SERVER}:{WEBSOCKET_PORT}{WEBSOCKET_PATH}"
150
+
151
+ # If using HTTPS, replace ws:// with wss://
152
+ # ws_url = f"wss://{WEBSOCKET_SERVER}{WEBSOCKET_PATH}"
153
+
154
+ # Send keepalive periodically
155
+ def send_keepalive(ws):
156
+ while True:
157
+ try:
158
+ if ws.sock and ws.sock.connected:
159
+ ws.send("keepalive")
160
+ print("Sent keepalive message")
161
+ time.sleep(300) # 5 minutes
162
+ except Exception as e:
163
+ print(f"Error sending keepalive: {e}")
164
+ time.sleep(60)
165
+
166
+ def on_message(ws, message):
167
+ try:
168
+ data = json.loads(message)
169
+ print(f"Received notification: {data}")
170
+ # Process notification, e.g.: send to Telegram Admin
171
+ if data.get("type") == "new_session":
172
+ session_data = data.get("data", {})
173
+ user_question = session_data.get("message", "")
174
+ user_name = session_data.get("first_name", "Unknown User")
175
+ print(f"User {user_name} asked: {user_question}")
176
+ # Code to send message to Telegram Admin
177
+ except json.JSONDecodeError:
178
+ print(f"Received non-JSON message: {message}")
179
+ except Exception as e:
180
+ print(f"Error processing message: {e}")
181
+
182
+ def on_error(ws, error):
183
+ print(f"WebSocket error: {error}")
184
+
185
+ def on_close(ws, close_status_code, close_msg):
186
+ print(f"WebSocket connection closed: code={close_status_code}, message={close_msg}")
187
+
188
+ def on_open(ws):
189
+ print(f"WebSocket connection opened to {ws_url}")
190
+ # Send keepalive messages periodically in a separate thread
191
+ keepalive_thread = threading.Thread(target=send_keepalive, args=(ws,), daemon=True)
192
+ keepalive_thread.start()
193
+
194
+ def run_forever_with_reconnect():
195
+ while True:
196
+ try:
197
+ # Connect WebSocket with ping to maintain connection
198
+ ws = websocket.WebSocketApp(
199
+ ws_url,
200
+ on_open=on_open,
201
+ on_message=on_message,
202
+ on_error=on_error,
203
+ on_close=on_close
204
+ )
205
+ ws.run_forever(ping_interval=60, ping_timeout=30)
206
+ print("WebSocket connection lost, reconnecting in 5 seconds...")
207
+ time.sleep(5)
208
+ except Exception as e:
209
+ print(f"WebSocket connection error: {e}")
210
+ time.sleep(5)
211
+
212
+ # Start WebSocket client in a separate thread
213
+ websocket_thread = threading.Thread(target=run_forever_with_reconnect, daemon=True)
214
+ websocket_thread.start()
215
+
216
+ # Keep the program running
217
+ try:
218
+ while True:
219
+ time.sleep(1)
220
+ except KeyboardInterrupt:
221
+ print("Stopping WebSocket client...")
222
+ """
223
+ }
224
+
225
+ @router.websocket("/notify")
226
+ async def websocket_endpoint(websocket: WebSocket):
227
+ """
228
+ WebSocket endpoint to receive notifications about new sessions.
229
+ Admin Bot will connect to this endpoint to receive notifications when there are new sessions requiring attention.
230
+ """
231
+ await manager.connect(websocket)
232
+ try:
233
+ while True:
234
+ # Maintain WebSocket connection
235
+ data = await websocket.receive_text()
236
+ # Echo back to keep connection active
237
+ await websocket.send_json({"status": "connected", "echo": data, "timestamp": datetime.now().isoformat()})
238
+ logger.info(f"Received message from WebSocket: {data}")
239
+ except WebSocketDisconnect:
240
+ logger.info("WebSocket client disconnected")
241
+ manager.disconnect(websocket)
242
+ except Exception as e:
243
+ logger.error(f"WebSocket error: {e}")
244
+ manager.disconnect(websocket)
245
+
246
+ # Function to send notifications over WebSocket
247
+ async def send_notification(data: dict):
248
+ """
249
+ Send notification to all active WebSocket connections.
250
+
251
+ This function is used to notify admin bots about new issues or questions that need attention.
252
+ It's triggered when the system cannot answer a user's question (response starts with "I don't know").
253
+
254
+ Args:
255
+ data: The data to send as notification
256
+ """
257
+ try:
258
+ # Log number of active connections and notification attempt
259
+ logger.info(f"Attempting to send notification. Active connections: {len(manager.active_connections)}")
260
+ logger.info(f"Notification data: session_id={data.get('session_id')}, user_id={data.get('user_id')}")
261
+ logger.info(f"Response: {data.get('response', '')[:50]}...")
262
+
263
+ # Check if the response starts with "I don't know"
264
+ response = data.get('response', '')
265
+ if not response or not isinstance(response, str):
266
+ logger.warning(f"Invalid response format in notification data: {response}")
267
+ return
268
+
269
+ if not response.strip().lower().startswith("i don't know"):
270
+ logger.info(f"Response doesn't start with 'I don't know', notification not needed: {response[:50]}...")
271
+ return
272
+
273
+ logger.info(f"Response starts with 'I don't know', sending notification")
274
+
275
+ # Format the notification data for admin
276
+ notification_data = {
277
+ "type": "new_session",
278
+ "timestamp": get_local_time(),
279
+ "data": {
280
+ "session_id": data.get('session_id', 'unknown'),
281
+ "user_id": data.get('user_id', 'unknown'),
282
+ "message": data.get('message', ''),
283
+ "response": response,
284
+ "first_name": data.get('first_name', 'User'),
285
+ "last_name": data.get('last_name', ''),
286
+ "username": data.get('username', ''),
287
+ "created_at": data.get('created_at', get_local_time()),
288
+ "action": data.get('action', 'unknown'),
289
+ "factor": "user" # Always show as user for better readability
290
+ }
291
+ }
292
+
293
+ # Check if there are active connections
294
+ if not manager.active_connections:
295
+ logger.warning("No active WebSocket connections for notification broadcast")
296
+ return
297
+
298
+ # Broadcast notification to all active connections
299
+ logger.info(f"Broadcasting notification to {len(manager.active_connections)} connections")
300
+ await manager.broadcast(notification_data)
301
+ logger.info("Notification broadcast completed successfully")
302
+
303
+ except Exception as e:
304
+ logger.error(f"Error sending notification: {e}")
305
+ import traceback
306
+ logger.error(traceback.format_exc())
app/database/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Database connections package
app/database/models.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, Float, Text, LargeBinary, JSON
2
+ from sqlalchemy.sql import func
3
+ from sqlalchemy.orm import relationship
4
+ from .postgresql import Base
5
+ import datetime
6
+
7
+ class FAQItem(Base):
8
+ __tablename__ = "faq_item"
9
+
10
+ id = Column(Integer, primary_key=True, index=True)
11
+ question = Column(String, nullable=False)
12
+ answer = Column(String, nullable=False)
13
+ is_active = Column(Boolean, default=True)
14
+ created_at = Column(DateTime, server_default=func.now())
15
+ updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
16
+
17
+ class EmergencyItem(Base):
18
+ __tablename__ = "emergency_item"
19
+
20
+ id = Column(Integer, primary_key=True, index=True)
21
+ name = Column(String, nullable=False)
22
+ phone_number = Column(String, nullable=False)
23
+ description = Column(String, nullable=True)
24
+ address = Column(String, nullable=True)
25
+ location = Column(String, nullable=True) # Will be converted to/from PostGIS POINT type
26
+ priority = Column(Integer, default=0)
27
+ is_active = Column(Boolean, default=True)
28
+ created_at = Column(DateTime, server_default=func.now())
29
+ updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
30
+
31
+ class EventItem(Base):
32
+ __tablename__ = "event_item"
33
+
34
+ id = Column(Integer, primary_key=True, index=True)
35
+ name = Column(String, nullable=False)
36
+ description = Column(Text, nullable=False)
37
+ address = Column(String, nullable=False)
38
+ location = Column(String, nullable=True) # Will be converted to/from PostGIS POINT type
39
+ date_start = Column(DateTime, nullable=False)
40
+ date_end = Column(DateTime, nullable=True)
41
+ price = Column(JSON, nullable=True)
42
+ is_active = Column(Boolean, default=True)
43
+ featured = Column(Boolean, default=False)
44
+ created_at = Column(DateTime, server_default=func.now())
45
+ updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
46
+
47
+ class VectorDatabase(Base):
48
+ __tablename__ = "vector_database"
49
+
50
+ id = Column(Integer, primary_key=True, index=True)
51
+ name = Column(String, nullable=False, unique=True)
52
+ description = Column(String, nullable=True)
53
+ pinecone_index = Column(String, nullable=False)
54
+ api_key = Column(String, nullable=False)
55
+ status = Column(String, default="active")
56
+ created_at = Column(DateTime, server_default=func.now())
57
+ updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
58
+
59
+ # Relationships
60
+ documents = relationship("Document", back_populates="vector_database")
61
+ vector_statuses = relationship("VectorStatus", back_populates="vector_database")
62
+ engine_associations = relationship("EngineVectorDb", back_populates="vector_database")
63
+
64
+ class Document(Base):
65
+ __tablename__ = "document"
66
+
67
+ id = Column(Integer, primary_key=True, index=True)
68
+ name = Column(String, nullable=False)
69
+ file_content = Column(LargeBinary, nullable=True)
70
+ file_type = Column(String, nullable=True)
71
+ size = Column(Integer, nullable=True)
72
+ content_type = Column(String, nullable=True)
73
+ is_embedded = Column(Boolean, default=False)
74
+ file_metadata = Column(JSON, nullable=True)
75
+ vector_database_id = Column(Integer, ForeignKey("vector_database.id"), nullable=False)
76
+ created_at = Column(DateTime, server_default=func.now())
77
+ updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
78
+
79
+ # Relationships
80
+ vector_database = relationship("VectorDatabase", back_populates="documents")
81
+ vector_statuses = relationship("VectorStatus", back_populates="document")
82
+
83
+ class VectorStatus(Base):
84
+ __tablename__ = "vector_status"
85
+
86
+ id = Column(Integer, primary_key=True, index=True)
87
+ document_id = Column(Integer, ForeignKey("document.id"), nullable=False)
88
+ vector_database_id = Column(Integer, ForeignKey("vector_database.id"), nullable=False)
89
+ vector_id = Column(String, nullable=True)
90
+ status = Column(String, default="pending")
91
+ error_message = Column(String, nullable=True)
92
+ embedded_at = Column(DateTime, nullable=True)
93
+
94
+ # Relationships
95
+ document = relationship("Document", back_populates="vector_statuses")
96
+ vector_database = relationship("VectorDatabase", back_populates="vector_statuses")
97
+
98
+ class TelegramBot(Base):
99
+ __tablename__ = "telegram_bot"
100
+
101
+ id = Column(Integer, primary_key=True, index=True)
102
+ name = Column(String, nullable=False)
103
+ username = Column(String, nullable=False, unique=True)
104
+ token = Column(String, nullable=False)
105
+ status = Column(String, default="inactive")
106
+ created_at = Column(DateTime, server_default=func.now())
107
+ updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
108
+
109
+ # Relationships
110
+ bot_engines = relationship("BotEngine", back_populates="bot")
111
+
112
+ class ChatEngine(Base):
113
+ __tablename__ = "chat_engine"
114
+
115
+ id = Column(Integer, primary_key=True, index=True)
116
+ name = Column(String, nullable=False)
117
+ answer_model = Column(String, nullable=False)
118
+ system_prompt = Column(Text, nullable=True)
119
+ empty_response = Column(String, nullable=True)
120
+ similarity_top_k = Column(Integer, default=3)
121
+ vector_distance_threshold = Column(Float, default=0.75)
122
+ grounding_threshold = Column(Float, default=0.2)
123
+ use_public_information = Column(Boolean, default=False)
124
+ status = Column(String, default="active")
125
+ created_at = Column(DateTime, server_default=func.now())
126
+ last_modified = Column(DateTime, server_default=func.now(), onupdate=func.now())
127
+
128
+ # Relationships
129
+ bot_engines = relationship("BotEngine", back_populates="engine")
130
+ engine_vector_dbs = relationship("EngineVectorDb", back_populates="engine")
131
+
132
+ class BotEngine(Base):
133
+ __tablename__ = "bot_engine"
134
+
135
+ id = Column(Integer, primary_key=True, index=True)
136
+ bot_id = Column(Integer, ForeignKey("telegram_bot.id"), nullable=False)
137
+ engine_id = Column(Integer, ForeignKey("chat_engine.id"), nullable=False)
138
+ created_at = Column(DateTime, server_default=func.now())
139
+
140
+ # Relationships
141
+ bot = relationship("TelegramBot", back_populates="bot_engines")
142
+ engine = relationship("ChatEngine", back_populates="bot_engines")
143
+
144
+ class EngineVectorDb(Base):
145
+ __tablename__ = "engine_vector_db"
146
+
147
+ id = Column(Integer, primary_key=True, index=True)
148
+ engine_id = Column(Integer, ForeignKey("chat_engine.id"), nullable=False)
149
+ vector_database_id = Column(Integer, ForeignKey("vector_database.id"), nullable=False)
150
+ priority = Column(Integer, default=0)
151
+
152
+ # Relationships
153
+ engine = relationship("ChatEngine", back_populates="engine_vector_dbs")
154
+ vector_database = relationship("VectorDatabase", back_populates="engine_associations")
155
+
156
+ class ApiKey(Base):
157
+ __tablename__ = "api_key"
158
+
159
+ id = Column(Integer, primary_key=True, index=True)
160
+ key = Column(String, nullable=False, unique=True)
161
+ name = Column(String, nullable=False)
162
+ description = Column(String, nullable=True)
163
+ is_active = Column(Boolean, default=True)
164
+ created_at = Column(DateTime, server_default=func.now())
165
+ last_used = Column(DateTime, nullable=True)
app/database/mongodb.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pymongo import MongoClient
3
+ from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
4
+ from dotenv import load_dotenv
5
+ from datetime import datetime, timedelta
6
+ import pytz
7
+ import logging
8
+
9
+ # Configure logging
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Load environment variables
13
+ load_dotenv()
14
+
15
+ # MongoDB connection string from .env
16
+ MONGODB_URL = os.getenv("MONGODB_URL")
17
+ DB_NAME = os.getenv("DB_NAME", "Telegram")
18
+ COLLECTION_NAME = os.getenv("COLLECTION_NAME", "session_chat")
19
+
20
+ # Set timeout for MongoDB connection
21
+ MONGODB_TIMEOUT = int(os.getenv("MONGODB_TIMEOUT", "5000")) # 5 seconds by default
22
+
23
+ # Create MongoDB connection with timeout
24
+ try:
25
+ client = MongoClient(MONGODB_URL, serverSelectionTimeoutMS=MONGODB_TIMEOUT)
26
+ db = client[DB_NAME]
27
+
28
+ # Collections
29
+ session_collection = db[COLLECTION_NAME]
30
+ logger.info(f"MongoDB connection initialized to {DB_NAME}.{COLLECTION_NAME}")
31
+
32
+ except Exception as e:
33
+ logger.error(f"Failed to initialize MongoDB connection: {e}")
34
+ # Don't raise exception to avoid crash during startup, error handling will be done in functions
35
+
36
+ # Check MongoDB connection
37
+ def check_db_connection():
38
+ """Check MongoDB connection"""
39
+ try:
40
+ # Issue a ping to confirm a successful connection
41
+ client.admin.command('ping')
42
+ logger.info("MongoDB connection is working")
43
+ return True
44
+ except (ConnectionFailure, ServerSelectionTimeoutError) as e:
45
+ logger.error(f"MongoDB connection failed: {e}")
46
+ return False
47
+ except Exception as e:
48
+ logger.error(f"Unknown error when checking MongoDB connection: {e}")
49
+ return False
50
+
51
+ # Timezone for Asia/Ho_Chi_Minh
52
+ asia_tz = pytz.timezone('Asia/Ho_Chi_Minh')
53
+
54
+ def get_local_time():
55
+ """Get current time in Asia/Ho_Chi_Minh timezone"""
56
+ return datetime.now(asia_tz).strftime("%Y-%m-%d %H:%M:%S")
57
+
58
+ def get_local_datetime():
59
+ """Get current datetime object in Asia/Ho_Chi_Minh timezone"""
60
+ return datetime.now(asia_tz)
61
+
62
+ # For backward compatibility
63
+ get_vietnam_time = get_local_time
64
+ get_vietnam_datetime = get_local_datetime
65
+
66
+ # Utility functions
67
+ def save_session(session_id, factor, action, first_name, last_name, message, user_id, username, response=None):
68
+ """Save user session to MongoDB"""
69
+ try:
70
+ session_data = {
71
+ "session_id": session_id,
72
+ "factor": factor,
73
+ "action": action,
74
+ "created_at": get_local_time(),
75
+ "created_at_datetime": get_local_datetime(),
76
+ "first_name": first_name,
77
+ "last_name": last_name,
78
+ "message": message,
79
+ "user_id": user_id,
80
+ "username": username,
81
+ "response": response
82
+ }
83
+ result = session_collection.insert_one(session_data)
84
+ logger.info(f"Session saved with ID: {result.inserted_id}")
85
+ return {
86
+ "acknowledged": result.acknowledged,
87
+ "inserted_id": str(result.inserted_id),
88
+ "session_data": session_data
89
+ }
90
+ except Exception as e:
91
+ logger.error(f"Error saving session: {e}")
92
+ raise
93
+
94
+ def update_session_response(session_id, response):
95
+ """Update a session with response"""
96
+ try:
97
+ result = session_collection.update_one(
98
+ {"session_id": session_id},
99
+ {"$set": {"response": response}}
100
+ )
101
+
102
+ if result.matched_count == 0:
103
+ logger.warning(f"No session found with ID: {session_id}")
104
+ return False
105
+
106
+ logger.info(f"Session {session_id} updated with response")
107
+ return True
108
+ except Exception as e:
109
+ logger.error(f"Error updating session response: {e}")
110
+ raise
111
+
112
+ def get_recent_sessions(user_id, action, n=3):
113
+ """Get n most recent sessions for a specific user and action"""
114
+ try:
115
+ return list(
116
+ session_collection.find(
117
+ {"user_id": user_id, "action": action},
118
+ {"_id": 0, "message": 1, "response": 1}
119
+ ).sort("created_at_datetime", -1).limit(n)
120
+ )
121
+ except Exception as e:
122
+ logger.error(f"Error getting recent sessions: {e}")
123
+ return []
124
+
125
+ def get_user_history(user_id, n=3):
126
+ """Get user history for a specific user"""
127
+ try:
128
+ # Find all messages of this user
129
+ user_messages = list(
130
+ session_collection.find(
131
+ {
132
+ "user_id": user_id,
133
+ "message": {"$exists": True, "$ne": None},
134
+ # Include all user messages regardless of action type
135
+ }
136
+ ).sort("created_at_datetime", -1).limit(n * 2) # Get more to ensure we have enough pairs
137
+ )
138
+
139
+ # Group messages by session_id to find pairs
140
+ session_dict = {}
141
+ for msg in user_messages:
142
+ session_id = msg.get("session_id")
143
+ if session_id not in session_dict:
144
+ session_dict[session_id] = {}
145
+
146
+ if msg.get("factor", "").lower() == "user":
147
+ session_dict[session_id]["question"] = msg.get("message", "")
148
+ session_dict[session_id]["timestamp"] = msg.get("created_at_datetime")
149
+ elif msg.get("factor", "").lower() == "rag":
150
+ session_dict[session_id]["answer"] = msg.get("response", "")
151
+
152
+ # Build history from complete pairs only (with both question and answer)
153
+ history = []
154
+ for session_id, data in session_dict.items():
155
+ if "question" in data and "answer" in data and data.get("answer"):
156
+ history.append({
157
+ "question": data["question"],
158
+ "answer": data["answer"]
159
+ })
160
+
161
+ # Sort by timestamp and limit to n
162
+ history = sorted(history, key=lambda x: x.get("timestamp", 0), reverse=True)[:n]
163
+
164
+ logger.info(f"Retrieved {len(history)} history items for user {user_id}")
165
+ return history
166
+ except Exception as e:
167
+ logger.error(f"Error getting user history: {e}")
168
+ return []
169
+
170
+ # Functions from chatbot.py
171
+ def get_chat_history(user_id, n=5):
172
+ """Get conversation history for a specific user from MongoDB in format suitable for LLM prompt"""
173
+ try:
174
+ history = get_user_history(user_id, n)
175
+
176
+ # Format history for prompt context
177
+ formatted_history = ""
178
+ for item in history:
179
+ formatted_history += f"User: {item['question']}\nAssistant: {item['answer']}\n\n"
180
+
181
+ return formatted_history
182
+ except Exception as e:
183
+ logger.error(f"Error getting chat history for prompt: {e}")
184
+ return ""
185
+
186
+ def get_request_history(user_id, n=3):
187
+ """Get the most recent user requests to use as context for retrieval"""
188
+ try:
189
+ history = get_user_history(user_id, n)
190
+
191
+ # Just extract the questions for context
192
+ requests = []
193
+ for item in history:
194
+ requests.append(item['question'])
195
+
196
+ # Join all recent requests into a single string for context
197
+ return " ".join(requests)
198
+ except Exception as e:
199
+ logger.error(f"Error getting request history: {e}")
200
+ return ""
app/database/pinecone.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pinecone import Pinecone
3
+ from dotenv import load_dotenv
4
+ import logging
5
+ from typing import Optional, List, Dict, Any, Union, Tuple
6
+ import time
7
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
8
+ import google.generativeai as genai
9
+ from app.utils.utils import cache
10
+ from langchain_core.retrievers import BaseRetriever
11
+ from langchain.callbacks.manager import Callbacks
12
+ from langchain_core.documents import Document
13
+ from langchain_core.pydantic_v1 import Field
14
+
15
+ # Configure logging
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Load environment variables
19
+ load_dotenv()
20
+
21
+ # Pinecone API key and index name
22
+ PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
23
+ PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME")
24
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
25
+
26
+ # Pinecone retrieval configuration
27
+ DEFAULT_LIMIT_K = int(os.getenv("PINECONE_DEFAULT_LIMIT_K", "10"))
28
+ DEFAULT_TOP_K = int(os.getenv("PINECONE_DEFAULT_TOP_K", "6"))
29
+ DEFAULT_SIMILARITY_METRIC = os.getenv("PINECONE_DEFAULT_SIMILARITY_METRIC", "cosine")
30
+ DEFAULT_SIMILARITY_THRESHOLD = float(os.getenv("PINECONE_DEFAULT_SIMILARITY_THRESHOLD", "0.75"))
31
+ ALLOWED_METRICS = os.getenv("PINECONE_ALLOWED_METRICS", "cosine,dotproduct,euclidean").split(",")
32
+
33
+ # Export constants for importing elsewhere
34
+ __all__ = [
35
+ 'get_pinecone_index',
36
+ 'check_db_connection',
37
+ 'search_vectors',
38
+ 'upsert_vectors',
39
+ 'delete_vectors',
40
+ 'fetch_metadata',
41
+ 'get_chain',
42
+ 'DEFAULT_TOP_K',
43
+ 'DEFAULT_LIMIT_K',
44
+ 'DEFAULT_SIMILARITY_METRIC',
45
+ 'DEFAULT_SIMILARITY_THRESHOLD',
46
+ 'ALLOWED_METRICS',
47
+ 'ThresholdRetriever'
48
+ ]
49
+
50
+ # Configure Google API
51
+ if GOOGLE_API_KEY:
52
+ genai.configure(api_key=GOOGLE_API_KEY)
53
+
54
+ # Initialize global variables to store instances of Pinecone and index
55
+ pc = None
56
+ index = None
57
+ _retriever_instance = None
58
+
59
+ # Check environment variables
60
+ if not PINECONE_API_KEY:
61
+ logger.error("PINECONE_API_KEY is not set in environment variables")
62
+
63
+ if not PINECONE_INDEX_NAME:
64
+ logger.error("PINECONE_INDEX_NAME is not set in environment variables")
65
+
66
+ # Initialize Pinecone
67
+ def init_pinecone():
68
+ """Initialize pinecone connection using new API"""
69
+ global pc, index
70
+
71
+ try:
72
+ # Only initialize if not already initialized
73
+ if pc is None:
74
+ logger.info(f"Initializing Pinecone connection to index {PINECONE_INDEX_NAME}...")
75
+
76
+ # Initialize Pinecone client using the new API
77
+ pc = Pinecone(api_key=PINECONE_API_KEY)
78
+
79
+ # Check if index exists
80
+ index_list = pc.list_indexes()
81
+
82
+ if not hasattr(index_list, 'names') or PINECONE_INDEX_NAME not in index_list.names():
83
+ logger.error(f"Index {PINECONE_INDEX_NAME} does not exist in Pinecone")
84
+ return None
85
+
86
+ # Get existing index
87
+ index = pc.Index(PINECONE_INDEX_NAME)
88
+ logger.info(f"Pinecone connection established to index {PINECONE_INDEX_NAME}")
89
+
90
+ return index
91
+ except Exception as e:
92
+ logger.error(f"Error initializing Pinecone: {e}")
93
+ return None
94
+
95
+ # Get Pinecone index singleton
96
+ def get_pinecone_index():
97
+ """Get Pinecone index"""
98
+ global index
99
+ if index is None:
100
+ index = init_pinecone()
101
+ return index
102
+
103
+ # Check Pinecone connection
104
+ def check_db_connection():
105
+ """Check Pinecone connection"""
106
+ try:
107
+ pinecone_index = get_pinecone_index()
108
+ if pinecone_index is None:
109
+ return False
110
+
111
+ # Check index information to confirm connection is working
112
+ stats = pinecone_index.describe_index_stats()
113
+
114
+ # Get total vector count from the new result structure
115
+ total_vectors = stats.get('total_vector_count', 0)
116
+ if hasattr(stats, 'namespaces'):
117
+ # If there are namespaces, calculate total vector count from namespaces
118
+ total_vectors = sum(ns.get('vector_count', 0) for ns in stats.namespaces.values())
119
+
120
+ logger.info(f"Pinecone connection is working. Total vectors: {total_vectors}")
121
+ return True
122
+ except Exception as e:
123
+ logger.error(f"Error in Pinecone connection: {e}")
124
+ return False
125
+
126
+ # Convert similarity score based on the metric
127
+ def convert_score(score: float, metric: str) -> float:
128
+ """
129
+ Convert similarity score to a 0-1 scale based on the metric used.
130
+ For metrics like euclidean distance where lower is better, we invert the score.
131
+
132
+ Args:
133
+ score: The raw similarity score
134
+ metric: The similarity metric used
135
+
136
+ Returns:
137
+ A normalized score between 0-1 where higher means more similar
138
+ """
139
+ if metric.lower() in ["euclidean", "l2"]:
140
+ # For distance metrics (lower is better), we inverse and normalize
141
+ # Assuming max reasonable distance is 2.0 for normalized vectors
142
+ return max(0, 1 - (score / 2.0))
143
+ else:
144
+ # For cosine and dot product (higher is better), return as is
145
+ return score
146
+
147
+ # Filter results based on similarity threshold
148
+ def filter_by_threshold(results, threshold: float, metric: str) -> List[Dict]:
149
+ """
150
+ Filter query results based on similarity threshold.
151
+
152
+ Args:
153
+ results: The query results from Pinecone
154
+ threshold: The similarity threshold (0-1)
155
+ metric: The similarity metric used
156
+
157
+ Returns:
158
+ Filtered list of matches
159
+ """
160
+ filtered_matches = []
161
+
162
+ if not hasattr(results, 'matches'):
163
+ return filtered_matches
164
+
165
+ for match in results.matches:
166
+ # Get the score
167
+ score = getattr(match, 'score', 0)
168
+
169
+ # Convert score based on metric
170
+ normalized_score = convert_score(score, metric)
171
+
172
+ # Filter based on threshold
173
+ if normalized_score >= threshold:
174
+ # Add normalized score as an additional attribute
175
+ match.normalized_score = normalized_score
176
+ filtered_matches.append(match)
177
+
178
+ return filtered_matches
179
+
180
+ # Search vectors in Pinecone with advanced options
181
+ async def search_vectors(
182
+ query_vector,
183
+ top_k: int = DEFAULT_TOP_K,
184
+ limit_k: int = DEFAULT_LIMIT_K,
185
+ similarity_metric: str = DEFAULT_SIMILARITY_METRIC,
186
+ similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
187
+ namespace: str = "",
188
+ filter: Optional[Dict] = None
189
+ ) -> Dict:
190
+ """
191
+ Search for most similar vectors in Pinecone with advanced filtering options.
192
+
193
+ Args:
194
+ query_vector: The query vector
195
+ top_k: Number of results to return (after threshold filtering)
196
+ limit_k: Maximum number of results to retrieve from Pinecone
197
+ similarity_metric: Similarity metric to use (cosine, dotproduct, euclidean)
198
+ similarity_threshold: Threshold for similarity (0-1)
199
+ namespace: Namespace to search in
200
+ filter: Filter query
201
+
202
+ Returns:
203
+ Search results with matches filtered by threshold
204
+ """
205
+ try:
206
+ # Validate parameters
207
+ if similarity_metric not in ALLOWED_METRICS:
208
+ logger.warning(f"Invalid similarity metric: {similarity_metric}. Using default: {DEFAULT_SIMILARITY_METRIC}")
209
+ similarity_metric = DEFAULT_SIMILARITY_METRIC
210
+
211
+ if limit_k < top_k:
212
+ logger.warning(f"limit_k ({limit_k}) must be greater than or equal to top_k ({top_k}). Setting limit_k to {top_k}")
213
+ limit_k = top_k
214
+
215
+ # Create cache key from parameters
216
+ vector_hash = hash(str(query_vector))
217
+ cache_key = f"pinecone_search:{vector_hash}:{limit_k}:{similarity_metric}:{similarity_threshold}:{namespace}:{filter}"
218
+
219
+ # Check cache first
220
+ cached_result = cache.get(cache_key)
221
+ if cached_result is not None:
222
+ logger.info("Returning cached Pinecone search results")
223
+ return cached_result
224
+
225
+ # If not in cache, perform search
226
+ pinecone_index = get_pinecone_index()
227
+ if pinecone_index is None:
228
+ logger.error("Failed to get Pinecone index for search")
229
+ return None
230
+
231
+ # Query Pinecone with the provided metric and higher limit_k to allow for threshold filtering
232
+ results = pinecone_index.query(
233
+ vector=query_vector,
234
+ top_k=limit_k, # Retrieve more results than needed to allow for threshold filtering
235
+ namespace=namespace,
236
+ filter=filter,
237
+ include_metadata=True,
238
+ include_values=False, # No need to return vector values to save bandwidth
239
+ metric=similarity_metric # Specify similarity metric
240
+ )
241
+
242
+ # Filter results by threshold
243
+ filtered_matches = filter_by_threshold(results, similarity_threshold, similarity_metric)
244
+
245
+ # Limit to top_k after filtering
246
+ filtered_matches = filtered_matches[:top_k]
247
+
248
+ # Create a new results object with filtered matches
249
+ results.matches = filtered_matches
250
+
251
+ # Log search result metrics
252
+ match_count = len(filtered_matches)
253
+ logger.info(f"Pinecone search returned {match_count} matches after threshold filtering (metric: {similarity_metric}, threshold: {similarity_threshold})")
254
+
255
+ # Store result in cache with 5 minute TTL
256
+ cache.set(cache_key, results, ttl=300)
257
+
258
+ return results
259
+ except Exception as e:
260
+ logger.error(f"Error searching vectors: {e}")
261
+ return None
262
+
263
+ # Upsert vectors to Pinecone
264
+ async def upsert_vectors(vectors, namespace=""):
265
+ """Upsert vectors to Pinecone index"""
266
+ try:
267
+ pinecone_index = get_pinecone_index()
268
+ if pinecone_index is None:
269
+ logger.error("Failed to get Pinecone index for upsert")
270
+ return None
271
+
272
+ response = pinecone_index.upsert(
273
+ vectors=vectors,
274
+ namespace=namespace
275
+ )
276
+
277
+ # Log upsert metrics
278
+ upserted_count = response.get('upserted_count', 0)
279
+ logger.info(f"Upserted {upserted_count} vectors to Pinecone")
280
+
281
+ return response
282
+ except Exception as e:
283
+ logger.error(f"Error upserting vectors: {e}")
284
+ return None
285
+
286
+ # Delete vectors from Pinecone
287
+ async def delete_vectors(ids, namespace=""):
288
+ """Delete vectors from Pinecone index"""
289
+ try:
290
+ pinecone_index = get_pinecone_index()
291
+ if pinecone_index is None:
292
+ logger.error("Failed to get Pinecone index for delete")
293
+ return False
294
+
295
+ response = pinecone_index.delete(
296
+ ids=ids,
297
+ namespace=namespace
298
+ )
299
+
300
+ logger.info(f"Deleted vectors with IDs {ids} from Pinecone")
301
+ return True
302
+ except Exception as e:
303
+ logger.error(f"Error deleting vectors: {e}")
304
+ return False
305
+
306
+ # Fetch vector metadata from Pinecone
307
+ async def fetch_metadata(ids, namespace=""):
308
+ """Fetch metadata for specific vector IDs"""
309
+ try:
310
+ pinecone_index = get_pinecone_index()
311
+ if pinecone_index is None:
312
+ logger.error("Failed to get Pinecone index for fetch")
313
+ return None
314
+
315
+ response = pinecone_index.fetch(
316
+ ids=ids,
317
+ namespace=namespace
318
+ )
319
+
320
+ return response
321
+ except Exception as e:
322
+ logger.error(f"Error fetching vector metadata: {e}")
323
+ return None
324
+
325
+ # Create a custom retriever class for Langchain integration
326
+ class ThresholdRetriever(BaseRetriever):
327
+ """
328
+ Custom retriever that supports threshold-based filtering and multiple similarity metrics.
329
+ This integrates with the Langchain ecosystem while using our advanced retrieval logic.
330
+ """
331
+
332
+ vectorstore: Any = Field(description="Vector store to use for retrieval")
333
+ embeddings: Any = Field(description="Embeddings model to use for retrieval")
334
+ search_kwargs: Dict[str, Any] = Field(default_factory=dict, description="Search kwargs for the vectorstore")
335
+ top_k: int = Field(default=DEFAULT_TOP_K, description="Number of results to return after filtering")
336
+ limit_k: int = Field(default=DEFAULT_LIMIT_K, description="Maximum number of results to retrieve from Pinecone")
337
+ similarity_metric: str = Field(default=DEFAULT_SIMILARITY_METRIC, description="Similarity metric to use")
338
+ similarity_threshold: float = Field(default=DEFAULT_SIMILARITY_THRESHOLD, description="Threshold for similarity")
339
+
340
+ class Config:
341
+ """Configuration for this pydantic object."""
342
+ arbitrary_types_allowed = True
343
+
344
+ async def search_vectors_sync(
345
+ self, query_vector,
346
+ top_k: int = DEFAULT_TOP_K,
347
+ limit_k: int = DEFAULT_LIMIT_K,
348
+ similarity_metric: str = DEFAULT_SIMILARITY_METRIC,
349
+ similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
350
+ namespace: str = "",
351
+ filter: Optional[Dict] = None
352
+ ) -> Dict:
353
+ """Synchronous wrapper for search_vectors"""
354
+ import asyncio
355
+ try:
356
+ # Get current event loop or create a new one
357
+ try:
358
+ loop = asyncio.get_event_loop()
359
+ except RuntimeError:
360
+ loop = asyncio.new_event_loop()
361
+ asyncio.set_event_loop(loop)
362
+
363
+ # Use event loop to run async function
364
+ if loop.is_running():
365
+ # If we're in an event loop, use asyncio.create_task
366
+ task = asyncio.create_task(search_vectors(
367
+ query_vector=query_vector,
368
+ top_k=top_k,
369
+ limit_k=limit_k,
370
+ similarity_metric=similarity_metric,
371
+ similarity_threshold=similarity_threshold,
372
+ namespace=namespace,
373
+ filter=filter
374
+ ))
375
+ return await task
376
+ else:
377
+ # If not in an event loop, just await directly
378
+ return await search_vectors(
379
+ query_vector=query_vector,
380
+ top_k=top_k,
381
+ limit_k=limit_k,
382
+ similarity_metric=similarity_metric,
383
+ similarity_threshold=similarity_threshold,
384
+ namespace=namespace,
385
+ filter=filter
386
+ )
387
+ except Exception as e:
388
+ logger.error(f"Error in search_vectors_sync: {e}")
389
+ return None
390
+
391
+ def _get_relevant_documents(
392
+ self, query: str, *, run_manager: Callbacks = None
393
+ ) -> List[Document]:
394
+ """
395
+ Get documents relevant to the query using threshold-based retrieval.
396
+
397
+ Args:
398
+ query: The query string
399
+ run_manager: The callbacks manager
400
+
401
+ Returns:
402
+ List of relevant documents
403
+ """
404
+ # Generate embedding for query using the embeddings model
405
+ try:
406
+ # Use the embeddings model we stored in the class
407
+ embedding = self.embeddings.embed_query(query)
408
+ except Exception as e:
409
+ logger.error(f"Error generating embedding: {e}")
410
+ # Fallback to creating a new embedding model if needed
411
+ embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
412
+ embedding = embedding_model.embed_query(query)
413
+
414
+ # Perform search with advanced options - avoid asyncio.run()
415
+ import asyncio
416
+
417
+ # Get or create event loop
418
+ try:
419
+ loop = asyncio.get_event_loop()
420
+ except RuntimeError:
421
+ loop = asyncio.new_event_loop()
422
+ asyncio.set_event_loop(loop)
423
+
424
+ # Run asynchronous search in a safe way
425
+ if loop.is_running():
426
+ # We're inside an existing event loop (like in FastAPI)
427
+ # Use a different approach - convert it to a synchronous call
428
+ from concurrent.futures import ThreadPoolExecutor
429
+ import functools
430
+
431
+ # Define a wrapper function to run in a thread
432
+ def run_async_in_thread():
433
+ # Create a new event loop for this thread
434
+ thread_loop = asyncio.new_event_loop()
435
+ asyncio.set_event_loop(thread_loop)
436
+ # Run the coroutine and return the result
437
+ return thread_loop.run_until_complete(search_vectors(
438
+ query_vector=embedding,
439
+ top_k=self.top_k,
440
+ limit_k=self.limit_k,
441
+ similarity_metric=self.similarity_metric,
442
+ similarity_threshold=self.similarity_threshold,
443
+ namespace=getattr(self.vectorstore, "namespace", ""),
444
+ filter=self.search_kwargs.get("filter", None)
445
+ ))
446
+
447
+ # Run the async function in a thread
448
+ with ThreadPoolExecutor() as executor:
449
+ search_result = executor.submit(run_async_in_thread).result()
450
+ else:
451
+ # No event loop running, we can use run_until_complete
452
+ search_result = loop.run_until_complete(search_vectors(
453
+ query_vector=embedding,
454
+ top_k=self.top_k,
455
+ limit_k=self.limit_k,
456
+ similarity_metric=self.similarity_metric,
457
+ similarity_threshold=self.similarity_threshold,
458
+ namespace=getattr(self.vectorstore, "namespace", ""),
459
+ filter=self.search_kwargs.get("filter", None)
460
+ ))
461
+
462
+ # Convert to documents
463
+ documents = []
464
+ if search_result and hasattr(search_result, 'matches'):
465
+ for match in search_result.matches:
466
+ # Extract metadata
467
+ metadata = {}
468
+ if hasattr(match, 'metadata'):
469
+ metadata = match.metadata
470
+
471
+ # Add score to metadata
472
+ score = getattr(match, 'score', 0)
473
+ normalized_score = getattr(match, 'normalized_score', score)
474
+ metadata['score'] = score
475
+ metadata['normalized_score'] = normalized_score
476
+
477
+ # Extract text
478
+ text = metadata.get('text', '')
479
+ if 'text' in metadata:
480
+ del metadata['text'] # Remove from metadata since it's the content
481
+
482
+ # Create Document
483
+ doc = Document(
484
+ page_content=text,
485
+ metadata=metadata
486
+ )
487
+ documents.append(doc)
488
+
489
+ return documents
490
+
491
+ # Get the retrieval chain with Pinecone vector store
492
+ def get_chain(
493
+ index_name=PINECONE_INDEX_NAME,
494
+ namespace="Default",
495
+ top_k=DEFAULT_TOP_K,
496
+ limit_k=DEFAULT_LIMIT_K,
497
+ similarity_metric=DEFAULT_SIMILARITY_METRIC,
498
+ similarity_threshold=DEFAULT_SIMILARITY_THRESHOLD
499
+ ):
500
+ """
501
+ Get the retrieval chain with Pinecone vector store using threshold-based retrieval.
502
+
503
+ Args:
504
+ index_name: Pinecone index name
505
+ namespace: Pinecone namespace
506
+ top_k: Number of results to return after filtering
507
+ limit_k: Maximum number of results to retrieve from Pinecone
508
+ similarity_metric: Similarity metric to use (cosine, dotproduct, euclidean)
509
+ similarity_threshold: Threshold for similarity (0-1)
510
+
511
+ Returns:
512
+ ThresholdRetriever instance
513
+ """
514
+ global _retriever_instance
515
+ try:
516
+ # If already initialized with same parameters, return cached instance
517
+ if _retriever_instance is not None:
518
+ return _retriever_instance
519
+
520
+ # Check if chain has been cached
521
+ cache_key = f"pinecone_retriever:{index_name}:{namespace}:{top_k}:{limit_k}:{similarity_metric}:{similarity_threshold}"
522
+ cached_retriever = cache.get(cache_key)
523
+ if cached_retriever is not None:
524
+ _retriever_instance = cached_retriever
525
+ logger.info("Retrieved cached Pinecone retriever")
526
+ return _retriever_instance
527
+
528
+ start_time = time.time()
529
+ logger.info("Initializing new retriever chain with threshold-based filtering")
530
+
531
+ # Initialize embeddings model
532
+ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
533
+
534
+ # Get index
535
+ pinecone_index = get_pinecone_index()
536
+ if not pinecone_index:
537
+ logger.error("Failed to get Pinecone index for retriever chain")
538
+ return None
539
+
540
+ # Get statistics for logging
541
+ try:
542
+ stats = pinecone_index.describe_index_stats()
543
+ total_vectors = stats.get('total_vector_count', 0)
544
+ logger.info(f"Pinecone index stats - Total vectors: {total_vectors}")
545
+ except Exception as e:
546
+ logger.error(f"Error getting index stats: {e}")
547
+
548
+ # Use Pinecone from langchain_community.vectorstores
549
+ from langchain_community.vectorstores import Pinecone as LangchainPinecone
550
+
551
+ logger.info(f"Creating Pinecone vectorstore with index: {index_name}, namespace: {namespace}")
552
+ vectorstore = LangchainPinecone.from_existing_index(
553
+ embedding=embeddings,
554
+ index_name=index_name,
555
+ namespace=namespace,
556
+ text_key="text"
557
+ )
558
+
559
+ # Create threshold-based retriever
560
+ logger.info(f"Creating ThresholdRetriever with top_k={top_k}, limit_k={limit_k}, " +
561
+ f"metric={similarity_metric}, threshold={similarity_threshold}")
562
+
563
+ # Create ThresholdRetriever with both vectorstore and embeddings
564
+ _retriever_instance = ThresholdRetriever(
565
+ vectorstore=vectorstore,
566
+ embeddings=embeddings, # Pass embeddings separately
567
+ top_k=top_k,
568
+ limit_k=limit_k,
569
+ similarity_metric=similarity_metric,
570
+ similarity_threshold=similarity_threshold
571
+ )
572
+
573
+ logger.info(f"Pinecone retriever initialized in {time.time() - start_time:.2f} seconds")
574
+
575
+ # Cache the retriever with longer TTL (1 hour) since it rarely changes
576
+ cache.set(cache_key, _retriever_instance, ttl=3600)
577
+
578
+ return _retriever_instance
579
+ except Exception as e:
580
+ logger.error(f"Error creating retrieval chain: {e}")
581
+ return None
app/database/postgresql.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from sqlalchemy import create_engine, text
3
+ from sqlalchemy.ext.declarative import declarative_base
4
+ from sqlalchemy.orm import sessionmaker
5
+ from sqlalchemy.exc import SQLAlchemyError, OperationalError
6
+ from dotenv import load_dotenv
7
+ import logging
8
+
9
+ # Cấu hình logging
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Load environment variables
13
+ load_dotenv()
14
+
15
+ # Get DB connection mode from environment
16
+ DB_CONNECTION_MODE = os.getenv("DB_CONNECTION_MODE", "aiven")
17
+
18
+ # Set connection string based on mode
19
+ if DB_CONNECTION_MODE == "aiven":
20
+ DATABASE_URL = os.getenv("AIVEN_DB_URL")
21
+ else:
22
+ # Default or other connection modes can be added here
23
+ DATABASE_URL = os.getenv("AIVEN_DB_URL")
24
+
25
+ if not DATABASE_URL:
26
+ logger.error("No database URL configured. Please set AIVEN_DB_URL environment variable.")
27
+ DATABASE_URL = "postgresql://localhost/test" # Fallback để không crash khi khởi động
28
+
29
+ # Create SQLAlchemy engine
30
+ try:
31
+ engine = create_engine(
32
+ DATABASE_URL,
33
+ pool_pre_ping=True,
34
+ pool_recycle=300, # Recycle connections every 5 minutes
35
+ pool_size=10, # Tăng kích thước pool từ 5 lên 10
36
+ max_overflow=20, # Tăng số lượng kết nối tối đa từ 10 lên 20
37
+ connect_args={
38
+ "connect_timeout": 3, # Giảm timeout từ 5 xuống 3 giây
39
+ "keepalives": 1, # Bật keepalive
40
+ "keepalives_idle": 30, # Thời gian idle trước khi gửi keepalive
41
+ "keepalives_interval": 10, # Khoảng thời gian giữa các gói keepalive
42
+ "keepalives_count": 5 # Số lần thử lại trước khi đóng kết nối
43
+ },
44
+ # Thêm các tùy chọn hiệu suất
45
+ isolation_level="READ COMMITTED", # Mức cô lập thấp hơn READ COMMITTED
46
+ echo=False, # Tắt echo SQL để giảm overhead logging
47
+ echo_pool=False # Tắt echo pool để giảm overhead logging
48
+ )
49
+ logger.info("PostgreSQL engine initialized")
50
+ except Exception as e:
51
+ logger.error(f"Failed to initialize PostgreSQL engine: {e}")
52
+ # Không raise exception để tránh crash khi khởi động, các xử lý lỗi sẽ được thực hiện ở các function
53
+
54
+ # Create session factory with optimized settings
55
+ SessionLocal = sessionmaker(
56
+ autocommit=False,
57
+ autoflush=False,
58
+ bind=engine,
59
+ expire_on_commit=False # Tránh truy vấn lại DB sau khi commit
60
+ )
61
+
62
+ # Base class for declarative models - use sqlalchemy.orm for SQLAlchemy 2.0 compatibility
63
+ from sqlalchemy.orm import declarative_base
64
+ Base = declarative_base()
65
+
66
+ # Kiểm tra kết nối PostgreSQL
67
+ def check_db_connection():
68
+ """Kiểm tra kết nối PostgreSQL"""
69
+ try:
70
+ # Thực hiện một truy vấn đơn giản để kiểm tra kết nối
71
+ with engine.connect() as connection:
72
+ connection.execute(text("SELECT 1"))
73
+ logger.info("PostgreSQL connection is working")
74
+ return True
75
+ except OperationalError as e:
76
+ logger.error(f"PostgreSQL connection failed: {e}")
77
+ return False
78
+ except Exception as e:
79
+ logger.error(f"Unknown error when checking PostgreSQL connection: {e}")
80
+ return False
81
+
82
+ # Dependency to get DB session
83
+ def get_db():
84
+ """Get database session dependency for FastAPI endpoints"""
85
+ db = SessionLocal()
86
+ try:
87
+ yield db
88
+ except SQLAlchemyError as e:
89
+ logger.error(f"Database session error: {e}")
90
+ db.rollback()
91
+ raise
92
+ finally:
93
+ db.close()
94
+
95
+ # Tạo các bảng trong cơ sở dữ liệu nếu chưa tồn tại
96
+ def create_tables():
97
+ """Tạo các bảng trong cơ sở dữ liệu"""
98
+ try:
99
+ Base.metadata.create_all(bind=engine)
100
+ logger.info("Database tables created or already exist")
101
+ return True
102
+ except SQLAlchemyError as e:
103
+ logger.error(f"Failed to create database tables: {e}")
104
+ return False
app/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Pydantic models package
app/models/mongodb_models.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field, ConfigDict
2
+ from typing import Optional, List, Dict, Any
3
+ from datetime import datetime
4
+ import uuid
5
+
6
+ class SessionBase(BaseModel):
7
+ """Base model for session data"""
8
+ session_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
9
+ factor: str
10
+ action: str
11
+ first_name: str
12
+ last_name: Optional[str] = None
13
+ message: Optional[str] = None
14
+ user_id: str
15
+ username: Optional[str] = None
16
+
17
+ class SessionCreate(SessionBase):
18
+ """Model for creating new session"""
19
+ response: Optional[str] = None
20
+
21
+ class SessionResponse(SessionBase):
22
+ """Response model for session data"""
23
+ created_at: str
24
+ response: Optional[str] = None
25
+
26
+ model_config = ConfigDict(
27
+ json_schema_extra={
28
+ "example": {
29
+ "session_id": "123e4567-e89b-12d3-a456-426614174000",
30
+ "factor": "user",
31
+ "action": "asking_freely",
32
+ "created_at": "2023-06-01 14:30:45",
33
+ "first_name": "John",
34
+ "last_name": "Doe",
35
+ "message": "How can I find emergency contacts?",
36
+ "user_id": "12345678",
37
+ "username": "johndoe",
38
+ "response": "You can find emergency contacts in the Emergency section..."
39
+ }
40
+ }
41
+ )
42
+
43
+ class HistoryRequest(BaseModel):
44
+ """Request model for history"""
45
+ user_id: str
46
+ n: int = 3
47
+
48
+ class QuestionAnswer(BaseModel):
49
+ """Model for question-answer pair"""
50
+ question: str
51
+ answer: str
52
+
53
+ class HistoryResponse(BaseModel):
54
+ """Response model for history"""
55
+ history: List[QuestionAnswer]
app/models/rag_models.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Optional, List, Dict, Any
3
+
4
+ class ChatRequest(BaseModel):
5
+ """Request model for chat endpoint"""
6
+ user_id: str = Field(..., description="User ID from Telegram")
7
+ question: str = Field(..., description="User's question")
8
+ include_history: bool = Field(True, description="Whether to include user history in prompt")
9
+ use_rag: bool = Field(True, description="Whether to use RAG")
10
+
11
+ # Advanced retrieval parameters
12
+ similarity_top_k: int = Field(6, description="Number of top similar documents to return (after filtering)")
13
+ limit_k: int = Field(10, description="Maximum number of documents to retrieve from vector store")
14
+ similarity_metric: str = Field("cosine", description="Similarity metric to use (cosine, dotproduct, euclidean)")
15
+ similarity_threshold: float = Field(0.75, description="Threshold for vector similarity (0-1)")
16
+
17
+ # User information
18
+ session_id: Optional[str] = Field(None, description="Session ID for tracking conversations")
19
+ first_name: Optional[str] = Field(None, description="User's first name")
20
+ last_name: Optional[str] = Field(None, description="User's last name")
21
+ username: Optional[str] = Field(None, description="User's username")
22
+
23
+ class SourceDocument(BaseModel):
24
+ """Model for source documents"""
25
+ text: str = Field(..., description="Text content of the document")
26
+ source: Optional[str] = Field(None, description="Source of the document")
27
+ score: Optional[float] = Field(None, description="Raw similarity score of the document")
28
+ normalized_score: Optional[float] = Field(None, description="Normalized similarity score (0-1)")
29
+ metadata: Optional[Dict[str, Any]] = Field(None, description="Metadata of the document")
30
+
31
+ class ChatResponse(BaseModel):
32
+ """Response model for chat endpoint"""
33
+ answer: str = Field(..., description="Generated answer")
34
+ processing_time: float = Field(..., description="Processing time in seconds")
35
+
36
+ class ChatResponseInternal(BaseModel):
37
+ """Internal model for chat response with sources - used only for logging"""
38
+ answer: str
39
+ sources: Optional[List[SourceDocument]] = Field(None, description="Source documents used for generating answer")
40
+ processing_time: Optional[float] = None
41
+
42
+ class EmbeddingRequest(BaseModel):
43
+ """Request model for embedding endpoint"""
44
+ text: str = Field(..., description="Text to generate embedding for")
45
+
46
+ class EmbeddingResponse(BaseModel):
47
+ """Response model for embedding endpoint"""
48
+ embedding: List[float] = Field(..., description="Generated embedding")
49
+ text: str = Field(..., description="Text that was embedded")
50
+ model: str = Field(..., description="Model used for embedding")
51
+
52
+ class HealthResponse(BaseModel):
53
+ """Response model for health endpoint"""
54
+ status: str
55
+ services: Dict[str, bool]
56
+ timestamp: str
57
+
58
+ class UserMessageModel(BaseModel):
59
+ """Model for user messages sent to the RAG API"""
60
+ user_id: str = Field(..., description="User ID from the client application")
61
+ session_id: str = Field(..., description="Session ID for tracking the conversation")
62
+ message: str = Field(..., description="User's message/question")
63
+
64
+ # Advanced retrieval parameters (optional)
65
+ similarity_top_k: Optional[int] = Field(None, description="Number of top similar documents to return (after filtering)")
66
+ limit_k: Optional[int] = Field(None, description="Maximum number of documents to retrieve from vector store")
67
+ similarity_metric: Optional[str] = Field(None, description="Similarity metric to use (cosine, dotproduct, euclidean)")
68
+ similarity_threshold: Optional[float] = Field(None, description="Threshold for vector similarity (0-1)")
app/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Utility functions package
app/utils/debug_utils.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ import traceback
5
+ import json
6
+ import time
7
+ from datetime import datetime
8
+ import platform
9
+
10
+ # Try to import psutil, provide fallback if not available
11
+ try:
12
+ import psutil
13
+ PSUTIL_AVAILABLE = True
14
+ except ImportError:
15
+ PSUTIL_AVAILABLE = False
16
+ logging.warning("psutil module not available. System monitoring features will be limited.")
17
+
18
+ # Configure logging
19
+ logger = logging.getLogger(__name__)
20
+
21
+ class DebugInfo:
22
+ """Class containing debug information"""
23
+
24
+ @staticmethod
25
+ def get_system_info():
26
+ """Get system information"""
27
+ try:
28
+ info = {
29
+ "os": platform.system(),
30
+ "os_version": platform.version(),
31
+ "python_version": platform.python_version(),
32
+ "cpu_count": os.cpu_count(),
33
+ "timestamp": datetime.now().isoformat()
34
+ }
35
+
36
+ # Add information from psutil if available
37
+ if PSUTIL_AVAILABLE:
38
+ info.update({
39
+ "total_memory": round(psutil.virtual_memory().total / (1024 * 1024 * 1024), 2), # GB
40
+ "available_memory": round(psutil.virtual_memory().available / (1024 * 1024 * 1024), 2), # GB
41
+ "cpu_usage": psutil.cpu_percent(interval=0.1),
42
+ "memory_usage": psutil.virtual_memory().percent,
43
+ "disk_usage": psutil.disk_usage('/').percent,
44
+ })
45
+ else:
46
+ info.update({
47
+ "total_memory": "psutil not available",
48
+ "available_memory": "psutil not available",
49
+ "cpu_usage": "psutil not available",
50
+ "memory_usage": "psutil not available",
51
+ "disk_usage": "psutil not available",
52
+ })
53
+
54
+ return info
55
+ except Exception as e:
56
+ logger.error(f"Error getting system info: {e}")
57
+ return {"error": str(e)}
58
+
59
+ @staticmethod
60
+ def get_env_info():
61
+ """Get environment variable information (masking sensitive information)"""
62
+ try:
63
+ # List of environment variables to mask values
64
+ sensitive_vars = [
65
+ "API_KEY", "SECRET", "PASSWORD", "TOKEN", "AUTH", "MONGODB_URL",
66
+ "AIVEN_DB_URL", "PINECONE_API_KEY", "GOOGLE_API_KEY"
67
+ ]
68
+
69
+ env_vars = {}
70
+ for key, value in os.environ.items():
71
+ # Check if environment variable contains sensitive words
72
+ is_sensitive = any(s in key.upper() for s in sensitive_vars)
73
+
74
+ if is_sensitive and value:
75
+ # Mask value displaying only the first 4 characters
76
+ masked_value = value[:4] + "****" if len(value) > 4 else "****"
77
+ env_vars[key] = masked_value
78
+ else:
79
+ env_vars[key] = value
80
+
81
+ return env_vars
82
+ except Exception as e:
83
+ logger.error(f"Error getting environment info: {e}")
84
+ return {"error": str(e)}
85
+
86
+ @staticmethod
87
+ def get_database_status():
88
+ """Get database connection status"""
89
+ try:
90
+ from app.database.postgresql import check_db_connection as check_postgresql
91
+ from app.database.mongodb import check_db_connection as check_mongodb
92
+ from app.database.pinecone import check_db_connection as check_pinecone
93
+
94
+ return {
95
+ "postgresql": check_postgresql(),
96
+ "mongodb": check_mongodb(),
97
+ "pinecone": check_pinecone(),
98
+ "timestamp": datetime.now().isoformat()
99
+ }
100
+ except Exception as e:
101
+ logger.error(f"Error getting database status: {e}")
102
+ return {"error": str(e)}
103
+
104
+ class PerformanceMonitor:
105
+ """Performance monitoring class"""
106
+
107
+ def __init__(self):
108
+ self.start_time = time.time()
109
+ self.checkpoints = []
110
+
111
+ def checkpoint(self, name):
112
+ """Mark a checkpoint and record the time"""
113
+ current_time = time.time()
114
+ elapsed = current_time - self.start_time
115
+ self.checkpoints.append({
116
+ "name": name,
117
+ "time": current_time,
118
+ "elapsed": elapsed
119
+ })
120
+ logger.debug(f"Checkpoint '{name}' at {elapsed:.4f}s")
121
+ return elapsed
122
+
123
+ def get_report(self):
124
+ """Generate performance report"""
125
+ if not self.checkpoints:
126
+ return {"error": "No checkpoints recorded"}
127
+
128
+ total_time = time.time() - self.start_time
129
+
130
+ # Calculate time between checkpoints
131
+ intervals = []
132
+ prev_time = self.start_time
133
+
134
+ for checkpoint in self.checkpoints:
135
+ interval = checkpoint["time"] - prev_time
136
+ intervals.append({
137
+ "name": checkpoint["name"],
138
+ "interval": interval,
139
+ "elapsed": checkpoint["elapsed"]
140
+ })
141
+ prev_time = checkpoint["time"]
142
+
143
+ return {
144
+ "total_time": total_time,
145
+ "checkpoint_count": len(self.checkpoints),
146
+ "intervals": intervals
147
+ }
148
+
149
+ class ErrorTracker:
150
+ """Class to track and record errors"""
151
+
152
+ def __init__(self, max_errors=100):
153
+ self.errors = []
154
+ self.max_errors = max_errors
155
+
156
+ def track_error(self, error, context=None):
157
+ """Record error information"""
158
+ error_info = {
159
+ "error_type": type(error).__name__,
160
+ "error_message": str(error),
161
+ "traceback": traceback.format_exc(),
162
+ "timestamp": datetime.now().isoformat(),
163
+ "context": context or {}
164
+ }
165
+
166
+ # Add to error list
167
+ self.errors.append(error_info)
168
+
169
+ # Limit the number of stored errors
170
+ if len(self.errors) > self.max_errors:
171
+ self.errors.pop(0) # Remove oldest error
172
+
173
+ return error_info
174
+
175
+ def get_errors(self, limit=None):
176
+ """Get list of recorded errors"""
177
+ if limit is None or limit >= len(self.errors):
178
+ return self.errors
179
+ return self.errors[-limit:] # Return most recent errors
180
+
181
+ # Initialize global objects
182
+ error_tracker = ErrorTracker()
183
+ performance_monitor = PerformanceMonitor()
184
+
185
+ def debug_view(request=None):
186
+ """Create a full debug report"""
187
+ debug_data = {
188
+ "system_info": DebugInfo.get_system_info(),
189
+ "database_status": DebugInfo.get_database_status(),
190
+ "performance": performance_monitor.get_report(),
191
+ "recent_errors": error_tracker.get_errors(limit=10),
192
+ "timestamp": datetime.now().isoformat()
193
+ }
194
+
195
+ # Add request information if available
196
+ if request:
197
+ debug_data["request"] = {
198
+ "method": request.method,
199
+ "url": str(request.url),
200
+ "headers": dict(request.headers),
201
+ "client": {
202
+ "host": request.client.host if request.client else "unknown",
203
+ "port": request.client.port if request.client else "unknown"
204
+ }
205
+ }
206
+
207
+ return debug_data
app/utils/middleware.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import Request, status
2
+ from fastapi.responses import JSONResponse
3
+ from starlette.middleware.base import BaseHTTPMiddleware
4
+ import logging
5
+ import time
6
+ import traceback
7
+ import uuid
8
+ from .utils import get_local_time
9
+
10
+ # Configure logging
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class RequestLoggingMiddleware(BaseHTTPMiddleware):
14
+ """Middleware to log requests and responses"""
15
+
16
+ async def dispatch(self, request: Request, call_next):
17
+ request_id = str(uuid.uuid4())
18
+ request.state.request_id = request_id
19
+
20
+ # Log request information
21
+ client_host = request.client.host if request.client else "unknown"
22
+ logger.info(f"Request [{request_id}]: {request.method} {request.url.path} from {client_host}")
23
+
24
+ # Measure processing time
25
+ start_time = time.time()
26
+
27
+ try:
28
+ # Process request
29
+ response = await call_next(request)
30
+
31
+ # Calculate processing time
32
+ process_time = time.time() - start_time
33
+ logger.info(f"Response [{request_id}]: {response.status_code} processed in {process_time:.4f}s")
34
+
35
+ # Add headers
36
+ response.headers["X-Request-ID"] = request_id
37
+ response.headers["X-Process-Time"] = str(process_time)
38
+
39
+ return response
40
+
41
+ except Exception as e:
42
+ # Log error
43
+ process_time = time.time() - start_time
44
+ logger.error(f"Error [{request_id}] after {process_time:.4f}s: {str(e)}")
45
+ logger.error(traceback.format_exc())
46
+
47
+ # Return error response
48
+ return JSONResponse(
49
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
50
+ content={
51
+ "detail": "Internal server error",
52
+ "request_id": request_id,
53
+ "timestamp": get_local_time()
54
+ }
55
+ )
56
+
57
+ class ErrorHandlingMiddleware(BaseHTTPMiddleware):
58
+ """Middleware to handle uncaught exceptions in the application"""
59
+
60
+ async def dispatch(self, request: Request, call_next):
61
+ try:
62
+ return await call_next(request)
63
+ except Exception as e:
64
+ # Get request_id if available
65
+ request_id = getattr(request.state, "request_id", str(uuid.uuid4()))
66
+
67
+ # Log error
68
+ logger.error(f"Uncaught exception [{request_id}]: {str(e)}")
69
+ logger.error(traceback.format_exc())
70
+
71
+ # Return error response
72
+ return JSONResponse(
73
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
74
+ content={
75
+ "detail": "Internal server error",
76
+ "request_id": request_id,
77
+ "timestamp": get_local_time()
78
+ }
79
+ )
80
+
81
+ class DatabaseCheckMiddleware(BaseHTTPMiddleware):
82
+ """Middleware to check database connections before each request"""
83
+
84
+ async def dispatch(self, request: Request, call_next):
85
+ # Skip paths that don't need database checks
86
+ skip_paths = ["/", "/health", "/docs", "/redoc", "/openapi.json"]
87
+ if request.url.path in skip_paths:
88
+ return await call_next(request)
89
+
90
+ # Check database connections
91
+ try:
92
+ # TODO: Add checks for MongoDB and Pinecone if needed
93
+ # PostgreSQL check is already done in route handler with get_db() method
94
+
95
+ # Process request normally
96
+ return await call_next(request)
97
+
98
+ except Exception as e:
99
+ # Log error
100
+ logger.error(f"Database connection check failed: {str(e)}")
101
+
102
+ # Return error response
103
+ return JSONResponse(
104
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
105
+ content={
106
+ "detail": "Database connection failed",
107
+ "timestamp": get_local_time()
108
+ }
109
+ )
app/utils/utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ import uuid
4
+ import threading
5
+ from functools import wraps
6
+ from datetime import datetime, timedelta
7
+ import pytz
8
+ from typing import Callable, Any, Dict, Optional
9
+
10
+ # Configure logging
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
14
+ )
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Asia/Ho_Chi_Minh timezone
18
+ asia_tz = pytz.timezone('Asia/Ho_Chi_Minh')
19
+
20
+ def generate_uuid():
21
+ """Generate a unique identifier"""
22
+ return str(uuid.uuid4())
23
+
24
+ def get_current_time():
25
+ """Get current time in ISO format"""
26
+ return datetime.now().isoformat()
27
+
28
+ def get_local_time():
29
+ """Get current time in Asia/Ho_Chi_Minh timezone"""
30
+ return datetime.now(asia_tz).strftime("%Y-%m-%d %H:%M:%S")
31
+
32
+ def get_local_datetime():
33
+ """Get current datetime object in Asia/Ho_Chi_Minh timezone"""
34
+ return datetime.now(asia_tz)
35
+
36
+ # For backward compatibility
37
+ get_vietnam_time = get_local_time
38
+ get_vietnam_datetime = get_local_datetime
39
+
40
+ def timer_decorator(func: Callable) -> Callable:
41
+ """
42
+ Decorator to time function execution and log results.
43
+ """
44
+ @wraps(func)
45
+ async def wrapper(*args, **kwargs):
46
+ start_time = time.time()
47
+ try:
48
+ result = await func(*args, **kwargs)
49
+ elapsed_time = time.time() - start_time
50
+ logger.info(f"Function {func.__name__} executed in {elapsed_time:.4f} seconds")
51
+ return result
52
+ except Exception as e:
53
+ elapsed_time = time.time() - start_time
54
+ logger.error(f"Function {func.__name__} failed after {elapsed_time:.4f} seconds: {e}")
55
+ raise
56
+ return wrapper
57
+
58
+ def sanitize_input(text):
59
+ """Sanitize input text"""
60
+ if not text:
61
+ return ""
62
+ # Remove potential dangerous characters or patterns
63
+ return text.strip()
64
+
65
+ def truncate_text(text, max_length=100):
66
+ """
67
+ Truncate text to given max length and add ellipsis.
68
+ """
69
+ if not text or len(text) <= max_length:
70
+ return text
71
+ return text[:max_length] + "..."
72
+
73
+ # Simple in-memory cache implementation (replaces Redis dependency)
74
+ class SimpleCache:
75
+ def __init__(self):
76
+ self._cache = {}
77
+ self._expiry = {}
78
+
79
+ def get(self, key: str) -> Optional[Any]:
80
+ """Get value from cache if it exists and hasn't expired"""
81
+ if key in self._cache:
82
+ # Check if the key has expired
83
+ if key in self._expiry and self._expiry[key] > datetime.now():
84
+ return self._cache[key]
85
+ else:
86
+ # Clean up expired keys
87
+ if key in self._cache:
88
+ del self._cache[key]
89
+ if key in self._expiry:
90
+ del self._expiry[key]
91
+ return None
92
+
93
+ def set(self, key: str, value: Any, ttl: int = 300) -> None:
94
+ """Set a value in the cache with TTL in seconds"""
95
+ self._cache[key] = value
96
+ # Set expiry time
97
+ self._expiry[key] = datetime.now() + timedelta(seconds=ttl)
98
+
99
+ def delete(self, key: str) -> None:
100
+ """Delete a key from the cache"""
101
+ if key in self._cache:
102
+ del self._cache[key]
103
+ if key in self._expiry:
104
+ del self._expiry[key]
105
+
106
+ def clear(self) -> None:
107
+ """Clear the entire cache"""
108
+ self._cache.clear()
109
+ self._expiry.clear()
110
+
111
+ # Initialize cache
112
+ cache = SimpleCache()
113
+
114
+ def get_host_url(request) -> str:
115
+ """
116
+ Get the host URL from a request object.
117
+ """
118
+ host = request.headers.get("host", "localhost")
119
+ scheme = request.headers.get("x-forwarded-proto", "http")
120
+ return f"{scheme}://{host}"
121
+
122
+ def format_time(timestamp):
123
+ """
124
+ Format a timestamp into a human-readable string.
125
+ """
126
+ return timestamp.strftime("%Y-%m-%d %H:%M:%S")
docker-compose.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3'
2
+
3
+ services:
4
+ backend:
5
+ build:
6
+ context: .
7
+ dockerfile: Dockerfile
8
+ ports:
9
+ - "7860:7860"
10
+ env_file:
11
+ - .env
12
+ restart: unless-stopped
13
+ healthcheck:
14
+ test: ["CMD", "curl", "-f", "http://localhost:7860/health"]
15
+ interval: 30s
16
+ timeout: 10s
17
+ retries: 3
18
+ start_period: 40s
19
+ volumes:
20
+ - ./app:/app/app
21
+ command: uvicorn app:app --host 0.0.0.0 --port 7860 --reload
pytest.ini ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [pytest]
2
+ # Bỏ qua cảnh báo về anyio module và các cảnh báo vận hành nội bộ
3
+ filterwarnings =
4
+ ignore::pytest.PytestAssertRewriteWarning:.*anyio
5
+ ignore:.*general_plain_validator_function.* is deprecated.*:DeprecationWarning
6
+ ignore:.*with_info_plain_validator_function.*:DeprecationWarning
7
+
8
+ # Cấu hình cơ bản khác
9
+ testpaths = tests
10
+ python_files = test_*.py
11
+ python_classes = Test*
12
+ python_functions = test_*
requirements.txt ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FastAPI
2
+ fastapi==0.103.1
3
+ uvicorn[standard]==0.23.2
4
+ pydantic==2.4.2
5
+ python-dotenv==1.0.0
6
+ websockets==11.0.3
7
+
8
+ # MongoDB
9
+ pymongo==4.6.1
10
+ dnspython==2.4.2
11
+
12
+ # PostgreSQL
13
+ sqlalchemy==2.0.20
14
+ pydantic-settings==2.0.3
15
+ psycopg2-binary==2.9.7
16
+
17
+ # Pinecone & RAG
18
+ pinecone-client==3.0.0
19
+ langchain==0.1.4
20
+ langchain-core==0.1.19
21
+ langchain-community==0.0.14
22
+ langchain-google-genai==0.0.5
23
+ langchain-pinecone==0.0.1
24
+ faiss-cpu==1.7.4
25
+ google-generativeai==0.3.1
26
+
27
+ # Extras
28
+ pytz==2023.3
29
+ python-multipart==0.0.6
30
+ httpx==0.25.1
31
+ requests==2.31.0
32
+ beautifulsoup4==4.12.2
33
+ redis==5.0.1
34
+
35
+ # Testing
36
+ prometheus-client==0.17.1
37
+ pytest==7.4.0
38
+ pytest-cov==4.1.0
39
+ watchfiles==0.21.0
40
+
41
+ # Core dependencies
42
+ starlette==0.27.0
43
+ psutil==5.9.6