first commit
Browse files- .dockerignore +46 -0
- .env.example +26 -0
- .gitattributes +29 -35
- .gitignore +79 -0
- Dockerfile +31 -0
- README.md +355 -4
- api_documentation.txt +318 -0
- app.py +197 -0
- app/__init__.py +23 -0
- app/api/__init__.py +1 -0
- app/api/mongodb_routes.py +276 -0
- app/api/postgresql_routes.py +626 -0
- app/api/rag_routes.py +538 -0
- app/api/websocket_routes.py +306 -0
- app/database/__init__.py +1 -0
- app/database/models.py +165 -0
- app/database/mongodb.py +200 -0
- app/database/pinecone.py +581 -0
- app/database/postgresql.py +104 -0
- app/models/__init__.py +1 -0
- app/models/mongodb_models.py +55 -0
- app/models/rag_models.py +68 -0
- app/utils/__init__.py +1 -0
- app/utils/debug_utils.py +207 -0
- app/utils/middleware.py +109 -0
- app/utils/utils.py +126 -0
- docker-compose.yml +21 -0
- pytest.ini +12 -0
- requirements.txt +43 -0
.dockerignore
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Git
|
2 |
+
.git
|
3 |
+
.gitignore
|
4 |
+
.gitattributes
|
5 |
+
|
6 |
+
# Environment files
|
7 |
+
.env
|
8 |
+
.env.*
|
9 |
+
!.env.example
|
10 |
+
|
11 |
+
# Python cache files
|
12 |
+
__pycache__/
|
13 |
+
*.py[cod]
|
14 |
+
*$py.class
|
15 |
+
*.so
|
16 |
+
.Python
|
17 |
+
.pytest_cache/
|
18 |
+
*.egg-info/
|
19 |
+
.installed.cfg
|
20 |
+
*.egg
|
21 |
+
|
22 |
+
# Logs
|
23 |
+
*.log
|
24 |
+
|
25 |
+
# Tests
|
26 |
+
tests/
|
27 |
+
|
28 |
+
# Docker related
|
29 |
+
Dockerfile
|
30 |
+
docker-compose.yml
|
31 |
+
.dockerignore
|
32 |
+
|
33 |
+
# Other files
|
34 |
+
.vscode/
|
35 |
+
.idea/
|
36 |
+
*.swp
|
37 |
+
*.swo
|
38 |
+
.DS_Store
|
39 |
+
.coverage
|
40 |
+
htmlcov/
|
41 |
+
.mypy_cache/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
instance/
|
45 |
+
.webassets-cache
|
46 |
+
main.py
|
.env.example
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PostgreSQL Configuration
|
2 |
+
DB_CONNECTION_MODE=aiven
|
3 |
+
AIVEN_DB_URL=postgresql://username:password@host:port/dbname?sslmode=require
|
4 |
+
|
5 |
+
# MongoDB Configuration
|
6 |
+
MONGODB_URL=mongodb+srv://username:[email protected]/?retryWrites=true&w=majority
|
7 |
+
DB_NAME=Telegram
|
8 |
+
COLLECTION_NAME=session_chat
|
9 |
+
|
10 |
+
# Pinecone configuration
|
11 |
+
PINECONE_API_KEY=your-pinecone-api-key
|
12 |
+
PINECONE_INDEX_NAME=your-pinecone-index-name
|
13 |
+
PINECONE_ENVIRONMENT=gcp-starter
|
14 |
+
|
15 |
+
# Google Gemini API key
|
16 |
+
GOOGLE_API_KEY=your-google-api-key
|
17 |
+
|
18 |
+
# WebSocket configuration
|
19 |
+
WEBSOCKET_SERVER=localhost
|
20 |
+
WEBSOCKET_PORT=7860
|
21 |
+
WEBSOCKET_PATH=/notify
|
22 |
+
|
23 |
+
# Application settings
|
24 |
+
ENVIRONMENT=production
|
25 |
+
DEBUG=false
|
26 |
+
PORT=7860
|
.gitattributes
CHANGED
@@ -1,35 +1,29 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
*.
|
6 |
-
*.
|
7 |
-
*.
|
8 |
-
*.
|
9 |
-
*.
|
10 |
-
*.
|
11 |
-
*.
|
12 |
-
*.
|
13 |
-
|
14 |
-
|
15 |
-
*.
|
16 |
-
*.
|
17 |
-
*.
|
18 |
-
*.
|
19 |
-
*.
|
20 |
-
*.
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
# Auto detect text files and perform LF normalization
|
2 |
+
* text=auto eol=lf
|
3 |
+
|
4 |
+
# Documents
|
5 |
+
*.md text
|
6 |
+
*.txt text
|
7 |
+
*.ini text
|
8 |
+
*.yaml text
|
9 |
+
*.yml text
|
10 |
+
*.json text
|
11 |
+
*.py text
|
12 |
+
*.env.example text
|
13 |
+
|
14 |
+
# Binary files
|
15 |
+
*.png binary
|
16 |
+
*.jpg binary
|
17 |
+
*.jpeg binary
|
18 |
+
*.gif binary
|
19 |
+
*.ico binary
|
20 |
+
*.db binary
|
21 |
+
|
22 |
+
# Git related files
|
23 |
+
.gitignore text
|
24 |
+
.gitattributes text
|
25 |
+
|
26 |
+
# Docker related files
|
27 |
+
Dockerfile text
|
28 |
+
docker-compose.yml text
|
29 |
+
.dockerignore text
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
build/
|
8 |
+
develop-eggs/
|
9 |
+
dist/
|
10 |
+
downloads/
|
11 |
+
eggs/
|
12 |
+
.eggs/
|
13 |
+
lib/
|
14 |
+
lib64/
|
15 |
+
parts/
|
16 |
+
sdist/
|
17 |
+
var/
|
18 |
+
wheels/
|
19 |
+
*.egg-info/
|
20 |
+
.installed.cfg
|
21 |
+
*.egg
|
22 |
+
.pytest_cache/
|
23 |
+
htmlcov/
|
24 |
+
.coverage
|
25 |
+
.coverage.*
|
26 |
+
.cache/
|
27 |
+
coverage.xml
|
28 |
+
*.cover
|
29 |
+
.mypy_cache/
|
30 |
+
|
31 |
+
# Environment
|
32 |
+
.env
|
33 |
+
.venv
|
34 |
+
env/
|
35 |
+
venv/
|
36 |
+
ENV/
|
37 |
+
|
38 |
+
# VSCode
|
39 |
+
.vscode/
|
40 |
+
*.code-workspace
|
41 |
+
.history/
|
42 |
+
|
43 |
+
# PyCharm
|
44 |
+
.idea/
|
45 |
+
*.iml
|
46 |
+
*.iws
|
47 |
+
*.ipr
|
48 |
+
*.iws
|
49 |
+
out/
|
50 |
+
.idea_modules/
|
51 |
+
|
52 |
+
# Logs and databases
|
53 |
+
*.log
|
54 |
+
*.sql
|
55 |
+
*.sqlite
|
56 |
+
*.db
|
57 |
+
|
58 |
+
# Tests
|
59 |
+
tests/
|
60 |
+
|
61 |
+
Admin_bot/
|
62 |
+
|
63 |
+
# Hugging Face Spaces
|
64 |
+
.gitattributes
|
65 |
+
|
66 |
+
# OS specific
|
67 |
+
.DS_Store
|
68 |
+
.DS_Store?
|
69 |
+
._*
|
70 |
+
.Spotlight-V100
|
71 |
+
.Trashes
|
72 |
+
Icon?
|
73 |
+
ehthumbs.db
|
74 |
+
Thumbs.db
|
75 |
+
|
76 |
+
# Project specific
|
77 |
+
*.log
|
78 |
+
.env
|
79 |
+
main.py
|
Dockerfile
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.11-slim
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
# Cài đặt các gói hệ thống cần thiết
|
6 |
+
RUN apt-get update && apt-get install -y \
|
7 |
+
build-essential \
|
8 |
+
curl \
|
9 |
+
software-properties-common \
|
10 |
+
git \
|
11 |
+
gcc \
|
12 |
+
python3-dev \
|
13 |
+
&& rm -rf /var/lib/apt/lists/*
|
14 |
+
|
15 |
+
# Sao chép các file yêu cầu trước để tận dụng cache của Docker
|
16 |
+
COPY requirements.txt .
|
17 |
+
|
18 |
+
# Cài đặt các gói Python
|
19 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
20 |
+
|
21 |
+
# Ensure langchain-core is installed
|
22 |
+
RUN pip install --no-cache-dir langchain-core==0.1.19
|
23 |
+
|
24 |
+
# Sao chép toàn bộ code vào container
|
25 |
+
COPY . .
|
26 |
+
|
27 |
+
# Mở cổng mà ứng dụng sẽ chạy
|
28 |
+
EXPOSE 7860
|
29 |
+
|
30 |
+
# Chạy ứng dụng với uvicorn
|
31 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
@@ -1,10 +1,361 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: docker
|
|
|
|
|
7 |
pinned: false
|
8 |
---
|
9 |
|
10 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: PIX Project Backend
|
3 |
+
emoji: 🤖
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: green
|
6 |
sdk: docker
|
7 |
+
sdk_version: "3.0.0"
|
8 |
+
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
13 |
+
|
14 |
+
# PIX Project Backend
|
15 |
+
|
16 |
+
[](https://fastapi.tiangolo.com/)
|
17 |
+
[](https://www.python.org/)
|
18 |
+
[](https://huggingface.co/spaces)
|
19 |
+
|
20 |
+
Backend API for PIX Project with MongoDB, PostgreSQL and RAG integration. This project provides a comprehensive backend solution for managing FAQ items, emergency contacts, events, and a RAG-based question answering system.
|
21 |
+
|
22 |
+
## Features
|
23 |
+
|
24 |
+
- **MongoDB Integration**: Store user sessions and conversation history
|
25 |
+
- **PostgreSQL Integration**: Manage FAQ items, emergency contacts, and events
|
26 |
+
- **Pinecone Vector Database**: Store and retrieve vector embeddings for RAG
|
27 |
+
- **RAG Question Answering**: Answer questions using relevant information from the vector database
|
28 |
+
- **WebSocket Notifications**: Real-time notifications for Admin Bot
|
29 |
+
- **API Documentation**: Automatic OpenAPI documentation via Swagger
|
30 |
+
- **Docker Support**: Easy deployment using Docker
|
31 |
+
- **Auto-Debugging**: Built-in debugging, error tracking, and performance monitoring
|
32 |
+
|
33 |
+
## API Endpoints
|
34 |
+
|
35 |
+
### MongoDB Endpoints
|
36 |
+
|
37 |
+
- `POST /mongodb/session`: Create a new session record
|
38 |
+
- `PUT /mongodb/session/{session_id}/response`: Update a session with a response
|
39 |
+
- `GET /mongodb/history`: Get user conversation history
|
40 |
+
- `GET /mongodb/health`: Check MongoDB connection health
|
41 |
+
|
42 |
+
### PostgreSQL Endpoints
|
43 |
+
|
44 |
+
- `GET /postgres/health`: Check PostgreSQL connection health
|
45 |
+
- `GET /postgres/faq`: Get FAQ items
|
46 |
+
- `POST /postgres/faq`: Create a new FAQ item
|
47 |
+
- `GET /postgres/faq/{faq_id}`: Get a specific FAQ item
|
48 |
+
- `PUT /postgres/faq/{faq_id}`: Update a specific FAQ item
|
49 |
+
- `DELETE /postgres/faq/{faq_id}`: Delete a specific FAQ item
|
50 |
+
- `GET /postgres/emergency`: Get emergency contact items
|
51 |
+
- `POST /postgres/emergency`: Create a new emergency contact item
|
52 |
+
- `GET /postgres/emergency/{emergency_id}`: Get a specific emergency contact
|
53 |
+
- `GET /postgres/events`: Get event items
|
54 |
+
|
55 |
+
### RAG Endpoints
|
56 |
+
|
57 |
+
- `POST /rag/chat`: Get answer for a question using RAG
|
58 |
+
- `POST /rag/embedding`: Generate embedding for text
|
59 |
+
- `GET /rag/health`: Check RAG services health
|
60 |
+
|
61 |
+
### WebSocket Endpoints
|
62 |
+
|
63 |
+
- `WebSocket /notify`: Receive real-time notifications for new sessions
|
64 |
+
|
65 |
+
### Debug Endpoints (Available in Debug Mode Only)
|
66 |
+
|
67 |
+
- `GET /debug/config`: Get configuration information
|
68 |
+
- `GET /debug/system`: Get system information (CPU, memory, disk usage)
|
69 |
+
- `GET /debug/database`: Check all database connections
|
70 |
+
- `GET /debug/errors`: View recent error logs
|
71 |
+
- `GET /debug/performance`: Get performance metrics
|
72 |
+
- `GET /debug/full`: Get comprehensive debug information
|
73 |
+
|
74 |
+
## WebSocket API
|
75 |
+
|
76 |
+
### Notifications for New Sessions
|
77 |
+
|
78 |
+
The backend provides a WebSocket endpoint for receiving notifications about new sessions that match specific criteria.
|
79 |
+
|
80 |
+
#### WebSocket Endpoint Configuration
|
81 |
+
|
82 |
+
The WebSocket endpoint is configured using environment variables:
|
83 |
+
|
84 |
+
```
|
85 |
+
# WebSocket configuration
|
86 |
+
WEBSOCKET_SERVER=localhost
|
87 |
+
WEBSOCKET_PORT=7860
|
88 |
+
WEBSOCKET_PATH=/notify
|
89 |
+
```
|
90 |
+
|
91 |
+
The full WebSocket URL will be:
|
92 |
+
```
|
93 |
+
ws://{WEBSOCKET_SERVER}:{WEBSOCKET_PORT}{WEBSOCKET_PATH}
|
94 |
+
```
|
95 |
+
|
96 |
+
For example: `ws://localhost:7860/notify`
|
97 |
+
|
98 |
+
#### Notification Criteria
|
99 |
+
|
100 |
+
A notification is sent when:
|
101 |
+
1. A new session is created with `factor` set to "RAG"
|
102 |
+
2. The message content starts with "I don't know"
|
103 |
+
|
104 |
+
#### Notification Format
|
105 |
+
|
106 |
+
```json
|
107 |
+
{
|
108 |
+
"type": "new_session",
|
109 |
+
"timestamp": "2025-04-15 22:30:45",
|
110 |
+
"data": {
|
111 |
+
"session_id": "123e4567-e89b-12d3-a456-426614174000",
|
112 |
+
"factor": "rag",
|
113 |
+
"action": "asking_freely",
|
114 |
+
"created_at": "2025-04-15 22:30:45",
|
115 |
+
"first_name": "John",
|
116 |
+
"last_name": "Doe",
|
117 |
+
"message": "I don't know how to find emergency contacts",
|
118 |
+
"user_id": "12345678",
|
119 |
+
"username": "johndoe"
|
120 |
+
}
|
121 |
+
}
|
122 |
+
```
|
123 |
+
|
124 |
+
#### Usage Example
|
125 |
+
|
126 |
+
Admin Bot should establish a WebSocket connection to this endpoint using the configured URL:
|
127 |
+
|
128 |
+
```python
|
129 |
+
import websocket
|
130 |
+
import json
|
131 |
+
import os
|
132 |
+
from dotenv import load_dotenv
|
133 |
+
|
134 |
+
# Load environment variables
|
135 |
+
load_dotenv()
|
136 |
+
|
137 |
+
# Get WebSocket configuration from environment variables
|
138 |
+
WEBSOCKET_SERVER = os.getenv("WEBSOCKET_SERVER", "localhost")
|
139 |
+
WEBSOCKET_PORT = os.getenv("WEBSOCKET_PORT", "7860")
|
140 |
+
WEBSOCKET_PATH = os.getenv("WEBSOCKET_PATH", "/notify")
|
141 |
+
|
142 |
+
# Create full URL
|
143 |
+
ws_url = f"ws://{WEBSOCKET_SERVER}:{WEBSOCKET_PORT}{WEBSOCKET_PATH}"
|
144 |
+
|
145 |
+
def on_message(ws, message):
|
146 |
+
data = json.loads(message)
|
147 |
+
print(f"Received notification: {data}")
|
148 |
+
# Forward to Telegram Admin
|
149 |
+
|
150 |
+
def on_error(ws, error):
|
151 |
+
print(f"Error: {error}")
|
152 |
+
|
153 |
+
def on_close(ws, close_status_code, close_msg):
|
154 |
+
print("Connection closed")
|
155 |
+
|
156 |
+
def on_open(ws):
|
157 |
+
print("Connection opened")
|
158 |
+
# Send keepalive message periodically
|
159 |
+
ws.send("keepalive")
|
160 |
+
|
161 |
+
# Connect to WebSocket
|
162 |
+
ws = websocket.WebSocketApp(
|
163 |
+
ws_url,
|
164 |
+
on_open=on_open,
|
165 |
+
on_message=on_message,
|
166 |
+
on_error=on_error,
|
167 |
+
on_close=on_close
|
168 |
+
)
|
169 |
+
ws.run_forever()
|
170 |
+
```
|
171 |
+
|
172 |
+
When a notification is received, Admin Bot should forward the content to the Telegram Admin.
|
173 |
+
|
174 |
+
## Environment Variables
|
175 |
+
|
176 |
+
Create a `.env` file with the following variables:
|
177 |
+
|
178 |
+
```
|
179 |
+
# PostgreSQL Configuration
|
180 |
+
DB_CONNECTION_MODE=aiven
|
181 |
+
AIVEN_DB_URL=postgresql://username:password@host:port/dbname?sslmode=require
|
182 |
+
|
183 |
+
# MongoDB Configuration
|
184 |
+
MONGODB_URL=mongodb+srv://username:[email protected]/?retryWrites=true&w=majority
|
185 |
+
DB_NAME=Telegram
|
186 |
+
COLLECTION_NAME=session_chat
|
187 |
+
|
188 |
+
# Pinecone configuration
|
189 |
+
PINECONE_API_KEY=your-pinecone-api-key
|
190 |
+
PINECONE_INDEX_NAME=your-pinecone-index-name
|
191 |
+
PINECONE_ENVIRONMENT=gcp-starter
|
192 |
+
|
193 |
+
# Google Gemini API key
|
194 |
+
GOOGLE_API_KEY=your-google-api-key
|
195 |
+
|
196 |
+
# WebSocket configuration
|
197 |
+
WEBSOCKET_SERVER=localhost
|
198 |
+
WEBSOCKET_PORT=7860
|
199 |
+
WEBSOCKET_PATH=/notify
|
200 |
+
|
201 |
+
# Application settings
|
202 |
+
ENVIRONMENT=production
|
203 |
+
DEBUG=false
|
204 |
+
PORT=7860
|
205 |
+
```
|
206 |
+
|
207 |
+
## Installation and Setup
|
208 |
+
|
209 |
+
### Local Development
|
210 |
+
|
211 |
+
1. Clone the repository:
|
212 |
+
```bash
|
213 |
+
git clone https://github.com/ManTT-Data/PixAgent.git
|
214 |
+
cd PixAgent
|
215 |
+
```
|
216 |
+
|
217 |
+
2. Create a virtual environment and install dependencies:
|
218 |
+
```bash
|
219 |
+
python -m venv venv
|
220 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
221 |
+
pip install -r requirements.txt
|
222 |
+
```
|
223 |
+
|
224 |
+
3. Create a `.env` file with your configuration (see above)
|
225 |
+
|
226 |
+
4. Run the application:
|
227 |
+
```bash
|
228 |
+
uvicorn app:app --reload --port 7860
|
229 |
+
```
|
230 |
+
|
231 |
+
5. Open your browser and navigate to [http://localhost:7860/docs](http://localhost:7860/docs) to see the API documentation
|
232 |
+
|
233 |
+
### Docker Deployment
|
234 |
+
|
235 |
+
1. Build the Docker image:
|
236 |
+
```bash
|
237 |
+
docker build -t pix-project-backend .
|
238 |
+
```
|
239 |
+
|
240 |
+
2. Run the Docker container:
|
241 |
+
```bash
|
242 |
+
docker run -p 7860:7860 --env-file .env pix-project-backend
|
243 |
+
```
|
244 |
+
|
245 |
+
## Deployment to HuggingFace Spaces
|
246 |
+
|
247 |
+
1. Create a new Space on HuggingFace (Dockerfile type)
|
248 |
+
2. Link your GitHub repository or push directly to the HuggingFace repo
|
249 |
+
3. Add your environment variables in the Space settings
|
250 |
+
4. The deployment will use `app.py` as the entry point, which is the standard for HuggingFace Spaces
|
251 |
+
|
252 |
+
### Important Notes for HuggingFace Deployment
|
253 |
+
|
254 |
+
- The application uses `app.py` with the FastAPI instance named `app` to avoid the "Error loading ASGI app. Attribute 'app' not found in module 'app'" error
|
255 |
+
- Make sure all environment variables are set in the Space settings
|
256 |
+
- The Dockerfile is configured to expose port 7860, which is the default port for HuggingFace Spaces
|
257 |
+
|
258 |
+
## Project Structure
|
259 |
+
|
260 |
+
```
|
261 |
+
.
|
262 |
+
├── app # Main application package
|
263 |
+
│ ├── api # API endpoints
|
264 |
+
│ │ ├── mongodb_routes.py
|
265 |
+
│ │ ├── postgresql_routes.py
|
266 |
+
│ │ ├── rag_routes.py
|
267 |
+
│ │ └── websocket_routes.py
|
268 |
+
│ ├── database # Database connections
|
269 |
+
│ │ ├── mongodb.py
|
270 |
+
│ │ ├── pinecone.py
|
271 |
+
│ │ └── postgresql.py
|
272 |
+
│ ├── models # Pydantic models
|
273 |
+
│ │ ├── mongodb_models.py
|
274 |
+
│ │ ├── postgresql_models.py
|
275 |
+
│ │ └── rag_models.py
|
276 |
+
│ └── utils # Utility functions
|
277 |
+
│ ├── debug_utils.py
|
278 |
+
│ └── middleware.py
|
279 |
+
├── tests # Test directory
|
280 |
+
│ └── test_api_endpoints.py
|
281 |
+
├── .dockerignore # Docker ignore file
|
282 |
+
├── .env.example # Example environment file
|
283 |
+
├── .gitattributes # Git attributes
|
284 |
+
├── .gitignore # Git ignore file
|
285 |
+
├── app.py # Application entry point
|
286 |
+
├── docker-compose.yml # Docker compose configuration
|
287 |
+
├── Dockerfile # Docker configuration
|
288 |
+
├── pytest.ini # Pytest configuration
|
289 |
+
├── README.md # Project documentation
|
290 |
+
├── requirements.txt # Project dependencies
|
291 |
+
└── api_documentation.txt # API documentation for frontend engineers
|
292 |
+
```
|
293 |
+
|
294 |
+
## License
|
295 |
+
|
296 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
297 |
+
|
298 |
+
# Advanced Retrieval System
|
299 |
+
|
300 |
+
This project now features an enhanced vector retrieval system that improves the quality and relevance of information retrieved from Pinecone using threshold-based filtering and multiple similarity metrics.
|
301 |
+
|
302 |
+
## Features
|
303 |
+
|
304 |
+
### 1. Threshold-Based Retrieval
|
305 |
+
|
306 |
+
The system implements a threshold-based approach to vector retrieval, which:
|
307 |
+
- Retrieves a larger candidate set from the vector database
|
308 |
+
- Applies a similarity threshold to filter out less relevant results
|
309 |
+
- Returns only the most relevant documents that exceed the threshold
|
310 |
+
|
311 |
+
### 2. Multiple Similarity Metrics
|
312 |
+
|
313 |
+
The system supports multiple similarity metrics:
|
314 |
+
- **Cosine Similarity** (default): Measures the cosine of the angle between vectors
|
315 |
+
- **Dot Product**: Calculates the dot product between vectors
|
316 |
+
- **Euclidean Distance**: Measures the straight-line distance between vectors
|
317 |
+
|
318 |
+
Each metric has different characteristics and may perform better for different types of data and queries.
|
319 |
+
|
320 |
+
### 3. Score Normalization
|
321 |
+
|
322 |
+
For metrics like Euclidean distance where lower values indicate higher similarity, the system automatically normalizes scores to a 0-1 scale where higher values always indicate higher similarity. This makes it easier to compare results across different metrics.
|
323 |
+
|
324 |
+
## Configuration
|
325 |
+
|
326 |
+
The retrieval system can be configured through environment variables:
|
327 |
+
|
328 |
+
```
|
329 |
+
# Pinecone retrieval configuration
|
330 |
+
PINECONE_DEFAULT_LIMIT_K=10 # Maximum number of candidates to retrieve
|
331 |
+
PINECONE_DEFAULT_TOP_K=6 # Number of results to return after filtering
|
332 |
+
PINECONE_DEFAULT_SIMILARITY_METRIC=cosine # Default similarity metric
|
333 |
+
PINECONE_DEFAULT_SIMILARITY_THRESHOLD=0.75 # Similarity threshold (0-1)
|
334 |
+
PINECONE_ALLOWED_METRICS=cosine,dotproduct,euclidean # Available metrics
|
335 |
+
```
|
336 |
+
|
337 |
+
## API Usage
|
338 |
+
|
339 |
+
You can customize the retrieval parameters when making API requests:
|
340 |
+
|
341 |
+
```json
|
342 |
+
{
|
343 |
+
"user_id": "user123",
|
344 |
+
"question": "What are the best restaurants in Da Nang?",
|
345 |
+
"similarity_top_k": 5,
|
346 |
+
"limit_k": 15,
|
347 |
+
"similarity_metric": "cosine",
|
348 |
+
"similarity_threshold": 0.8
|
349 |
+
}
|
350 |
+
```
|
351 |
+
|
352 |
+
## Benefits
|
353 |
+
|
354 |
+
1. **Quality Improvement**: Retrieves only the most relevant documents above a certain quality threshold
|
355 |
+
2. **Flexibility**: Different similarity metrics can be used for different types of queries
|
356 |
+
3. **Efficiency**: Avoids processing irrelevant documents, improving response time
|
357 |
+
4. **Configurability**: All parameters can be adjusted via environment variables or at request time
|
358 |
+
|
359 |
+
## Implementation Details
|
360 |
+
|
361 |
+
The system is implemented as a custom retriever class `ThresholdRetriever` that integrates with LangChain's retrieval infrastructure while providing enhanced functionality.
|
api_documentation.txt
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Frontend Integration Guide for PixAgent API
|
2 |
+
|
3 |
+
This guide provides instructions for integrating with the optimized PostgreSQL-based API endpoints for Event, FAQ, and Emergency data.
|
4 |
+
|
5 |
+
## API Endpoints
|
6 |
+
|
7 |
+
### Events
|
8 |
+
|
9 |
+
| Endpoint | Method | Description |
|
10 |
+
|----------|--------|-------------|
|
11 |
+
| /postgres/events/ | GET | Fetch all events (with optional filtering) |
|
12 |
+
| /postgres/events/{event_id} | GET | Fetch a specific event by ID |
|
13 |
+
| /postgres/events/featured | GET | Fetch featured events |
|
14 |
+
| /postgres/events/ | POST | Create a new event |
|
15 |
+
| /postgres/events/{event_id} | PUT | Update an existing event |
|
16 |
+
| /postgres/events/{event_id} | DELETE | Delete an event |
|
17 |
+
|
18 |
+
### FAQs
|
19 |
+
|
20 |
+
| Endpoint | Method | Description |
|
21 |
+
|----------|--------|-------------|
|
22 |
+
| /postgres/faqs/ | GET | Fetch all FAQs |
|
23 |
+
| /postgres/faqs/{faq_id} | GET | Fetch a specific FAQ by ID |
|
24 |
+
| /postgres/faqs/ | POST | Create a new FAQ |
|
25 |
+
| /postgres/faqs/{faq_id} | PUT | Update an existing FAQ |
|
26 |
+
| /postgres/faqs/{faq_id} | DELETE | Delete a FAQ |
|
27 |
+
|
28 |
+
### Emergency Contacts
|
29 |
+
|
30 |
+
| Endpoint | Method | Description |
|
31 |
+
|----------|--------|-------------|
|
32 |
+
| /postgres/emergencies/ | GET | Fetch all emergency contacts |
|
33 |
+
| /postgres/emergencies/{emergency_id} | GET | Fetch a specific emergency contact by ID |
|
34 |
+
| /postgres/emergencies/ | POST | Create a new emergency contact |
|
35 |
+
| /postgres/emergencies/{emergency_id} | PUT | Update an existing emergency contact |
|
36 |
+
| /postgres/emergencies/{emergency_id} | DELETE | Delete an emergency contact |
|
37 |
+
|
38 |
+
## Response Models
|
39 |
+
|
40 |
+
### Event Response Model
|
41 |
+
|
42 |
+
interface EventResponse {
|
43 |
+
id: number;
|
44 |
+
name: string;
|
45 |
+
description: string;
|
46 |
+
date_start: string; // ISO format date
|
47 |
+
date_end: string; // ISO format date
|
48 |
+
location: string;
|
49 |
+
image_url: string;
|
50 |
+
price: {
|
51 |
+
currency: string;
|
52 |
+
amount: string;
|
53 |
+
};
|
54 |
+
featured: boolean;
|
55 |
+
is_active: boolean;
|
56 |
+
created_at: string; // ISO format date
|
57 |
+
updated_at: string; // ISO format date
|
58 |
+
}
|
59 |
+
|
60 |
+
### FAQ Response Model
|
61 |
+
|
62 |
+
interface FaqResponse {
|
63 |
+
id: number;
|
64 |
+
question: string;
|
65 |
+
answer: string;
|
66 |
+
is_active: boolean;
|
67 |
+
created_at: string; // ISO format date
|
68 |
+
updated_at: string; // ISO format date
|
69 |
+
}
|
70 |
+
|
71 |
+
### Emergency Response Model
|
72 |
+
|
73 |
+
interface EmergencyResponse {
|
74 |
+
id: number;
|
75 |
+
name: string;
|
76 |
+
phone_number: string;
|
77 |
+
description: string;
|
78 |
+
address: string;
|
79 |
+
priority: number;
|
80 |
+
is_active: boolean;
|
81 |
+
created_at: string; // ISO format date
|
82 |
+
updated_at: string; // ISO format date
|
83 |
+
}
|
84 |
+
|
85 |
+
## Example Usage (React)
|
86 |
+
|
87 |
+
### Fetching Events
|
88 |
+
|
89 |
+
import { useState, useEffect } from 'react';
|
90 |
+
import axios from 'axios';
|
91 |
+
|
92 |
+
const API_BASE_URL = 'http://localhost:8000';
|
93 |
+
|
94 |
+
function EventList() {
|
95 |
+
const [events, setEvents] = useState([]);
|
96 |
+
const [loading, setLoading] = useState(true);
|
97 |
+
const [error, setError] = useState(null);
|
98 |
+
|
99 |
+
useEffect(() => {
|
100 |
+
const fetchEvents = async () => {
|
101 |
+
try {
|
102 |
+
setLoading(true);
|
103 |
+
const response = await axios.get(`${API_BASE_URL}/postgres/events/`);
|
104 |
+
setEvents(response.data);
|
105 |
+
setLoading(false);
|
106 |
+
} catch (err) {
|
107 |
+
setError('Failed to fetch events');
|
108 |
+
setLoading(false);
|
109 |
+
console.error('Error fetching events:', err);
|
110 |
+
}
|
111 |
+
};
|
112 |
+
|
113 |
+
fetchEvents();
|
114 |
+
}, []);
|
115 |
+
|
116 |
+
if (loading) return <p>Loading events...</p>;
|
117 |
+
if (error) return <p>{error}</p>;
|
118 |
+
|
119 |
+
return (
|
120 |
+
<div>
|
121 |
+
<h1>Events</h1>
|
122 |
+
<div className="event-list">
|
123 |
+
{events.map(event => (
|
124 |
+
<div key={event.id} className="event-card">
|
125 |
+
<h2>{event.name}</h2>
|
126 |
+
<p>{event.description}</p>
|
127 |
+
<p>
|
128 |
+
<strong>When:</strong> {new Date(event.date_start).toLocaleDateString()} - {new Date(event.date_end).toLocaleDateString()}
|
129 |
+
</p>
|
130 |
+
<p><strong>Where:</strong> {event.location}</p>
|
131 |
+
<p><strong>Price:</strong> {event.price.amount} {event.price.currency}</p>
|
132 |
+
{event.featured && <span className="featured-badge">Featured</span>}
|
133 |
+
</div>
|
134 |
+
))}
|
135 |
+
</div>
|
136 |
+
</div>
|
137 |
+
);
|
138 |
+
}
|
139 |
+
|
140 |
+
### Creating an Event
|
141 |
+
|
142 |
+
import { useState } from 'react';
|
143 |
+
import axios from 'axios';
|
144 |
+
|
145 |
+
function CreateEvent() {
|
146 |
+
const [eventData, setEventData] = useState({
|
147 |
+
name: '',
|
148 |
+
description: '',
|
149 |
+
date_start: '',
|
150 |
+
date_end: '',
|
151 |
+
location: '',
|
152 |
+
image_url: '',
|
153 |
+
price: {
|
154 |
+
currency: 'USD',
|
155 |
+
amount: '0'
|
156 |
+
},
|
157 |
+
featured: false,
|
158 |
+
is_active: true
|
159 |
+
});
|
160 |
+
const [loading, setLoading] = useState(false);
|
161 |
+
const [error, setError] = useState(null);
|
162 |
+
const [success, setSuccess] = useState(false);
|
163 |
+
|
164 |
+
const handleChange = (e) => {
|
165 |
+
const { name, value, type, checked } = e.target;
|
166 |
+
|
167 |
+
if (name === 'price_amount') {
|
168 |
+
setEventData(prev => ({
|
169 |
+
...prev,
|
170 |
+
price: {
|
171 |
+
...prev.price,
|
172 |
+
amount: value
|
173 |
+
}
|
174 |
+
}));
|
175 |
+
} else if (name === 'price_currency') {
|
176 |
+
setEventData(prev => ({
|
177 |
+
...prev,
|
178 |
+
price: {
|
179 |
+
...prev.price,
|
180 |
+
currency: value
|
181 |
+
}
|
182 |
+
}));
|
183 |
+
} else {
|
184 |
+
setEventData(prev => ({
|
185 |
+
...prev,
|
186 |
+
[name]: type === 'checkbox' ? checked : value
|
187 |
+
}));
|
188 |
+
}
|
189 |
+
};
|
190 |
+
|
191 |
+
const handleSubmit = async (e) => {
|
192 |
+
e.preventDefault();
|
193 |
+
try {
|
194 |
+
setLoading(true);
|
195 |
+
setError(null);
|
196 |
+
setSuccess(false);
|
197 |
+
|
198 |
+
const response = await axios.post(`${API_BASE_URL}/postgres/events/`, eventData);
|
199 |
+
setSuccess(true);
|
200 |
+
setEventData({
|
201 |
+
name: '',
|
202 |
+
description: '',
|
203 |
+
date_start: '',
|
204 |
+
date_end: '',
|
205 |
+
location: '',
|
206 |
+
image_url: '',
|
207 |
+
price: {
|
208 |
+
currency: 'USD',
|
209 |
+
amount: '0'
|
210 |
+
},
|
211 |
+
featured: false,
|
212 |
+
is_active: true
|
213 |
+
});
|
214 |
+
setLoading(false);
|
215 |
+
} catch (err) {
|
216 |
+
setError('Failed to create event');
|
217 |
+
setLoading(false);
|
218 |
+
console.error('Error creating event:', err);
|
219 |
+
}
|
220 |
+
};
|
221 |
+
|
222 |
+
return (
|
223 |
+
<div>
|
224 |
+
<h1>Create New Event</h1>
|
225 |
+
{success && <div className="success-message">Event created successfully!</div>}
|
226 |
+
{error && <div className="error-message">{error}</div>}
|
227 |
+
<form onSubmit={handleSubmit}>
|
228 |
+
{/* Form fields would go here */}
|
229 |
+
<button type="submit" disabled={loading}>
|
230 |
+
{loading ? 'Creating...' : 'Create Event'}
|
231 |
+
</button>
|
232 |
+
</form>
|
233 |
+
</div>
|
234 |
+
);
|
235 |
+
}
|
236 |
+
|
237 |
+
## Performance Optimizations
|
238 |
+
|
239 |
+
The API now includes several performance optimizations:
|
240 |
+
|
241 |
+
### Caching
|
242 |
+
|
243 |
+
The server implements caching for read operations, which significantly improves response times for repeated requests. The average cache improvement is over 70%.
|
244 |
+
|
245 |
+
Frontend considerations:
|
246 |
+
No need to implement client-side caching for data that doesn't change frequently
|
247 |
+
For real-time data, consider adding a refresh button in the UI
|
248 |
+
If data might be updated by other users, consider adding a polling mechanism or websocket for updates
|
249 |
+
|
250 |
+
|
251 |
+
### Error Handling
|
252 |
+
|
253 |
+
The API returns standardized error responses. Example:
|
254 |
+
|
255 |
+
async function fetchData(url) {
|
256 |
+
try {
|
257 |
+
const response = await fetch(url);
|
258 |
+
if (!response.ok) {
|
259 |
+
const errorData = await response.json();
|
260 |
+
throw new Error(errorData.detail || 'An error occurred');
|
261 |
+
}
|
262 |
+
return await response.json();
|
263 |
+
} catch (error) {
|
264 |
+
console.error('API request failed:', error);
|
265 |
+
// Handle error in UI
|
266 |
+
return null;
|
267 |
+
}
|
268 |
+
}
|
269 |
+
|
270 |
+
### Price Field Handling
|
271 |
+
|
272 |
+
The price field of events is a JSON object with currency and amount properties. When creating or updating events, ensure this is properly formatted:
|
273 |
+
|
274 |
+
// Correct format for price field
|
275 |
+
const eventData = {
|
276 |
+
// other fields...
|
277 |
+
price: {
|
278 |
+
currency: 'USD',
|
279 |
+
amount: '10.99'
|
280 |
+
}
|
281 |
+
};
|
282 |
+
|
283 |
+
// When displaying price
|
284 |
+
function formatPrice(price) {
|
285 |
+
if (!price) return 'Free';
|
286 |
+
if (typeof price === 'string') {
|
287 |
+
try {
|
288 |
+
price = JSON.parse(price);
|
289 |
+
} catch {
|
290 |
+
return price;
|
291 |
+
}
|
292 |
+
}
|
293 |
+
return `${price.amount} ${price.currency}`;
|
294 |
+
}
|
295 |
+
|
296 |
+
## CORS Configuration
|
297 |
+
|
298 |
+
The API has CORS enabled for frontend applications. If you're experiencing CORS issues, ensure your frontend domain is allowed in the server configuration.
|
299 |
+
|
300 |
+
For local development, the following origins are typically allowed:
|
301 |
+
- http://localhost:3000
|
302 |
+
- http://localhost:5000
|
303 |
+
- http://localhost:8080
|
304 |
+
|
305 |
+
## Status Codes
|
306 |
+
|
307 |
+
| Status Code | Description |
|
308 |
+
|-------------|-------------|
|
309 |
+
| 200 | Success - The request was successful |
|
310 |
+
| 201 | Created - A new resource was successfully created |
|
311 |
+
| 400 | Bad Request - The request could not be understood or was missing required parameters |
|
312 |
+
| 404 | Not Found - Resource not found |
|
313 |
+
| 422 | Validation Error - Request data failed validation |
|
314 |
+
| 500 | Internal Server Error - An error occurred on the server |
|
315 |
+
|
316 |
+
## Questions?
|
317 |
+
|
318 |
+
For further inquiries about the API, please contact the development team.
|
app.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Depends, Request, HTTPException, status
|
2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
from contextlib import asynccontextmanager
|
4 |
+
import uvicorn
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import logging
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
|
10 |
+
# Cấu hình logging
|
11 |
+
logging.basicConfig(
|
12 |
+
level=logging.INFO,
|
13 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
14 |
+
handlers=[
|
15 |
+
logging.StreamHandler(sys.stdout),
|
16 |
+
]
|
17 |
+
)
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
# Load environment variables
|
21 |
+
load_dotenv()
|
22 |
+
DEBUG = os.getenv("DEBUG", "False").lower() in ("true", "1", "t")
|
23 |
+
|
24 |
+
# Kiểm tra các biến môi trường bắt buộc
|
25 |
+
required_env_vars = [
|
26 |
+
"AIVEN_DB_URL",
|
27 |
+
"MONGODB_URL",
|
28 |
+
"PINECONE_API_KEY",
|
29 |
+
"PINECONE_INDEX_NAME",
|
30 |
+
"GOOGLE_API_KEY"
|
31 |
+
]
|
32 |
+
|
33 |
+
missing_vars = [var for var in required_env_vars if not os.getenv(var)]
|
34 |
+
if missing_vars:
|
35 |
+
logger.error(f"Missing required environment variables: {', '.join(missing_vars)}")
|
36 |
+
if not DEBUG: # Chỉ thoát nếu không ở chế độ debug
|
37 |
+
sys.exit(1)
|
38 |
+
|
39 |
+
# Database health checks
|
40 |
+
def check_database_connections():
|
41 |
+
"""Kiểm tra kết nối các database khi khởi động"""
|
42 |
+
from app.database.postgresql import check_db_connection as check_postgresql
|
43 |
+
from app.database.mongodb import check_db_connection as check_mongodb
|
44 |
+
from app.database.pinecone import check_db_connection as check_pinecone
|
45 |
+
|
46 |
+
db_status = {
|
47 |
+
"postgresql": check_postgresql(),
|
48 |
+
"mongodb": check_mongodb(),
|
49 |
+
"pinecone": check_pinecone()
|
50 |
+
}
|
51 |
+
|
52 |
+
all_ok = all(db_status.values())
|
53 |
+
if not all_ok:
|
54 |
+
failed_dbs = [name for name, status in db_status.items() if not status]
|
55 |
+
logger.error(f"Failed to connect to databases: {', '.join(failed_dbs)}")
|
56 |
+
if not DEBUG: # Chỉ thoát nếu không ở chế độ debug
|
57 |
+
sys.exit(1)
|
58 |
+
|
59 |
+
return db_status
|
60 |
+
|
61 |
+
# Khởi tạo lifespan để kiểm tra kết nối database khi khởi động
|
62 |
+
@asynccontextmanager
|
63 |
+
async def lifespan(app: FastAPI):
|
64 |
+
# Startup: kiểm tra kết nối các database
|
65 |
+
logger.info("Starting application...")
|
66 |
+
db_status = check_database_connections()
|
67 |
+
if all(db_status.values()):
|
68 |
+
logger.info("All database connections are working")
|
69 |
+
|
70 |
+
# Khởi tạo bảng trong cơ sở dữ liệu (nếu chưa tồn tại)
|
71 |
+
if DEBUG: # Chỉ khởi tạo bảng trong chế độ debug
|
72 |
+
from app.database.postgresql import create_tables
|
73 |
+
if create_tables():
|
74 |
+
logger.info("Database tables created or already exist")
|
75 |
+
|
76 |
+
yield
|
77 |
+
|
78 |
+
# Shutdown
|
79 |
+
logger.info("Shutting down application...")
|
80 |
+
|
81 |
+
# Import routers
|
82 |
+
try:
|
83 |
+
from app.api.mongodb_routes import router as mongodb_router
|
84 |
+
from app.api.postgresql_routes import router as postgresql_router
|
85 |
+
from app.api.rag_routes import router as rag_router
|
86 |
+
from app.api.websocket_routes import router as websocket_router
|
87 |
+
|
88 |
+
# Import middlewares
|
89 |
+
from app.utils.middleware import RequestLoggingMiddleware, ErrorHandlingMiddleware, DatabaseCheckMiddleware
|
90 |
+
|
91 |
+
# Import debug utilities
|
92 |
+
from app.utils.debug_utils import debug_view, DebugInfo, error_tracker, performance_monitor
|
93 |
+
|
94 |
+
except ImportError as e:
|
95 |
+
logger.error(f"Error importing routes or middlewares: {e}")
|
96 |
+
raise
|
97 |
+
|
98 |
+
# Create FastAPI app
|
99 |
+
app = FastAPI(
|
100 |
+
title="PIX Project Backend API",
|
101 |
+
description="Backend API for PIX Project with MongoDB, PostgreSQL and RAG integration",
|
102 |
+
version="1.0.0",
|
103 |
+
docs_url="/docs",
|
104 |
+
redoc_url="/redoc",
|
105 |
+
debug=DEBUG,
|
106 |
+
lifespan=lifespan,
|
107 |
+
)
|
108 |
+
|
109 |
+
# Configure CORS
|
110 |
+
app.add_middleware(
|
111 |
+
CORSMiddleware,
|
112 |
+
allow_origins=["*"],
|
113 |
+
allow_credentials=True,
|
114 |
+
allow_methods=["*"],
|
115 |
+
allow_headers=["*"],
|
116 |
+
)
|
117 |
+
|
118 |
+
# Thêm middlewares
|
119 |
+
app.add_middleware(ErrorHandlingMiddleware)
|
120 |
+
app.add_middleware(RequestLoggingMiddleware)
|
121 |
+
if not DEBUG: # Chỉ thêm middleware kiểm tra database trong production
|
122 |
+
app.add_middleware(DatabaseCheckMiddleware)
|
123 |
+
|
124 |
+
# Include routers
|
125 |
+
app.include_router(mongodb_router)
|
126 |
+
app.include_router(postgresql_router)
|
127 |
+
app.include_router(rag_router)
|
128 |
+
app.include_router(websocket_router)
|
129 |
+
|
130 |
+
# Root endpoint
|
131 |
+
@app.get("/")
|
132 |
+
def read_root():
|
133 |
+
return {
|
134 |
+
"message": "Welcome to PIX Project Backend API",
|
135 |
+
"documentation": "/docs",
|
136 |
+
}
|
137 |
+
|
138 |
+
# Health check endpoint
|
139 |
+
@app.get("/health")
|
140 |
+
def health_check():
|
141 |
+
# Kiểm tra kết nối database
|
142 |
+
db_status = check_database_connections()
|
143 |
+
all_db_ok = all(db_status.values())
|
144 |
+
|
145 |
+
return {
|
146 |
+
"status": "healthy" if all_db_ok else "degraded",
|
147 |
+
"version": "1.0.0",
|
148 |
+
"environment": os.environ.get("ENVIRONMENT", "production"),
|
149 |
+
"databases": db_status
|
150 |
+
}
|
151 |
+
|
152 |
+
# Debug endpoints (chỉ có trong chế độ debug)
|
153 |
+
if DEBUG:
|
154 |
+
@app.get("/debug/config")
|
155 |
+
def debug_config():
|
156 |
+
"""Hiển thị thông tin cấu hình (chỉ trong chế độ debug)"""
|
157 |
+
config = {
|
158 |
+
"environment": os.environ.get("ENVIRONMENT", "production"),
|
159 |
+
"debug": DEBUG,
|
160 |
+
"db_connection_mode": os.environ.get("DB_CONNECTION_MODE", "aiven"),
|
161 |
+
"databases": {
|
162 |
+
"postgresql": os.environ.get("AIVEN_DB_URL", "").split("@")[1].split("/")[0] if "@" in os.environ.get("AIVEN_DB_URL", "") else "N/A",
|
163 |
+
"mongodb": os.environ.get("MONGODB_URL", "").split("@")[1].split("/?")[0] if "@" in os.environ.get("MONGODB_URL", "") else "N/A",
|
164 |
+
"pinecone": os.environ.get("PINECONE_INDEX_NAME", "N/A"),
|
165 |
+
}
|
166 |
+
}
|
167 |
+
return config
|
168 |
+
|
169 |
+
@app.get("/debug/system")
|
170 |
+
def debug_system():
|
171 |
+
"""Hiển thị thông tin hệ thống (chỉ trong chế độ debug)"""
|
172 |
+
return DebugInfo.get_system_info()
|
173 |
+
|
174 |
+
@app.get("/debug/database")
|
175 |
+
def debug_database():
|
176 |
+
"""Hiển thị trạng thái database (chỉ trong chế độ debug)"""
|
177 |
+
return DebugInfo.get_database_status()
|
178 |
+
|
179 |
+
@app.get("/debug/errors")
|
180 |
+
def debug_errors(limit: int = 10):
|
181 |
+
"""Hiển thị các lỗi gần đây (chỉ trong chế độ debug)"""
|
182 |
+
return error_tracker.get_errors(limit=limit)
|
183 |
+
|
184 |
+
@app.get("/debug/performance")
|
185 |
+
def debug_performance():
|
186 |
+
"""Hiển thị thông tin hiệu suất (chỉ trong chế độ debug)"""
|
187 |
+
return performance_monitor.get_report()
|
188 |
+
|
189 |
+
@app.get("/debug/full")
|
190 |
+
def debug_full_report(request: Request):
|
191 |
+
"""Hiển thị báo cáo debug đầy đủ (chỉ trong chế độ debug)"""
|
192 |
+
return debug_view(request)
|
193 |
+
|
194 |
+
# Run the app with uvicorn when executed directly
|
195 |
+
if __name__ == "__main__":
|
196 |
+
port = int(os.environ.get("PORT", 8000))
|
197 |
+
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=DEBUG)
|
app/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PIX Project Backend
|
2 |
+
# Version: 1.0.0
|
3 |
+
|
4 |
+
__version__ = "1.0.0"
|
5 |
+
|
6 |
+
# Import app từ app.py để tests có thể tìm thấy
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
|
10 |
+
# Thêm thư mục gốc vào sys.path
|
11 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
12 |
+
|
13 |
+
try:
|
14 |
+
from app.py import app
|
15 |
+
except ImportError:
|
16 |
+
# Thử cách khác nếu import trực tiếp không hoạt động
|
17 |
+
import importlib.util
|
18 |
+
spec = importlib.util.spec_from_file_location("app_module",
|
19 |
+
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
20 |
+
"app.py"))
|
21 |
+
app_module = importlib.util.module_from_spec(spec)
|
22 |
+
spec.loader.exec_module(app_module)
|
23 |
+
app = app_module.app
|
app/api/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# API routes package
|
app/api/mongodb_routes.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, HTTPException, Depends, Query, status, Response
|
2 |
+
from typing import List, Optional, Dict
|
3 |
+
from pymongo.errors import PyMongoError
|
4 |
+
import logging
|
5 |
+
from datetime import datetime
|
6 |
+
import traceback
|
7 |
+
import asyncio
|
8 |
+
|
9 |
+
from app.database.mongodb import (
|
10 |
+
save_session,
|
11 |
+
get_user_history,
|
12 |
+
update_session_response,
|
13 |
+
check_db_connection,
|
14 |
+
session_collection
|
15 |
+
)
|
16 |
+
from app.models.mongodb_models import (
|
17 |
+
SessionCreate,
|
18 |
+
SessionResponse,
|
19 |
+
HistoryRequest,
|
20 |
+
HistoryResponse,
|
21 |
+
QuestionAnswer
|
22 |
+
)
|
23 |
+
from app.api.websocket_routes import send_notification
|
24 |
+
|
25 |
+
# Configure logging
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
# Create router
|
29 |
+
router = APIRouter(
|
30 |
+
prefix="/mongodb",
|
31 |
+
tags=["MongoDB"],
|
32 |
+
)
|
33 |
+
|
34 |
+
@router.post("/session", response_model=SessionResponse, status_code=status.HTTP_201_CREATED)
|
35 |
+
async def create_session(session: SessionCreate, response: Response):
|
36 |
+
"""
|
37 |
+
Create a new session record in MongoDB.
|
38 |
+
|
39 |
+
- **session_id**: Unique identifier for the session (auto-generated if not provided)
|
40 |
+
- **factor**: Factor type (user, rag, etc.)
|
41 |
+
- **action**: Action type (start, events, faq, emergency, help, asking_freely, etc.)
|
42 |
+
- **first_name**: User's first name
|
43 |
+
- **last_name**: User's last name (optional)
|
44 |
+
- **message**: User's message (optional)
|
45 |
+
- **user_id**: User's ID from Telegram
|
46 |
+
- **username**: User's username (optional)
|
47 |
+
- **response**: Response from RAG (optional)
|
48 |
+
"""
|
49 |
+
try:
|
50 |
+
# Kiểm tra kết nối MongoDB
|
51 |
+
if not check_db_connection():
|
52 |
+
logger.error("MongoDB connection failed when trying to create session")
|
53 |
+
raise HTTPException(
|
54 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
55 |
+
detail="MongoDB connection failed"
|
56 |
+
)
|
57 |
+
|
58 |
+
# Create new session in MongoDB
|
59 |
+
result = save_session(
|
60 |
+
session_id=session.session_id,
|
61 |
+
factor=session.factor,
|
62 |
+
action=session.action,
|
63 |
+
first_name=session.first_name,
|
64 |
+
last_name=session.last_name,
|
65 |
+
message=session.message,
|
66 |
+
user_id=session.user_id,
|
67 |
+
username=session.username,
|
68 |
+
response=session.response
|
69 |
+
)
|
70 |
+
|
71 |
+
# Chuẩn bị response object
|
72 |
+
session_response = SessionResponse(
|
73 |
+
**session.model_dump(),
|
74 |
+
created_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
75 |
+
)
|
76 |
+
|
77 |
+
# Kiểm tra nếu session cần gửi thông báo (response bắt đầu bằng "I don't know")
|
78 |
+
if session.response and session.response.strip().lower().startswith("i don't know"):
|
79 |
+
# Gửi thông báo qua WebSocket
|
80 |
+
try:
|
81 |
+
notification_data = {
|
82 |
+
"session_id": session.session_id,
|
83 |
+
"factor": session.factor,
|
84 |
+
"action": session.action,
|
85 |
+
"message": session.message,
|
86 |
+
"user_id": session.user_id,
|
87 |
+
"username": session.username,
|
88 |
+
"first_name": session.first_name,
|
89 |
+
"last_name": session.last_name,
|
90 |
+
"response": session.response,
|
91 |
+
"created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
92 |
+
}
|
93 |
+
|
94 |
+
# Khởi tạo task để gửi thông báo - sử dụng asyncio.create_task để đảm bảo không block quá trình chính
|
95 |
+
asyncio.create_task(send_notification(notification_data))
|
96 |
+
logger.info(f"Notification queued for session {session.session_id} - response starts with 'I don't know'")
|
97 |
+
except Exception as e:
|
98 |
+
logger.error(f"Error queueing notification: {e}")
|
99 |
+
# Không dừng xử lý chính khi gửi thông báo thất bại
|
100 |
+
|
101 |
+
# Return response
|
102 |
+
return session_response
|
103 |
+
except PyMongoError as e:
|
104 |
+
logger.error(f"MongoDB error creating session: {e}")
|
105 |
+
logger.error(traceback.format_exc())
|
106 |
+
raise HTTPException(
|
107 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
108 |
+
detail=f"MongoDB error: {str(e)}"
|
109 |
+
)
|
110 |
+
except Exception as e:
|
111 |
+
logger.error(f"Unexpected error creating session: {e}")
|
112 |
+
logger.error(traceback.format_exc())
|
113 |
+
raise HTTPException(
|
114 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
115 |
+
detail=f"Failed to create session: {str(e)}"
|
116 |
+
)
|
117 |
+
|
118 |
+
@router.put("/session/{session_id}/response", status_code=status.HTTP_200_OK)
|
119 |
+
async def update_session_with_response(session_id: str, response_text: str):
|
120 |
+
"""
|
121 |
+
Update a session with the response.
|
122 |
+
|
123 |
+
- **session_id**: ID of the session to update
|
124 |
+
- **response_text**: Response to add to the session
|
125 |
+
"""
|
126 |
+
try:
|
127 |
+
# Kiểm tra kết nối MongoDB
|
128 |
+
if not check_db_connection():
|
129 |
+
logger.error("MongoDB connection failed when trying to update session response")
|
130 |
+
raise HTTPException(
|
131 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
132 |
+
detail="MongoDB connection failed"
|
133 |
+
)
|
134 |
+
|
135 |
+
# Update session in MongoDB
|
136 |
+
result = update_session_response(session_id, response_text)
|
137 |
+
|
138 |
+
if not result:
|
139 |
+
raise HTTPException(
|
140 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
141 |
+
detail=f"Session with ID {session_id} not found"
|
142 |
+
)
|
143 |
+
|
144 |
+
return {"status": "success", "message": "Response added to session"}
|
145 |
+
except PyMongoError as e:
|
146 |
+
logger.error(f"MongoDB error updating session response: {e}")
|
147 |
+
logger.error(traceback.format_exc())
|
148 |
+
raise HTTPException(
|
149 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
150 |
+
detail=f"MongoDB error: {str(e)}"
|
151 |
+
)
|
152 |
+
except HTTPException:
|
153 |
+
# Re-throw HTTP exceptions
|
154 |
+
raise
|
155 |
+
except Exception as e:
|
156 |
+
logger.error(f"Unexpected error updating session response: {e}")
|
157 |
+
logger.error(traceback.format_exc())
|
158 |
+
raise HTTPException(
|
159 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
160 |
+
detail=f"Failed to update session: {str(e)}"
|
161 |
+
)
|
162 |
+
|
163 |
+
@router.get("/history", response_model=HistoryResponse)
|
164 |
+
async def get_history(user_id: str, n: int = Query(3, ge=1, le=10)):
|
165 |
+
"""
|
166 |
+
Get user history for a specific user.
|
167 |
+
|
168 |
+
- **user_id**: User's ID from Telegram
|
169 |
+
- **n**: Number of most recent interactions to return (default: 3, min: 1, max: 10)
|
170 |
+
"""
|
171 |
+
try:
|
172 |
+
# Kiểm tra kết nối MongoDB
|
173 |
+
if not check_db_connection():
|
174 |
+
logger.error("MongoDB connection failed when trying to get user history")
|
175 |
+
raise HTTPException(
|
176 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
177 |
+
detail="MongoDB connection failed"
|
178 |
+
)
|
179 |
+
|
180 |
+
# Get user history from MongoDB
|
181 |
+
history_data = get_user_history(user_id=user_id, n=n)
|
182 |
+
|
183 |
+
# Convert to response model
|
184 |
+
return HistoryResponse(history=history_data)
|
185 |
+
except PyMongoError as e:
|
186 |
+
logger.error(f"MongoDB error getting user history: {e}")
|
187 |
+
logger.error(traceback.format_exc())
|
188 |
+
raise HTTPException(
|
189 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
190 |
+
detail=f"MongoDB error: {str(e)}"
|
191 |
+
)
|
192 |
+
except Exception as e:
|
193 |
+
logger.error(f"Unexpected error getting user history: {e}")
|
194 |
+
logger.error(traceback.format_exc())
|
195 |
+
raise HTTPException(
|
196 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
197 |
+
detail=f"Failed to get user history: {str(e)}"
|
198 |
+
)
|
199 |
+
|
200 |
+
@router.get("/health")
|
201 |
+
async def health_check():
|
202 |
+
"""
|
203 |
+
Check health of MongoDB connection.
|
204 |
+
"""
|
205 |
+
try:
|
206 |
+
# Kiểm tra kết nối MongoDB
|
207 |
+
is_connected = check_db_connection()
|
208 |
+
|
209 |
+
if not is_connected:
|
210 |
+
return {
|
211 |
+
"status": "unhealthy",
|
212 |
+
"message": "MongoDB connection failed",
|
213 |
+
"timestamp": datetime.now().isoformat()
|
214 |
+
}
|
215 |
+
|
216 |
+
return {
|
217 |
+
"status": "healthy",
|
218 |
+
"message": "MongoDB connection is working",
|
219 |
+
"timestamp": datetime.now().isoformat()
|
220 |
+
}
|
221 |
+
except Exception as e:
|
222 |
+
logger.error(f"MongoDB health check failed: {e}")
|
223 |
+
logger.error(traceback.format_exc())
|
224 |
+
return {
|
225 |
+
"status": "error",
|
226 |
+
"message": f"MongoDB health check error: {str(e)}",
|
227 |
+
"timestamp": datetime.now().isoformat()
|
228 |
+
}
|
229 |
+
|
230 |
+
@router.get("/session/{session_id}")
|
231 |
+
async def get_session(session_id: str):
|
232 |
+
"""
|
233 |
+
Lấy thông tin session từ MongoDB theo session_id.
|
234 |
+
|
235 |
+
- **session_id**: ID của session cần lấy
|
236 |
+
"""
|
237 |
+
try:
|
238 |
+
# Kiểm tra kết nối MongoDB
|
239 |
+
if not check_db_connection():
|
240 |
+
logger.error("MongoDB connection failed when trying to get session")
|
241 |
+
raise HTTPException(
|
242 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
243 |
+
detail="MongoDB connection failed"
|
244 |
+
)
|
245 |
+
|
246 |
+
# Lấy thông tin từ MongoDB
|
247 |
+
session_data = session_collection.find_one({"session_id": session_id})
|
248 |
+
|
249 |
+
if not session_data:
|
250 |
+
raise HTTPException(
|
251 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
252 |
+
detail=f"Session with ID {session_id} not found"
|
253 |
+
)
|
254 |
+
|
255 |
+
# Chuyển _id thành string để có thể JSON serialize
|
256 |
+
if "_id" in session_data:
|
257 |
+
session_data["_id"] = str(session_data["_id"])
|
258 |
+
|
259 |
+
return session_data
|
260 |
+
except PyMongoError as e:
|
261 |
+
logger.error(f"MongoDB error getting session: {e}")
|
262 |
+
logger.error(traceback.format_exc())
|
263 |
+
raise HTTPException(
|
264 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
265 |
+
detail=f"MongoDB error: {str(e)}"
|
266 |
+
)
|
267 |
+
except HTTPException:
|
268 |
+
# Re-throw HTTP exceptions
|
269 |
+
raise
|
270 |
+
except Exception as e:
|
271 |
+
logger.error(f"Unexpected error getting session: {e}")
|
272 |
+
logger.error(traceback.format_exc())
|
273 |
+
raise HTTPException(
|
274 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
275 |
+
detail=f"Failed to get session: {str(e)}"
|
276 |
+
)
|
app/api/postgresql_routes.py
ADDED
@@ -0,0 +1,626 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, HTTPException, Depends, Query, Path, Body
|
2 |
+
from sqlalchemy.orm import Session
|
3 |
+
from sqlalchemy.exc import SQLAlchemyError
|
4 |
+
from typing import List, Optional, Dict, Any
|
5 |
+
import logging
|
6 |
+
import traceback
|
7 |
+
from datetime import datetime
|
8 |
+
from sqlalchemy import text, inspect
|
9 |
+
|
10 |
+
from app.database.postgresql import get_db
|
11 |
+
from app.database.models import FAQItem, EmergencyItem, EventItem
|
12 |
+
from pydantic import BaseModel, Field, ConfigDict
|
13 |
+
|
14 |
+
# Configure logging
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
# Create router
|
18 |
+
router = APIRouter(
|
19 |
+
prefix="/postgres",
|
20 |
+
tags=["PostgreSQL"],
|
21 |
+
)
|
22 |
+
|
23 |
+
# --- Pydantic models for request/response ---
|
24 |
+
|
25 |
+
# FAQ models
|
26 |
+
class FAQBase(BaseModel):
|
27 |
+
question: str
|
28 |
+
answer: str
|
29 |
+
is_active: bool = True
|
30 |
+
|
31 |
+
class FAQCreate(FAQBase):
|
32 |
+
pass
|
33 |
+
|
34 |
+
class FAQUpdate(BaseModel):
|
35 |
+
question: Optional[str] = None
|
36 |
+
answer: Optional[str] = None
|
37 |
+
is_active: Optional[bool] = None
|
38 |
+
|
39 |
+
class FAQResponse(FAQBase):
|
40 |
+
id: int
|
41 |
+
created_at: datetime
|
42 |
+
updated_at: datetime
|
43 |
+
|
44 |
+
# Sử dụng ConfigDict thay vì class Config cho Pydantic V2
|
45 |
+
model_config = ConfigDict(from_attributes=True)
|
46 |
+
|
47 |
+
# Emergency contact models
|
48 |
+
class EmergencyBase(BaseModel):
|
49 |
+
name: str
|
50 |
+
phone_number: str
|
51 |
+
description: Optional[str] = None
|
52 |
+
address: Optional[str] = None
|
53 |
+
location: Optional[str] = None
|
54 |
+
priority: int = 0
|
55 |
+
is_active: bool = True
|
56 |
+
|
57 |
+
class EmergencyCreate(EmergencyBase):
|
58 |
+
pass
|
59 |
+
|
60 |
+
class EmergencyUpdate(BaseModel):
|
61 |
+
name: Optional[str] = None
|
62 |
+
phone_number: Optional[str] = None
|
63 |
+
description: Optional[str] = None
|
64 |
+
address: Optional[str] = None
|
65 |
+
location: Optional[str] = None
|
66 |
+
priority: Optional[int] = None
|
67 |
+
is_active: Optional[bool] = None
|
68 |
+
|
69 |
+
class EmergencyResponse(EmergencyBase):
|
70 |
+
id: int
|
71 |
+
created_at: datetime
|
72 |
+
updated_at: datetime
|
73 |
+
|
74 |
+
# Sử dụng ConfigDict thay vì class Config cho Pydantic V2
|
75 |
+
model_config = ConfigDict(from_attributes=True)
|
76 |
+
|
77 |
+
# Event models
|
78 |
+
class EventBase(BaseModel):
|
79 |
+
name: str
|
80 |
+
description: str
|
81 |
+
address: str
|
82 |
+
location: Optional[str] = None
|
83 |
+
date_start: datetime
|
84 |
+
date_end: Optional[datetime] = None
|
85 |
+
price: Optional[List[dict]] = None
|
86 |
+
is_active: bool = True
|
87 |
+
featured: bool = False
|
88 |
+
|
89 |
+
class EventCreate(EventBase):
|
90 |
+
pass
|
91 |
+
|
92 |
+
class EventUpdate(BaseModel):
|
93 |
+
name: Optional[str] = None
|
94 |
+
description: Optional[str] = None
|
95 |
+
address: Optional[str] = None
|
96 |
+
location: Optional[str] = None
|
97 |
+
date_start: Optional[datetime] = None
|
98 |
+
date_end: Optional[datetime] = None
|
99 |
+
price: Optional[List[dict]] = None
|
100 |
+
is_active: Optional[bool] = None
|
101 |
+
featured: Optional[bool] = None
|
102 |
+
|
103 |
+
class EventResponse(EventBase):
|
104 |
+
id: int
|
105 |
+
created_at: datetime
|
106 |
+
updated_at: datetime
|
107 |
+
|
108 |
+
# Sử dụng ConfigDict thay vì class Config cho Pydantic V2
|
109 |
+
model_config = ConfigDict(from_attributes=True)
|
110 |
+
|
111 |
+
# --- FAQ endpoints ---
|
112 |
+
|
113 |
+
@router.get("/faq", response_model=List[FAQResponse])
|
114 |
+
async def get_faqs(
|
115 |
+
skip: int = 0,
|
116 |
+
limit: int = 100,
|
117 |
+
active_only: bool = False,
|
118 |
+
db: Session = Depends(get_db)
|
119 |
+
):
|
120 |
+
"""
|
121 |
+
Get all FAQ items.
|
122 |
+
|
123 |
+
- **skip**: Number of items to skip
|
124 |
+
- **limit**: Maximum number of items to return
|
125 |
+
- **active_only**: If true, only return active items
|
126 |
+
"""
|
127 |
+
try:
|
128 |
+
# Log detailed connection info
|
129 |
+
logger.info(f"Attempting to fetch FAQs with skip={skip}, limit={limit}, active_only={active_only}")
|
130 |
+
|
131 |
+
# Check if the FAQItem table exists
|
132 |
+
inspector = inspect(db.bind)
|
133 |
+
if not inspector.has_table("faq_item"):
|
134 |
+
logger.error("The faq_item table does not exist in the database")
|
135 |
+
raise HTTPException(status_code=500, detail="Table 'faq_item' does not exist")
|
136 |
+
|
137 |
+
# Log table columns
|
138 |
+
columns = inspector.get_columns("faq_item")
|
139 |
+
logger.info(f"faq_item table columns: {[c['name'] for c in columns]}")
|
140 |
+
|
141 |
+
# Query the FAQs with detailed logging
|
142 |
+
query = db.query(FAQItem)
|
143 |
+
if active_only:
|
144 |
+
query = query.filter(FAQItem.is_active == True)
|
145 |
+
|
146 |
+
# Try direct SQL to debug
|
147 |
+
try:
|
148 |
+
test_result = db.execute(text("SELECT COUNT(*) FROM faq_item")).scalar()
|
149 |
+
logger.info(f"SQL test query succeeded, found {test_result} FAQ items")
|
150 |
+
except Exception as sql_error:
|
151 |
+
logger.error(f"SQL test query failed: {sql_error}")
|
152 |
+
|
153 |
+
# Execute the ORM query
|
154 |
+
faqs = query.offset(skip).limit(limit).all()
|
155 |
+
logger.info(f"Successfully fetched {len(faqs)} FAQ items")
|
156 |
+
|
157 |
+
# Check what we're returning
|
158 |
+
for i, faq in enumerate(faqs[:3]): # Log the first 3 items
|
159 |
+
logger.info(f"FAQ item {i+1}: id={faq.id}, question={faq.question[:30]}...")
|
160 |
+
|
161 |
+
# Convert SQLAlchemy models to Pydantic models - updated for Pydantic v2
|
162 |
+
result = [FAQResponse.model_validate(faq, from_attributes=True) for faq in faqs]
|
163 |
+
return result
|
164 |
+
except SQLAlchemyError as e:
|
165 |
+
logger.error(f"Database error in get_faqs: {e}")
|
166 |
+
logger.error(traceback.format_exc())
|
167 |
+
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
|
168 |
+
except Exception as e:
|
169 |
+
logger.error(f"Unexpected error in get_faqs: {e}")
|
170 |
+
logger.error(traceback.format_exc())
|
171 |
+
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
|
172 |
+
|
173 |
+
@router.post("/faq", response_model=FAQResponse)
|
174 |
+
async def create_faq(
|
175 |
+
faq: FAQCreate,
|
176 |
+
db: Session = Depends(get_db)
|
177 |
+
):
|
178 |
+
"""
|
179 |
+
Create a new FAQ item.
|
180 |
+
|
181 |
+
- **question**: Question text
|
182 |
+
- **answer**: Answer text
|
183 |
+
- **is_active**: Whether the FAQ is active (default: True)
|
184 |
+
"""
|
185 |
+
try:
|
186 |
+
# Sử dụng model_dump thay vì dict
|
187 |
+
db_faq = FAQItem(**faq.model_dump())
|
188 |
+
db.add(db_faq)
|
189 |
+
db.commit()
|
190 |
+
db.refresh(db_faq)
|
191 |
+
return FAQResponse.model_validate(db_faq, from_attributes=True)
|
192 |
+
except SQLAlchemyError as e:
|
193 |
+
db.rollback()
|
194 |
+
logger.error(f"Database error: {e}")
|
195 |
+
raise HTTPException(status_code=500, detail="Failed to create FAQ item")
|
196 |
+
|
197 |
+
@router.get("/faq/{faq_id}", response_model=FAQResponse)
|
198 |
+
async def get_faq(
|
199 |
+
faq_id: int = Path(..., gt=0),
|
200 |
+
db: Session = Depends(get_db)
|
201 |
+
):
|
202 |
+
"""
|
203 |
+
Get a specific FAQ item by ID.
|
204 |
+
|
205 |
+
- **faq_id**: ID of the FAQ item
|
206 |
+
"""
|
207 |
+
try:
|
208 |
+
faq = db.query(FAQItem).filter(FAQItem.id == faq_id).first()
|
209 |
+
if not faq:
|
210 |
+
raise HTTPException(status_code=404, detail="FAQ item not found")
|
211 |
+
return FAQResponse.model_validate(faq, from_attributes=True)
|
212 |
+
except SQLAlchemyError as e:
|
213 |
+
logger.error(f"Database error: {e}")
|
214 |
+
raise HTTPException(status_code=500, detail="Database error")
|
215 |
+
|
216 |
+
@router.put("/faq/{faq_id}", response_model=FAQResponse)
|
217 |
+
async def update_faq(
|
218 |
+
faq_id: int = Path(..., gt=0),
|
219 |
+
faq_update: FAQUpdate = Body(...),
|
220 |
+
db: Session = Depends(get_db)
|
221 |
+
):
|
222 |
+
"""
|
223 |
+
Update a specific FAQ item.
|
224 |
+
|
225 |
+
- **faq_id**: ID of the FAQ item to update
|
226 |
+
- **question**: New question text (optional)
|
227 |
+
- **answer**: New answer text (optional)
|
228 |
+
- **is_active**: New active status (optional)
|
229 |
+
"""
|
230 |
+
try:
|
231 |
+
faq = db.query(FAQItem).filter(FAQItem.id == faq_id).first()
|
232 |
+
if not faq:
|
233 |
+
raise HTTPException(status_code=404, detail="FAQ item not found")
|
234 |
+
|
235 |
+
# Sử dụng model_dump thay vì dict
|
236 |
+
update_data = faq_update.model_dump(exclude_unset=True)
|
237 |
+
for key, value in update_data.items():
|
238 |
+
setattr(faq, key, value)
|
239 |
+
|
240 |
+
db.commit()
|
241 |
+
db.refresh(faq)
|
242 |
+
return FAQResponse.model_validate(faq, from_attributes=True)
|
243 |
+
except SQLAlchemyError as e:
|
244 |
+
db.rollback()
|
245 |
+
logger.error(f"Database error: {e}")
|
246 |
+
raise HTTPException(status_code=500, detail="Failed to update FAQ item")
|
247 |
+
|
248 |
+
@router.delete("/faq/{faq_id}", response_model=dict)
|
249 |
+
async def delete_faq(
|
250 |
+
faq_id: int = Path(..., gt=0),
|
251 |
+
db: Session = Depends(get_db)
|
252 |
+
):
|
253 |
+
"""
|
254 |
+
Delete a specific FAQ item.
|
255 |
+
|
256 |
+
- **faq_id**: ID of the FAQ item to delete
|
257 |
+
"""
|
258 |
+
try:
|
259 |
+
faq = db.query(FAQItem).filter(FAQItem.id == faq_id).first()
|
260 |
+
if not faq:
|
261 |
+
raise HTTPException(status_code=404, detail="FAQ item not found")
|
262 |
+
|
263 |
+
db.delete(faq)
|
264 |
+
db.commit()
|
265 |
+
return {"status": "success", "message": f"FAQ item {faq_id} deleted"}
|
266 |
+
except SQLAlchemyError as e:
|
267 |
+
db.rollback()
|
268 |
+
logger.error(f"Database error: {e}")
|
269 |
+
raise HTTPException(status_code=500, detail="Failed to delete FAQ item")
|
270 |
+
|
271 |
+
# --- Emergency endpoints ---
|
272 |
+
|
273 |
+
@router.get("/emergency", response_model=List[EmergencyResponse])
|
274 |
+
async def get_emergency_contacts(
|
275 |
+
skip: int = 0,
|
276 |
+
limit: int = 100,
|
277 |
+
active_only: bool = False,
|
278 |
+
db: Session = Depends(get_db)
|
279 |
+
):
|
280 |
+
"""
|
281 |
+
Get all emergency contacts.
|
282 |
+
|
283 |
+
- **skip**: Number of items to skip
|
284 |
+
- **limit**: Maximum number of items to return
|
285 |
+
- **active_only**: If true, only return active items
|
286 |
+
"""
|
287 |
+
try:
|
288 |
+
# Log detailed connection info
|
289 |
+
logger.info(f"Attempting to fetch emergency contacts with skip={skip}, limit={limit}, active_only={active_only}")
|
290 |
+
|
291 |
+
# Check if the EmergencyItem table exists
|
292 |
+
inspector = inspect(db.bind)
|
293 |
+
if not inspector.has_table("emergency_item"):
|
294 |
+
logger.error("The emergency_item table does not exist in the database")
|
295 |
+
raise HTTPException(status_code=500, detail="Table 'emergency_item' does not exist")
|
296 |
+
|
297 |
+
# Log table columns
|
298 |
+
columns = inspector.get_columns("emergency_item")
|
299 |
+
logger.info(f"emergency_item table columns: {[c['name'] for c in columns]}")
|
300 |
+
|
301 |
+
# Try direct SQL to debug
|
302 |
+
try:
|
303 |
+
test_result = db.execute(text("SELECT COUNT(*) FROM emergency_item")).scalar()
|
304 |
+
logger.info(f"SQL test query succeeded, found {test_result} emergency contacts")
|
305 |
+
except Exception as sql_error:
|
306 |
+
logger.error(f"SQL test query failed: {sql_error}")
|
307 |
+
|
308 |
+
# Query the emergency contacts
|
309 |
+
query = db.query(EmergencyItem)
|
310 |
+
if active_only:
|
311 |
+
query = query.filter(EmergencyItem.is_active == True)
|
312 |
+
|
313 |
+
# Execute the ORM query
|
314 |
+
emergency_contacts = query.offset(skip).limit(limit).all()
|
315 |
+
logger.info(f"Successfully fetched {len(emergency_contacts)} emergency contacts")
|
316 |
+
|
317 |
+
# Check what we're returning
|
318 |
+
for i, contact in enumerate(emergency_contacts[:3]): # Log the first 3 items
|
319 |
+
logger.info(f"Emergency contact {i+1}: id={contact.id}, name={contact.name}")
|
320 |
+
|
321 |
+
return emergency_contacts
|
322 |
+
except SQLAlchemyError as e:
|
323 |
+
logger.error(f"Database error in get_emergency_contacts: {e}")
|
324 |
+
logger.error(traceback.format_exc())
|
325 |
+
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
|
326 |
+
except Exception as e:
|
327 |
+
logger.error(f"Unexpected error in get_emergency_contacts: {e}")
|
328 |
+
logger.error(traceback.format_exc())
|
329 |
+
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
|
330 |
+
|
331 |
+
@router.post("/emergency", response_model=EmergencyResponse)
|
332 |
+
async def create_emergency_contact(
|
333 |
+
emergency: EmergencyCreate,
|
334 |
+
db: Session = Depends(get_db)
|
335 |
+
):
|
336 |
+
"""
|
337 |
+
Create a new emergency contact.
|
338 |
+
|
339 |
+
- **name**: Contact name
|
340 |
+
- **phone_number**: Phone number
|
341 |
+
- **description**: Description (optional)
|
342 |
+
- **address**: Address (optional)
|
343 |
+
- **location**: Location coordinates (optional)
|
344 |
+
- **priority**: Priority order (default: 0)
|
345 |
+
- **is_active**: Whether the contact is active (default: True)
|
346 |
+
"""
|
347 |
+
try:
|
348 |
+
db_emergency = EmergencyItem(**emergency.model_dump())
|
349 |
+
db.add(db_emergency)
|
350 |
+
db.commit()
|
351 |
+
db.refresh(db_emergency)
|
352 |
+
return db_emergency
|
353 |
+
except SQLAlchemyError as e:
|
354 |
+
db.rollback()
|
355 |
+
logger.error(f"Database error: {e}")
|
356 |
+
raise HTTPException(status_code=500, detail="Failed to create emergency contact")
|
357 |
+
|
358 |
+
@router.get("/emergency/{emergency_id}", response_model=EmergencyResponse)
|
359 |
+
async def get_emergency_contact(
|
360 |
+
emergency_id: int = Path(..., gt=0),
|
361 |
+
db: Session = Depends(get_db)
|
362 |
+
):
|
363 |
+
"""
|
364 |
+
Get a specific emergency contact by ID.
|
365 |
+
|
366 |
+
- **emergency_id**: ID of the emergency contact
|
367 |
+
"""
|
368 |
+
try:
|
369 |
+
emergency = db.query(EmergencyItem).filter(EmergencyItem.id == emergency_id).first()
|
370 |
+
if not emergency:
|
371 |
+
raise HTTPException(status_code=404, detail="Emergency contact not found")
|
372 |
+
return emergency
|
373 |
+
except SQLAlchemyError as e:
|
374 |
+
logger.error(f"Database error: {e}")
|
375 |
+
raise HTTPException(status_code=500, detail="Database error")
|
376 |
+
|
377 |
+
@router.put("/emergency/{emergency_id}", response_model=EmergencyResponse)
|
378 |
+
async def update_emergency_contact(
|
379 |
+
emergency_id: int = Path(..., gt=0),
|
380 |
+
emergency_update: EmergencyUpdate = Body(...),
|
381 |
+
db: Session = Depends(get_db)
|
382 |
+
):
|
383 |
+
"""
|
384 |
+
Update a specific emergency contact.
|
385 |
+
|
386 |
+
- **emergency_id**: ID of the emergency contact to update
|
387 |
+
- **name**: New name (optional)
|
388 |
+
- **phone_number**: New phone number (optional)
|
389 |
+
- **description**: New description (optional)
|
390 |
+
- **address**: New address (optional)
|
391 |
+
- **location**: New location coordinates (optional)
|
392 |
+
- **priority**: New priority order (optional)
|
393 |
+
- **is_active**: New active status (optional)
|
394 |
+
"""
|
395 |
+
try:
|
396 |
+
emergency = db.query(EmergencyItem).filter(EmergencyItem.id == emergency_id).first()
|
397 |
+
if not emergency:
|
398 |
+
raise HTTPException(status_code=404, detail="Emergency contact not found")
|
399 |
+
|
400 |
+
# Update fields if provided
|
401 |
+
update_data = emergency_update.model_dump(exclude_unset=True)
|
402 |
+
for key, value in update_data.items():
|
403 |
+
setattr(emergency, key, value)
|
404 |
+
|
405 |
+
db.commit()
|
406 |
+
db.refresh(emergency)
|
407 |
+
return emergency
|
408 |
+
except SQLAlchemyError as e:
|
409 |
+
db.rollback()
|
410 |
+
logger.error(f"Database error: {e}")
|
411 |
+
raise HTTPException(status_code=500, detail="Failed to update emergency contact")
|
412 |
+
|
413 |
+
@router.delete("/emergency/{emergency_id}", response_model=dict)
|
414 |
+
async def delete_emergency_contact(
|
415 |
+
emergency_id: int = Path(..., gt=0),
|
416 |
+
db: Session = Depends(get_db)
|
417 |
+
):
|
418 |
+
"""
|
419 |
+
Delete a specific emergency contact.
|
420 |
+
|
421 |
+
- **emergency_id**: ID of the emergency contact to delete
|
422 |
+
"""
|
423 |
+
try:
|
424 |
+
emergency = db.query(EmergencyItem).filter(EmergencyItem.id == emergency_id).first()
|
425 |
+
if not emergency:
|
426 |
+
raise HTTPException(status_code=404, detail="Emergency contact not found")
|
427 |
+
|
428 |
+
db.delete(emergency)
|
429 |
+
db.commit()
|
430 |
+
return {"status": "success", "message": f"Emergency contact {emergency_id} deleted"}
|
431 |
+
except SQLAlchemyError as e:
|
432 |
+
db.rollback()
|
433 |
+
logger.error(f"Database error: {e}")
|
434 |
+
raise HTTPException(status_code=500, detail="Failed to delete emergency contact")
|
435 |
+
|
436 |
+
# --- Event endpoints ---
|
437 |
+
|
438 |
+
@router.get("/events", response_model=List[EventResponse])
|
439 |
+
async def get_events(
|
440 |
+
skip: int = 0,
|
441 |
+
limit: int = 100,
|
442 |
+
active_only: bool = False,
|
443 |
+
featured_only: bool = False,
|
444 |
+
db: Session = Depends(get_db)
|
445 |
+
):
|
446 |
+
"""
|
447 |
+
Get all events.
|
448 |
+
|
449 |
+
- **skip**: Number of items to skip
|
450 |
+
- **limit**: Maximum number of items to return
|
451 |
+
- **active_only**: If true, only return active items
|
452 |
+
- **featured_only**: If true, only return featured items
|
453 |
+
"""
|
454 |
+
try:
|
455 |
+
# Log detailed connection info
|
456 |
+
logger.info(f"Attempting to fetch events with skip={skip}, limit={limit}, active_only={active_only}, featured_only={featured_only}")
|
457 |
+
|
458 |
+
# Check if the EventItem table exists
|
459 |
+
inspector = inspect(db.bind)
|
460 |
+
if not inspector.has_table("event_item"):
|
461 |
+
logger.error("The event_item table does not exist in the database")
|
462 |
+
raise HTTPException(status_code=500, detail="Table 'event_item' does not exist")
|
463 |
+
|
464 |
+
# Log table columns
|
465 |
+
columns = inspector.get_columns("event_item")
|
466 |
+
logger.info(f"event_item table columns: {[c['name'] for c in columns]}")
|
467 |
+
|
468 |
+
# Try direct SQL to debug
|
469 |
+
try:
|
470 |
+
test_result = db.execute(text("SELECT COUNT(*) FROM event_item")).scalar()
|
471 |
+
logger.info(f"SQL test query succeeded, found {test_result} events")
|
472 |
+
except Exception as sql_error:
|
473 |
+
logger.error(f"SQL test query failed: {sql_error}")
|
474 |
+
|
475 |
+
# Query the events
|
476 |
+
query = db.query(EventItem)
|
477 |
+
if active_only:
|
478 |
+
query = query.filter(EventItem.is_active == True)
|
479 |
+
if featured_only:
|
480 |
+
query = query.filter(EventItem.featured == True)
|
481 |
+
|
482 |
+
# Execute the ORM query
|
483 |
+
events = query.offset(skip).limit(limit).all()
|
484 |
+
logger.info(f"Successfully fetched {len(events)} events")
|
485 |
+
|
486 |
+
# Debug price field of first event
|
487 |
+
if events and len(events) > 0:
|
488 |
+
logger.info(f"First event price type: {type(events[0].price)}, value: {events[0].price}")
|
489 |
+
|
490 |
+
# Check what we're returning
|
491 |
+
for i, event in enumerate(events[:3]): # Log the first 3 items
|
492 |
+
logger.info(f"Event {i+1}: id={event.id}, name={event.name}, price={type(event.price)}")
|
493 |
+
|
494 |
+
return events
|
495 |
+
except SQLAlchemyError as e:
|
496 |
+
logger.error(f"Database error in get_events: {e}")
|
497 |
+
logger.error(traceback.format_exc())
|
498 |
+
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
|
499 |
+
except Exception as e:
|
500 |
+
logger.error(f"Unexpected error in get_events: {e}")
|
501 |
+
logger.error(traceback.format_exc())
|
502 |
+
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
|
503 |
+
|
504 |
+
@router.post("/events", response_model=EventResponse)
|
505 |
+
async def create_event(
|
506 |
+
event: EventCreate,
|
507 |
+
db: Session = Depends(get_db)
|
508 |
+
):
|
509 |
+
"""
|
510 |
+
Create a new event.
|
511 |
+
|
512 |
+
- **name**: Event name
|
513 |
+
- **description**: Event description
|
514 |
+
- **address**: Event address
|
515 |
+
- **location**: Location coordinates (optional)
|
516 |
+
- **date_start**: Start date and time
|
517 |
+
- **date_end**: End date and time (optional)
|
518 |
+
- **price**: Price information (optional JSON object)
|
519 |
+
- **is_active**: Whether the event is active (default: True)
|
520 |
+
- **featured**: Whether the event is featured (default: False)
|
521 |
+
"""
|
522 |
+
try:
|
523 |
+
db_event = EventItem(**event.model_dump())
|
524 |
+
db.add(db_event)
|
525 |
+
db.commit()
|
526 |
+
db.refresh(db_event)
|
527 |
+
return db_event
|
528 |
+
except SQLAlchemyError as e:
|
529 |
+
db.rollback()
|
530 |
+
logger.error(f"Database error: {e}")
|
531 |
+
raise HTTPException(status_code=500, detail="Failed to create event")
|
532 |
+
|
533 |
+
@router.get("/events/{event_id}", response_model=EventResponse)
|
534 |
+
async def get_event(
|
535 |
+
event_id: int = Path(..., gt=0),
|
536 |
+
db: Session = Depends(get_db)
|
537 |
+
):
|
538 |
+
"""
|
539 |
+
Get a specific event by ID.
|
540 |
+
|
541 |
+
- **event_id**: ID of the event
|
542 |
+
"""
|
543 |
+
try:
|
544 |
+
event = db.query(EventItem).filter(EventItem.id == event_id).first()
|
545 |
+
if not event:
|
546 |
+
raise HTTPException(status_code=404, detail="Event not found")
|
547 |
+
return event
|
548 |
+
except SQLAlchemyError as e:
|
549 |
+
logger.error(f"Database error: {e}")
|
550 |
+
raise HTTPException(status_code=500, detail="Database error")
|
551 |
+
|
552 |
+
@router.put("/events/{event_id}", response_model=EventResponse)
|
553 |
+
async def update_event(
|
554 |
+
event_id: int = Path(..., gt=0),
|
555 |
+
event_update: EventUpdate = Body(...),
|
556 |
+
db: Session = Depends(get_db)
|
557 |
+
):
|
558 |
+
"""
|
559 |
+
Update a specific event.
|
560 |
+
|
561 |
+
- **event_id**: ID of the event to update
|
562 |
+
- **name**: New name (optional)
|
563 |
+
- **description**: New description (optional)
|
564 |
+
- **address**: New address (optional)
|
565 |
+
- **location**: New location coordinates (optional)
|
566 |
+
- **date_start**: New start date and time (optional)
|
567 |
+
- **date_end**: New end date and time (optional)
|
568 |
+
- **price**: New price information (optional JSON object)
|
569 |
+
- **is_active**: New active status (optional)
|
570 |
+
- **featured**: New featured status (optional)
|
571 |
+
"""
|
572 |
+
try:
|
573 |
+
event = db.query(EventItem).filter(EventItem.id == event_id).first()
|
574 |
+
if not event:
|
575 |
+
raise HTTPException(status_code=404, detail="Event not found")
|
576 |
+
|
577 |
+
# Update fields if provided
|
578 |
+
update_data = event_update.model_dump(exclude_unset=True)
|
579 |
+
for key, value in update_data.items():
|
580 |
+
setattr(event, key, value)
|
581 |
+
|
582 |
+
db.commit()
|
583 |
+
db.refresh(event)
|
584 |
+
return event
|
585 |
+
except SQLAlchemyError as e:
|
586 |
+
db.rollback()
|
587 |
+
logger.error(f"Database error: {e}")
|
588 |
+
raise HTTPException(status_code=500, detail="Failed to update event")
|
589 |
+
|
590 |
+
@router.delete("/events/{event_id}", response_model=dict)
|
591 |
+
async def delete_event(
|
592 |
+
event_id: int = Path(..., gt=0),
|
593 |
+
db: Session = Depends(get_db)
|
594 |
+
):
|
595 |
+
"""
|
596 |
+
Delete a specific event.
|
597 |
+
|
598 |
+
- **event_id**: ID of the event to delete
|
599 |
+
"""
|
600 |
+
try:
|
601 |
+
event = db.query(EventItem).filter(EventItem.id == event_id).first()
|
602 |
+
if not event:
|
603 |
+
raise HTTPException(status_code=404, detail="Event not found")
|
604 |
+
|
605 |
+
db.delete(event)
|
606 |
+
db.commit()
|
607 |
+
return {"status": "success", "message": f"Event {event_id} deleted"}
|
608 |
+
except SQLAlchemyError as e:
|
609 |
+
db.rollback()
|
610 |
+
logger.error(f"Database error: {e}")
|
611 |
+
raise HTTPException(status_code=500, detail="Failed to delete event")
|
612 |
+
|
613 |
+
# Health check endpoint
|
614 |
+
@router.get("/health")
|
615 |
+
async def health_check(db: Session = Depends(get_db)):
|
616 |
+
"""
|
617 |
+
Check health of PostgreSQL connection.
|
618 |
+
"""
|
619 |
+
try:
|
620 |
+
# Perform a simple database query to check health
|
621 |
+
# Use text() to wrap the SQL query for SQLAlchemy 2.0 compatibility
|
622 |
+
db.execute(text("SELECT 1")).first()
|
623 |
+
return {"status": "healthy", "message": "PostgreSQL connection is working", "timestamp": datetime.now().isoformat()}
|
624 |
+
except Exception as e:
|
625 |
+
logger.error(f"PostgreSQL health check failed: {e}")
|
626 |
+
raise HTTPException(status_code=503, detail=f"PostgreSQL connection failed: {str(e)}")
|
app/api/rag_routes.py
ADDED
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, HTTPException, Depends, Query, BackgroundTasks, Request
|
2 |
+
from typing import List, Optional, Dict, Any
|
3 |
+
import logging
|
4 |
+
import time
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import hashlib
|
8 |
+
import asyncio
|
9 |
+
import traceback
|
10 |
+
import google.generativeai as genai
|
11 |
+
from datetime import datetime
|
12 |
+
from langchain.prompts import PromptTemplate
|
13 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
14 |
+
from app.utils.utils import cache, timer_decorator
|
15 |
+
|
16 |
+
from app.database.mongodb import get_user_history, get_chat_history, get_request_history, save_session, session_collection
|
17 |
+
from app.database.pinecone import (
|
18 |
+
search_vectors,
|
19 |
+
get_chain,
|
20 |
+
DEFAULT_TOP_K,
|
21 |
+
DEFAULT_LIMIT_K,
|
22 |
+
DEFAULT_SIMILARITY_METRIC,
|
23 |
+
DEFAULT_SIMILARITY_THRESHOLD,
|
24 |
+
ALLOWED_METRICS
|
25 |
+
)
|
26 |
+
from app.models.rag_models import (
|
27 |
+
ChatRequest,
|
28 |
+
ChatResponse,
|
29 |
+
ChatResponseInternal,
|
30 |
+
SourceDocument,
|
31 |
+
EmbeddingRequest,
|
32 |
+
EmbeddingResponse,
|
33 |
+
UserMessageModel
|
34 |
+
)
|
35 |
+
|
36 |
+
# Sử dụng bộ nhớ đệm thay vì Redis
|
37 |
+
class SimpleCache:
|
38 |
+
def __init__(self):
|
39 |
+
self.cache = {}
|
40 |
+
self.expiration = {}
|
41 |
+
|
42 |
+
async def get(self, key):
|
43 |
+
if key in self.cache:
|
44 |
+
# Kiểm tra xem cache đã hết hạn chưa
|
45 |
+
if key in self.expiration and self.expiration[key] > time.time():
|
46 |
+
return self.cache[key]
|
47 |
+
else:
|
48 |
+
# Xóa cache đã hết hạn
|
49 |
+
if key in self.cache:
|
50 |
+
del self.cache[key]
|
51 |
+
if key in self.expiration:
|
52 |
+
del self.expiration[key]
|
53 |
+
return None
|
54 |
+
|
55 |
+
async def set(self, key, value, ex=300): # Mặc định 5 phút
|
56 |
+
self.cache[key] = value
|
57 |
+
self.expiration[key] = time.time() + ex
|
58 |
+
|
59 |
+
# Khởi tạo SimpleCache
|
60 |
+
redis_client = SimpleCache()
|
61 |
+
|
62 |
+
# Configure logging
|
63 |
+
logger = logging.getLogger(__name__)
|
64 |
+
|
65 |
+
# Configure Google Gemini API
|
66 |
+
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
67 |
+
genai.configure(api_key=GOOGLE_API_KEY)
|
68 |
+
|
69 |
+
# Create router
|
70 |
+
router = APIRouter(
|
71 |
+
prefix="/rag",
|
72 |
+
tags=["RAG"],
|
73 |
+
)
|
74 |
+
|
75 |
+
# Create a prompt template with conversation history
|
76 |
+
prompt = PromptTemplate(
|
77 |
+
template = """Goal:
|
78 |
+
You are a professional tour guide assistant that assists users in finding information about places in Da Nang, Vietnam.
|
79 |
+
You can provide details on restaurants, cafes, hotels, attractions, and other local venues.
|
80 |
+
You have to use core knowledge and conversation history to chat with users, who are Da Nang's tourists.
|
81 |
+
|
82 |
+
Return Format:
|
83 |
+
Respond in friendly, natural, concise and use only English like a real tour guide.
|
84 |
+
Always use HTML tags (e.g. <b> for bold) so that Telegram can render the special formatting correctly.
|
85 |
+
|
86 |
+
Warning:
|
87 |
+
Let's support users like a real tour guide, not a bot. The information in core knowledge is your own knowledge.
|
88 |
+
Your knowledge is provided in the Core Knowledge. All of information in Core Knowledge is about Da Nang, Vietnam.
|
89 |
+
You just care about current time that user mention when user ask about Solana event.
|
90 |
+
If you do not have enough information to answer user's question, please reply with "I don't know. I don't have information about that".
|
91 |
+
|
92 |
+
Core knowledge:
|
93 |
+
{context}
|
94 |
+
|
95 |
+
Conversation History:
|
96 |
+
{chat_history}
|
97 |
+
|
98 |
+
User message:
|
99 |
+
{question}
|
100 |
+
|
101 |
+
Your message:
|
102 |
+
""",
|
103 |
+
input_variables = ["context", "question", "chat_history"],
|
104 |
+
)
|
105 |
+
|
106 |
+
# Helper for embeddings
|
107 |
+
async def get_embedding(text: str):
|
108 |
+
"""Get embedding from Google Gemini API"""
|
109 |
+
try:
|
110 |
+
# Initialize embedding model
|
111 |
+
embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
112 |
+
|
113 |
+
# Generate embedding
|
114 |
+
result = await embedding_model.aembed_query(text)
|
115 |
+
|
116 |
+
# Return embedding
|
117 |
+
return {
|
118 |
+
"embedding": result,
|
119 |
+
"text": text,
|
120 |
+
"model": "embedding-001"
|
121 |
+
}
|
122 |
+
except Exception as e:
|
123 |
+
logger.error(f"Error generating embedding: {e}")
|
124 |
+
raise HTTPException(status_code=500, detail=f"Failed to generate embedding: {str(e)}")
|
125 |
+
|
126 |
+
# Endpoint for generating embeddings
|
127 |
+
@router.post("/embedding", response_model=EmbeddingResponse)
|
128 |
+
async def create_embedding(request: EmbeddingRequest):
|
129 |
+
"""
|
130 |
+
Generate embedding for text.
|
131 |
+
|
132 |
+
- **text**: Text to generate embedding for
|
133 |
+
"""
|
134 |
+
try:
|
135 |
+
# Get embedding
|
136 |
+
embedding_data = await get_embedding(request.text)
|
137 |
+
|
138 |
+
# Return embedding
|
139 |
+
return EmbeddingResponse(**embedding_data)
|
140 |
+
except Exception as e:
|
141 |
+
logger.error(f"Error generating embedding: {e}")
|
142 |
+
raise HTTPException(status_code=500, detail=f"Failed to generate embedding: {str(e)}")
|
143 |
+
|
144 |
+
@timer_decorator
|
145 |
+
@router.post("/chat", response_model=ChatResponse)
|
146 |
+
async def chat(request: ChatRequest, background_tasks: BackgroundTasks):
|
147 |
+
"""
|
148 |
+
Get answer for a question using RAG.
|
149 |
+
|
150 |
+
- **user_id**: User's ID from Telegram
|
151 |
+
- **question**: User's question
|
152 |
+
- **include_history**: Whether to include user history in prompt (default: True)
|
153 |
+
- **use_rag**: Whether to use RAG (default: True)
|
154 |
+
- **similarity_top_k**: Number of top similar documents to return after filtering (default: 6)
|
155 |
+
- **limit_k**: Maximum number of documents to retrieve from vector store (default: 10)
|
156 |
+
- **similarity_metric**: Similarity metric to use - cosine, dotproduct, euclidean (default: cosine)
|
157 |
+
- **similarity_threshold**: Threshold for vector similarity (default: 0.75)
|
158 |
+
- **session_id**: Optional session ID for tracking conversations
|
159 |
+
- **first_name**: User's first name
|
160 |
+
- **last_name**: User's last name
|
161 |
+
- **username**: User's username
|
162 |
+
"""
|
163 |
+
start_time = time.time()
|
164 |
+
try:
|
165 |
+
# Create cache key for request
|
166 |
+
cache_key = f"rag_chat:{request.user_id}:{request.question}:{request.include_history}:{request.use_rag}:{request.similarity_top_k}:{request.limit_k}:{request.similarity_metric}:{request.similarity_threshold}"
|
167 |
+
|
168 |
+
# Check cache using redis_client instead of cache
|
169 |
+
cached_response = await redis_client.get(cache_key)
|
170 |
+
if cached_response is not None:
|
171 |
+
logger.info(f"Cache hit for RAG chat request from user {request.user_id}")
|
172 |
+
try:
|
173 |
+
# If cached_response is string (JSON), parse it
|
174 |
+
if isinstance(cached_response, str):
|
175 |
+
cached_data = json.loads(cached_response)
|
176 |
+
return ChatResponse(
|
177 |
+
answer=cached_data.get("answer", ""),
|
178 |
+
processing_time=cached_data.get("processing_time", 0.0)
|
179 |
+
)
|
180 |
+
# If cached_response is object with sources, extract answer and processing_time
|
181 |
+
elif hasattr(cached_response, 'sources'):
|
182 |
+
return ChatResponse(
|
183 |
+
answer=cached_response.answer,
|
184 |
+
processing_time=cached_response.processing_time
|
185 |
+
)
|
186 |
+
# Otherwise, return cached response as is
|
187 |
+
return cached_response
|
188 |
+
except Exception as e:
|
189 |
+
logger.error(f"Error parsing cached response: {e}")
|
190 |
+
# Continue processing if cache parsing fails
|
191 |
+
|
192 |
+
# Save user message first (so it's available for user history)
|
193 |
+
session_id = request.session_id or f"{request.user_id}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
|
194 |
+
logger.info(f"Processing chat request for user {request.user_id}, session {session_id}")
|
195 |
+
|
196 |
+
# First, save the user's message so it's available for history lookups
|
197 |
+
try:
|
198 |
+
# Save user's question
|
199 |
+
save_session(
|
200 |
+
session_id=session_id,
|
201 |
+
factor="user",
|
202 |
+
action="asking_freely",
|
203 |
+
first_name=getattr(request, 'first_name', "User"),
|
204 |
+
last_name=getattr(request, 'last_name', ""),
|
205 |
+
message=request.question,
|
206 |
+
user_id=request.user_id,
|
207 |
+
username=getattr(request, 'username', ""),
|
208 |
+
response=None # No response yet
|
209 |
+
)
|
210 |
+
logger.info(f"User message saved for session {session_id}")
|
211 |
+
except Exception as e:
|
212 |
+
logger.error(f"Error saving user message to session: {e}")
|
213 |
+
# Continue processing even if saving fails
|
214 |
+
|
215 |
+
# Use the RAG pipeline
|
216 |
+
if request.use_rag:
|
217 |
+
# Get the retriever with custom parameters
|
218 |
+
retriever = get_chain(
|
219 |
+
top_k=request.similarity_top_k,
|
220 |
+
limit_k=request.limit_k,
|
221 |
+
similarity_metric=request.similarity_metric,
|
222 |
+
similarity_threshold=request.similarity_threshold
|
223 |
+
)
|
224 |
+
if not retriever:
|
225 |
+
raise HTTPException(status_code=500, detail="Failed to initialize retriever")
|
226 |
+
|
227 |
+
# Get request history for context
|
228 |
+
context_query = get_request_history(request.user_id) if request.include_history else request.question
|
229 |
+
logger.info(f"Using context query for retrieval: {context_query[:100]}...")
|
230 |
+
|
231 |
+
# Retrieve relevant documents
|
232 |
+
retrieved_docs = retriever.invoke(context_query)
|
233 |
+
context = "\n".join([doc.page_content for doc in retrieved_docs])
|
234 |
+
|
235 |
+
# Prepare sources
|
236 |
+
sources = []
|
237 |
+
for doc in retrieved_docs:
|
238 |
+
source = None
|
239 |
+
metadata = {}
|
240 |
+
|
241 |
+
if hasattr(doc, 'metadata'):
|
242 |
+
source = doc.metadata.get('source', None)
|
243 |
+
# Extract score information
|
244 |
+
score = doc.metadata.get('score', None)
|
245 |
+
normalized_score = doc.metadata.get('normalized_score', None)
|
246 |
+
# Remove score info from metadata to avoid duplication
|
247 |
+
metadata = {k: v for k, v in doc.metadata.items()
|
248 |
+
if k not in ['text', 'source', 'score', 'normalized_score']}
|
249 |
+
|
250 |
+
sources.append(SourceDocument(
|
251 |
+
text=doc.page_content,
|
252 |
+
source=source,
|
253 |
+
score=score,
|
254 |
+
normalized_score=normalized_score,
|
255 |
+
metadata=metadata
|
256 |
+
))
|
257 |
+
else:
|
258 |
+
# No RAG
|
259 |
+
context = ""
|
260 |
+
sources = None
|
261 |
+
|
262 |
+
# Get chat history
|
263 |
+
chat_history = get_chat_history(request.user_id) if request.include_history else ""
|
264 |
+
logger.info(f"Using chat history: {chat_history[:100]}...")
|
265 |
+
|
266 |
+
# Initialize Gemini model
|
267 |
+
generation_config = {
|
268 |
+
"temperature": 0.9,
|
269 |
+
"top_p": 1,
|
270 |
+
"top_k": 1,
|
271 |
+
"max_output_tokens": 2048,
|
272 |
+
}
|
273 |
+
|
274 |
+
safety_settings = [
|
275 |
+
{
|
276 |
+
"category": "HARM_CATEGORY_HARASSMENT",
|
277 |
+
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
|
278 |
+
},
|
279 |
+
{
|
280 |
+
"category": "HARM_CATEGORY_HATE_SPEECH",
|
281 |
+
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
|
282 |
+
},
|
283 |
+
{
|
284 |
+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
285 |
+
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
|
286 |
+
},
|
287 |
+
{
|
288 |
+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
289 |
+
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
|
290 |
+
},
|
291 |
+
]
|
292 |
+
|
293 |
+
model = genai.GenerativeModel(
|
294 |
+
model_name='models/gemini-2.0-flash',
|
295 |
+
generation_config=generation_config,
|
296 |
+
safety_settings=safety_settings
|
297 |
+
)
|
298 |
+
|
299 |
+
# Generate the prompt using template
|
300 |
+
prompt_text = prompt.format(
|
301 |
+
context=context,
|
302 |
+
question=request.question,
|
303 |
+
chat_history=chat_history
|
304 |
+
)
|
305 |
+
logger.info(f"Full prompt with history and context: {prompt_text}")
|
306 |
+
|
307 |
+
# Generate response
|
308 |
+
response = model.generate_content(prompt_text)
|
309 |
+
answer = response.text
|
310 |
+
|
311 |
+
# Save the RAG response
|
312 |
+
try:
|
313 |
+
# Now save the RAG response with the same session_id
|
314 |
+
save_session(
|
315 |
+
session_id=session_id,
|
316 |
+
factor="rag",
|
317 |
+
action="RAG_response",
|
318 |
+
first_name=getattr(request, 'first_name', "User"),
|
319 |
+
last_name=getattr(request, 'last_name', ""),
|
320 |
+
message=request.question,
|
321 |
+
user_id=request.user_id,
|
322 |
+
username=getattr(request, 'username', ""),
|
323 |
+
response=answer
|
324 |
+
)
|
325 |
+
logger.info(f"RAG response saved for session {session_id}")
|
326 |
+
|
327 |
+
# Check if the response starts with "I don't know" and trigger notification
|
328 |
+
if answer.strip().lower().startswith("i don't know"):
|
329 |
+
from app.api.websocket_routes import send_notification
|
330 |
+
notification_data = {
|
331 |
+
"session_id": session_id,
|
332 |
+
"factor": "rag",
|
333 |
+
"action": "RAG_response",
|
334 |
+
"message": request.question,
|
335 |
+
"user_id": request.user_id,
|
336 |
+
"username": getattr(request, 'username', ""),
|
337 |
+
"first_name": getattr(request, 'first_name', "User"),
|
338 |
+
"last_name": getattr(request, 'last_name', ""),
|
339 |
+
"response": answer,
|
340 |
+
"created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
341 |
+
}
|
342 |
+
background_tasks.add_task(send_notification, notification_data)
|
343 |
+
logger.info(f"Notification queued for session {session_id} - response starts with 'I don't know'")
|
344 |
+
except Exception as e:
|
345 |
+
logger.error(f"Error saving RAG response to session: {e}")
|
346 |
+
# Continue processing even if saving fails
|
347 |
+
|
348 |
+
# Calculate processing time
|
349 |
+
processing_time = time.time() - start_time
|
350 |
+
|
351 |
+
# Create internal response object with sources for logging
|
352 |
+
internal_response = ChatResponseInternal(
|
353 |
+
answer=answer,
|
354 |
+
sources=sources,
|
355 |
+
processing_time=processing_time
|
356 |
+
)
|
357 |
+
|
358 |
+
# Log full response with sources
|
359 |
+
logger.info(f"Generated response for user {request.user_id}: {answer}")
|
360 |
+
if sources:
|
361 |
+
logger.info(f"Sources used: {len(sources)} documents")
|
362 |
+
for i, source in enumerate(sources):
|
363 |
+
logger.info(f"Source {i+1}: {source.source or 'Unknown'} (score: {source.score})")
|
364 |
+
|
365 |
+
# Create response object for API (without sources)
|
366 |
+
chat_response = ChatResponse(
|
367 |
+
answer=answer,
|
368 |
+
processing_time=processing_time
|
369 |
+
)
|
370 |
+
|
371 |
+
# Cache result using redis_client instead of cache
|
372 |
+
try:
|
373 |
+
# Convert to JSON to ensure it can be cached
|
374 |
+
cache_data = {
|
375 |
+
"answer": answer,
|
376 |
+
"processing_time": processing_time
|
377 |
+
}
|
378 |
+
await redis_client.set(cache_key, json.dumps(cache_data), ex=300)
|
379 |
+
except Exception as e:
|
380 |
+
logger.error(f"Error caching response: {e}")
|
381 |
+
# Continue even if caching fails
|
382 |
+
|
383 |
+
# Return response
|
384 |
+
return chat_response
|
385 |
+
except Exception as e:
|
386 |
+
logger.error(f"Error processing chat request: {e}")
|
387 |
+
import traceback
|
388 |
+
logger.error(traceback.format_exc())
|
389 |
+
raise HTTPException(status_code=500, detail=f"Failed to process chat request: {str(e)}")
|
390 |
+
|
391 |
+
# Health check endpoint
|
392 |
+
@router.get("/health")
|
393 |
+
async def health_check():
|
394 |
+
"""
|
395 |
+
Check health of RAG services and retrieval system.
|
396 |
+
|
397 |
+
Returns:
|
398 |
+
- status: "healthy" if all services are working, "degraded" otherwise
|
399 |
+
- services: Status of each service (gemini, pinecone)
|
400 |
+
- retrieval_config: Current retrieval configuration
|
401 |
+
- timestamp: Current time
|
402 |
+
"""
|
403 |
+
services = {
|
404 |
+
"gemini": False,
|
405 |
+
"pinecone": False
|
406 |
+
}
|
407 |
+
|
408 |
+
# Check Gemini
|
409 |
+
try:
|
410 |
+
# Initialize simple model
|
411 |
+
model = genai.GenerativeModel("gemini-2.0-flash")
|
412 |
+
# Test generation
|
413 |
+
response = model.generate_content("Hello")
|
414 |
+
services["gemini"] = True
|
415 |
+
except Exception as e:
|
416 |
+
logger.error(f"Gemini health check failed: {e}")
|
417 |
+
|
418 |
+
# Check Pinecone
|
419 |
+
try:
|
420 |
+
# Import pinecone function
|
421 |
+
from app.database.pinecone import get_pinecone_index
|
422 |
+
# Get index
|
423 |
+
index = get_pinecone_index()
|
424 |
+
# Check if index exists
|
425 |
+
if index:
|
426 |
+
services["pinecone"] = True
|
427 |
+
except Exception as e:
|
428 |
+
logger.error(f"Pinecone health check failed: {e}")
|
429 |
+
|
430 |
+
# Get retrieval configuration
|
431 |
+
retrieval_config = {
|
432 |
+
"default_top_k": DEFAULT_TOP_K,
|
433 |
+
"default_limit_k": DEFAULT_LIMIT_K,
|
434 |
+
"default_similarity_metric": DEFAULT_SIMILARITY_METRIC,
|
435 |
+
"default_similarity_threshold": DEFAULT_SIMILARITY_THRESHOLD,
|
436 |
+
"allowed_metrics": ALLOWED_METRICS
|
437 |
+
}
|
438 |
+
|
439 |
+
# Return health status
|
440 |
+
status = "healthy" if all(services.values()) else "degraded"
|
441 |
+
return {
|
442 |
+
"status": status,
|
443 |
+
"services": services,
|
444 |
+
"retrieval_config": retrieval_config,
|
445 |
+
"timestamp": datetime.now().isoformat()
|
446 |
+
}
|
447 |
+
|
448 |
+
@router.post("/rag")
|
449 |
+
async def process_rag(request: Request, user_data: UserMessageModel, background_tasks: BackgroundTasks):
|
450 |
+
"""
|
451 |
+
Process a user message through the RAG pipeline and return a response.
|
452 |
+
|
453 |
+
Parameters:
|
454 |
+
- **user_id**: User ID from the client application
|
455 |
+
- **session_id**: Session ID for tracking the conversation
|
456 |
+
- **message**: User's message/question
|
457 |
+
- **similarity_top_k**: (Optional) Number of top similar documents to return after filtering
|
458 |
+
- **limit_k**: (Optional) Maximum number of documents to retrieve from vector store
|
459 |
+
- **similarity_metric**: (Optional) Similarity metric to use (cosine, dotproduct, euclidean)
|
460 |
+
- **similarity_threshold**: (Optional) Threshold for vector similarity (0-1)
|
461 |
+
"""
|
462 |
+
try:
|
463 |
+
# Extract request data
|
464 |
+
user_id = user_data.user_id
|
465 |
+
session_id = user_data.session_id
|
466 |
+
message = user_data.message
|
467 |
+
|
468 |
+
# Extract retrieval parameters (use defaults if not provided)
|
469 |
+
top_k = user_data.similarity_top_k or DEFAULT_TOP_K
|
470 |
+
limit_k = user_data.limit_k or DEFAULT_LIMIT_K
|
471 |
+
similarity_metric = user_data.similarity_metric or DEFAULT_SIMILARITY_METRIC
|
472 |
+
similarity_threshold = user_data.similarity_threshold or DEFAULT_SIMILARITY_THRESHOLD
|
473 |
+
|
474 |
+
logger.info(f"RAG request received for user_id={user_id}, session_id={session_id}")
|
475 |
+
logger.info(f"Message: {message[:100]}..." if len(message) > 100 else f"Message: {message}")
|
476 |
+
logger.info(f"Retrieval parameters: top_k={top_k}, limit_k={limit_k}, metric={similarity_metric}, threshold={similarity_threshold}")
|
477 |
+
|
478 |
+
# Create a cache key for this request to avoid reprocessing identical questions
|
479 |
+
cache_key = f"rag_{user_id}_{session_id}_{hashlib.md5(message.encode()).hexdigest()}_{top_k}_{limit_k}_{similarity_metric}_{similarity_threshold}"
|
480 |
+
|
481 |
+
# Check if we have this response cached
|
482 |
+
cached_result = await redis_client.get(cache_key)
|
483 |
+
if cached_result:
|
484 |
+
logger.info(f"Cache hit for key: {cache_key}")
|
485 |
+
if isinstance(cached_result, str): # If stored as JSON string
|
486 |
+
return json.loads(cached_result)
|
487 |
+
return cached_result
|
488 |
+
|
489 |
+
# Save user message to MongoDB
|
490 |
+
try:
|
491 |
+
# Save user's question
|
492 |
+
save_session(
|
493 |
+
session_id=session_id,
|
494 |
+
factor="user",
|
495 |
+
action="asking_freely",
|
496 |
+
first_name="User", # You can update this with actual data if available
|
497 |
+
last_name="",
|
498 |
+
message=message,
|
499 |
+
user_id=user_id,
|
500 |
+
username="",
|
501 |
+
response=None # No response yet
|
502 |
+
)
|
503 |
+
logger.info(f"User message saved to MongoDB with session_id: {session_id}")
|
504 |
+
except Exception as e:
|
505 |
+
logger.error(f"Error saving user message: {e}")
|
506 |
+
# Continue anyway to try to get a response
|
507 |
+
|
508 |
+
# Create a ChatRequest object to reuse the existing chat endpoint
|
509 |
+
chat_request = ChatRequest(
|
510 |
+
user_id=user_id,
|
511 |
+
question=message,
|
512 |
+
include_history=True,
|
513 |
+
use_rag=True,
|
514 |
+
similarity_top_k=top_k,
|
515 |
+
limit_k=limit_k,
|
516 |
+
similarity_metric=similarity_metric,
|
517 |
+
similarity_threshold=similarity_threshold,
|
518 |
+
session_id=session_id
|
519 |
+
)
|
520 |
+
|
521 |
+
# Process through the chat endpoint
|
522 |
+
response = await chat(chat_request, background_tasks)
|
523 |
+
|
524 |
+
# Cache the response
|
525 |
+
try:
|
526 |
+
await redis_client.set(cache_key, json.dumps({
|
527 |
+
"answer": response.answer,
|
528 |
+
"processing_time": response.processing_time
|
529 |
+
}))
|
530 |
+
logger.info(f"Cached response for key: {cache_key}")
|
531 |
+
except Exception as e:
|
532 |
+
logger.error(f"Failed to cache response: {e}")
|
533 |
+
|
534 |
+
return response
|
535 |
+
except Exception as e:
|
536 |
+
logger.error(f"Error processing RAG request: {e}")
|
537 |
+
logger.error(traceback.format_exc())
|
538 |
+
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
|
app/api/websocket_routes.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, status
|
2 |
+
from typing import List, Dict
|
3 |
+
import logging
|
4 |
+
from datetime import datetime
|
5 |
+
import asyncio
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
from app.database.mongodb import session_collection
|
10 |
+
from app.utils.utils import get_local_time
|
11 |
+
|
12 |
+
# Load environment variables
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
+
# Get WebSocket configuration from environment variables
|
16 |
+
WEBSOCKET_SERVER = os.getenv("WEBSOCKET_SERVER", "localhost")
|
17 |
+
WEBSOCKET_PORT = os.getenv("WEBSOCKET_PORT", "7860")
|
18 |
+
WEBSOCKET_PATH = os.getenv("WEBSOCKET_PATH", "/notify")
|
19 |
+
|
20 |
+
# Configure logging
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
# Create router
|
24 |
+
router = APIRouter(
|
25 |
+
tags=["WebSocket"],
|
26 |
+
)
|
27 |
+
|
28 |
+
# Store active WebSocket connections
|
29 |
+
class ConnectionManager:
|
30 |
+
def __init__(self):
|
31 |
+
self.active_connections: List[WebSocket] = []
|
32 |
+
|
33 |
+
async def connect(self, websocket: WebSocket):
|
34 |
+
await websocket.accept()
|
35 |
+
self.active_connections.append(websocket)
|
36 |
+
client_info = f"{websocket.client.host}:{websocket.client.port}" if hasattr(websocket, 'client') else "Unknown"
|
37 |
+
logger.info(f"New WebSocket connection from {client_info}. Total connections: {len(self.active_connections)}")
|
38 |
+
|
39 |
+
def disconnect(self, websocket: WebSocket):
|
40 |
+
self.active_connections.remove(websocket)
|
41 |
+
logger.info(f"WebSocket connection removed. Total connections: {len(self.active_connections)}")
|
42 |
+
|
43 |
+
async def broadcast(self, message: Dict):
|
44 |
+
if not self.active_connections:
|
45 |
+
logger.warning("No active WebSocket connections to broadcast to")
|
46 |
+
return
|
47 |
+
|
48 |
+
disconnected = []
|
49 |
+
for connection in self.active_connections:
|
50 |
+
try:
|
51 |
+
await connection.send_json(message)
|
52 |
+
logger.info(f"Message sent to WebSocket connection")
|
53 |
+
except Exception as e:
|
54 |
+
logger.error(f"Error sending message to WebSocket: {e}")
|
55 |
+
disconnected.append(connection)
|
56 |
+
|
57 |
+
# Remove disconnected connections
|
58 |
+
for conn in disconnected:
|
59 |
+
if conn in self.active_connections:
|
60 |
+
self.active_connections.remove(conn)
|
61 |
+
logger.info(f"Removed disconnected WebSocket. Remaining: {len(self.active_connections)}")
|
62 |
+
|
63 |
+
# Initialize connection manager
|
64 |
+
manager = ConnectionManager()
|
65 |
+
|
66 |
+
# Create full URL of WebSocket server from environment variables
|
67 |
+
def get_full_websocket_url(server_side=False):
|
68 |
+
if server_side:
|
69 |
+
# Relative URL (for server side)
|
70 |
+
return WEBSOCKET_PATH
|
71 |
+
else:
|
72 |
+
# Full URL (for client)
|
73 |
+
# Check if should use wss:// for HTTPS
|
74 |
+
is_https = True if int(WEBSOCKET_PORT) == 443 else False
|
75 |
+
protocol = "wss" if is_https else "ws"
|
76 |
+
|
77 |
+
# If using default port for protocol, don't include in URL
|
78 |
+
if (is_https and int(WEBSOCKET_PORT) == 443) or (not is_https and int(WEBSOCKET_PORT) == 80):
|
79 |
+
return f"{protocol}://{WEBSOCKET_SERVER}{WEBSOCKET_PATH}"
|
80 |
+
else:
|
81 |
+
return f"{protocol}://{WEBSOCKET_SERVER}:{WEBSOCKET_PORT}{WEBSOCKET_PATH}"
|
82 |
+
|
83 |
+
# Add GET endpoint to display WebSocket information in Swagger
|
84 |
+
@router.get("/notify",
|
85 |
+
summary="WebSocket notifications for Admin Bot",
|
86 |
+
description=f"""
|
87 |
+
This is documentation for the WebSocket endpoint.
|
88 |
+
|
89 |
+
To connect to WebSocket:
|
90 |
+
1. Use the path `{get_full_websocket_url()}`
|
91 |
+
2. Connect using a WebSocket client library
|
92 |
+
3. When there are new sessions requiring attention, you will receive notifications through this connection
|
93 |
+
|
94 |
+
Notifications are sent when:
|
95 |
+
- Session response starts with "I don't know"
|
96 |
+
- The system cannot answer the user's question
|
97 |
+
|
98 |
+
Make sure to send a "keepalive" message every 5 minutes to maintain the connection.
|
99 |
+
""",
|
100 |
+
status_code=status.HTTP_200_OK
|
101 |
+
)
|
102 |
+
async def websocket_documentation():
|
103 |
+
"""
|
104 |
+
Provides information about how to use the WebSocket endpoint /notify.
|
105 |
+
This endpoint is for documentation purposes only. To use WebSocket, please connect to the WebSocket URL.
|
106 |
+
"""
|
107 |
+
ws_url = get_full_websocket_url()
|
108 |
+
return {
|
109 |
+
"websocket_endpoint": WEBSOCKET_PATH,
|
110 |
+
"connection_type": "WebSocket",
|
111 |
+
"protocol": "ws://",
|
112 |
+
"server": WEBSOCKET_SERVER,
|
113 |
+
"port": WEBSOCKET_PORT,
|
114 |
+
"full_url": ws_url,
|
115 |
+
"description": "Endpoint to receive notifications about new sessions requiring attention",
|
116 |
+
"notification_format": {
|
117 |
+
"type": "new_session",
|
118 |
+
"timestamp": "YYYY-MM-DD HH:MM:SS",
|
119 |
+
"data": {
|
120 |
+
"session_id": "session id",
|
121 |
+
"factor": "user",
|
122 |
+
"action": "action type",
|
123 |
+
"message": "User question",
|
124 |
+
"response": "I don't know...",
|
125 |
+
"user_id": "user id",
|
126 |
+
"first_name": "user's first name",
|
127 |
+
"last_name": "user's last name",
|
128 |
+
"username": "username",
|
129 |
+
"created_at": "creation time"
|
130 |
+
}
|
131 |
+
},
|
132 |
+
"client_example": """
|
133 |
+
import websocket
|
134 |
+
import json
|
135 |
+
import os
|
136 |
+
import time
|
137 |
+
import threading
|
138 |
+
from dotenv import load_dotenv
|
139 |
+
|
140 |
+
# Load environment variables
|
141 |
+
load_dotenv()
|
142 |
+
|
143 |
+
# Get WebSocket configuration from environment variables
|
144 |
+
WEBSOCKET_SERVER = os.getenv("WEBSOCKET_SERVER", "localhost")
|
145 |
+
WEBSOCKET_PORT = os.getenv("WEBSOCKET_PORT", "7860")
|
146 |
+
WEBSOCKET_PATH = os.getenv("WEBSOCKET_PATH", "/notify")
|
147 |
+
|
148 |
+
# Create full URL
|
149 |
+
ws_url = f"ws://{WEBSOCKET_SERVER}:{WEBSOCKET_PORT}{WEBSOCKET_PATH}"
|
150 |
+
|
151 |
+
# If using HTTPS, replace ws:// with wss://
|
152 |
+
# ws_url = f"wss://{WEBSOCKET_SERVER}{WEBSOCKET_PATH}"
|
153 |
+
|
154 |
+
# Send keepalive periodically
|
155 |
+
def send_keepalive(ws):
|
156 |
+
while True:
|
157 |
+
try:
|
158 |
+
if ws.sock and ws.sock.connected:
|
159 |
+
ws.send("keepalive")
|
160 |
+
print("Sent keepalive message")
|
161 |
+
time.sleep(300) # 5 minutes
|
162 |
+
except Exception as e:
|
163 |
+
print(f"Error sending keepalive: {e}")
|
164 |
+
time.sleep(60)
|
165 |
+
|
166 |
+
def on_message(ws, message):
|
167 |
+
try:
|
168 |
+
data = json.loads(message)
|
169 |
+
print(f"Received notification: {data}")
|
170 |
+
# Process notification, e.g.: send to Telegram Admin
|
171 |
+
if data.get("type") == "new_session":
|
172 |
+
session_data = data.get("data", {})
|
173 |
+
user_question = session_data.get("message", "")
|
174 |
+
user_name = session_data.get("first_name", "Unknown User")
|
175 |
+
print(f"User {user_name} asked: {user_question}")
|
176 |
+
# Code to send message to Telegram Admin
|
177 |
+
except json.JSONDecodeError:
|
178 |
+
print(f"Received non-JSON message: {message}")
|
179 |
+
except Exception as e:
|
180 |
+
print(f"Error processing message: {e}")
|
181 |
+
|
182 |
+
def on_error(ws, error):
|
183 |
+
print(f"WebSocket error: {error}")
|
184 |
+
|
185 |
+
def on_close(ws, close_status_code, close_msg):
|
186 |
+
print(f"WebSocket connection closed: code={close_status_code}, message={close_msg}")
|
187 |
+
|
188 |
+
def on_open(ws):
|
189 |
+
print(f"WebSocket connection opened to {ws_url}")
|
190 |
+
# Send keepalive messages periodically in a separate thread
|
191 |
+
keepalive_thread = threading.Thread(target=send_keepalive, args=(ws,), daemon=True)
|
192 |
+
keepalive_thread.start()
|
193 |
+
|
194 |
+
def run_forever_with_reconnect():
|
195 |
+
while True:
|
196 |
+
try:
|
197 |
+
# Connect WebSocket with ping to maintain connection
|
198 |
+
ws = websocket.WebSocketApp(
|
199 |
+
ws_url,
|
200 |
+
on_open=on_open,
|
201 |
+
on_message=on_message,
|
202 |
+
on_error=on_error,
|
203 |
+
on_close=on_close
|
204 |
+
)
|
205 |
+
ws.run_forever(ping_interval=60, ping_timeout=30)
|
206 |
+
print("WebSocket connection lost, reconnecting in 5 seconds...")
|
207 |
+
time.sleep(5)
|
208 |
+
except Exception as e:
|
209 |
+
print(f"WebSocket connection error: {e}")
|
210 |
+
time.sleep(5)
|
211 |
+
|
212 |
+
# Start WebSocket client in a separate thread
|
213 |
+
websocket_thread = threading.Thread(target=run_forever_with_reconnect, daemon=True)
|
214 |
+
websocket_thread.start()
|
215 |
+
|
216 |
+
# Keep the program running
|
217 |
+
try:
|
218 |
+
while True:
|
219 |
+
time.sleep(1)
|
220 |
+
except KeyboardInterrupt:
|
221 |
+
print("Stopping WebSocket client...")
|
222 |
+
"""
|
223 |
+
}
|
224 |
+
|
225 |
+
@router.websocket("/notify")
|
226 |
+
async def websocket_endpoint(websocket: WebSocket):
|
227 |
+
"""
|
228 |
+
WebSocket endpoint to receive notifications about new sessions.
|
229 |
+
Admin Bot will connect to this endpoint to receive notifications when there are new sessions requiring attention.
|
230 |
+
"""
|
231 |
+
await manager.connect(websocket)
|
232 |
+
try:
|
233 |
+
while True:
|
234 |
+
# Maintain WebSocket connection
|
235 |
+
data = await websocket.receive_text()
|
236 |
+
# Echo back to keep connection active
|
237 |
+
await websocket.send_json({"status": "connected", "echo": data, "timestamp": datetime.now().isoformat()})
|
238 |
+
logger.info(f"Received message from WebSocket: {data}")
|
239 |
+
except WebSocketDisconnect:
|
240 |
+
logger.info("WebSocket client disconnected")
|
241 |
+
manager.disconnect(websocket)
|
242 |
+
except Exception as e:
|
243 |
+
logger.error(f"WebSocket error: {e}")
|
244 |
+
manager.disconnect(websocket)
|
245 |
+
|
246 |
+
# Function to send notifications over WebSocket
|
247 |
+
async def send_notification(data: dict):
|
248 |
+
"""
|
249 |
+
Send notification to all active WebSocket connections.
|
250 |
+
|
251 |
+
This function is used to notify admin bots about new issues or questions that need attention.
|
252 |
+
It's triggered when the system cannot answer a user's question (response starts with "I don't know").
|
253 |
+
|
254 |
+
Args:
|
255 |
+
data: The data to send as notification
|
256 |
+
"""
|
257 |
+
try:
|
258 |
+
# Log number of active connections and notification attempt
|
259 |
+
logger.info(f"Attempting to send notification. Active connections: {len(manager.active_connections)}")
|
260 |
+
logger.info(f"Notification data: session_id={data.get('session_id')}, user_id={data.get('user_id')}")
|
261 |
+
logger.info(f"Response: {data.get('response', '')[:50]}...")
|
262 |
+
|
263 |
+
# Check if the response starts with "I don't know"
|
264 |
+
response = data.get('response', '')
|
265 |
+
if not response or not isinstance(response, str):
|
266 |
+
logger.warning(f"Invalid response format in notification data: {response}")
|
267 |
+
return
|
268 |
+
|
269 |
+
if not response.strip().lower().startswith("i don't know"):
|
270 |
+
logger.info(f"Response doesn't start with 'I don't know', notification not needed: {response[:50]}...")
|
271 |
+
return
|
272 |
+
|
273 |
+
logger.info(f"Response starts with 'I don't know', sending notification")
|
274 |
+
|
275 |
+
# Format the notification data for admin
|
276 |
+
notification_data = {
|
277 |
+
"type": "new_session",
|
278 |
+
"timestamp": get_local_time(),
|
279 |
+
"data": {
|
280 |
+
"session_id": data.get('session_id', 'unknown'),
|
281 |
+
"user_id": data.get('user_id', 'unknown'),
|
282 |
+
"message": data.get('message', ''),
|
283 |
+
"response": response,
|
284 |
+
"first_name": data.get('first_name', 'User'),
|
285 |
+
"last_name": data.get('last_name', ''),
|
286 |
+
"username": data.get('username', ''),
|
287 |
+
"created_at": data.get('created_at', get_local_time()),
|
288 |
+
"action": data.get('action', 'unknown'),
|
289 |
+
"factor": "user" # Always show as user for better readability
|
290 |
+
}
|
291 |
+
}
|
292 |
+
|
293 |
+
# Check if there are active connections
|
294 |
+
if not manager.active_connections:
|
295 |
+
logger.warning("No active WebSocket connections for notification broadcast")
|
296 |
+
return
|
297 |
+
|
298 |
+
# Broadcast notification to all active connections
|
299 |
+
logger.info(f"Broadcasting notification to {len(manager.active_connections)} connections")
|
300 |
+
await manager.broadcast(notification_data)
|
301 |
+
logger.info("Notification broadcast completed successfully")
|
302 |
+
|
303 |
+
except Exception as e:
|
304 |
+
logger.error(f"Error sending notification: {e}")
|
305 |
+
import traceback
|
306 |
+
logger.error(traceback.format_exc())
|
app/database/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Database connections package
|
app/database/models.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, Float, Text, LargeBinary, JSON
|
2 |
+
from sqlalchemy.sql import func
|
3 |
+
from sqlalchemy.orm import relationship
|
4 |
+
from .postgresql import Base
|
5 |
+
import datetime
|
6 |
+
|
7 |
+
class FAQItem(Base):
|
8 |
+
__tablename__ = "faq_item"
|
9 |
+
|
10 |
+
id = Column(Integer, primary_key=True, index=True)
|
11 |
+
question = Column(String, nullable=False)
|
12 |
+
answer = Column(String, nullable=False)
|
13 |
+
is_active = Column(Boolean, default=True)
|
14 |
+
created_at = Column(DateTime, server_default=func.now())
|
15 |
+
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
|
16 |
+
|
17 |
+
class EmergencyItem(Base):
|
18 |
+
__tablename__ = "emergency_item"
|
19 |
+
|
20 |
+
id = Column(Integer, primary_key=True, index=True)
|
21 |
+
name = Column(String, nullable=False)
|
22 |
+
phone_number = Column(String, nullable=False)
|
23 |
+
description = Column(String, nullable=True)
|
24 |
+
address = Column(String, nullable=True)
|
25 |
+
location = Column(String, nullable=True) # Will be converted to/from PostGIS POINT type
|
26 |
+
priority = Column(Integer, default=0)
|
27 |
+
is_active = Column(Boolean, default=True)
|
28 |
+
created_at = Column(DateTime, server_default=func.now())
|
29 |
+
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
|
30 |
+
|
31 |
+
class EventItem(Base):
|
32 |
+
__tablename__ = "event_item"
|
33 |
+
|
34 |
+
id = Column(Integer, primary_key=True, index=True)
|
35 |
+
name = Column(String, nullable=False)
|
36 |
+
description = Column(Text, nullable=False)
|
37 |
+
address = Column(String, nullable=False)
|
38 |
+
location = Column(String, nullable=True) # Will be converted to/from PostGIS POINT type
|
39 |
+
date_start = Column(DateTime, nullable=False)
|
40 |
+
date_end = Column(DateTime, nullable=True)
|
41 |
+
price = Column(JSON, nullable=True)
|
42 |
+
is_active = Column(Boolean, default=True)
|
43 |
+
featured = Column(Boolean, default=False)
|
44 |
+
created_at = Column(DateTime, server_default=func.now())
|
45 |
+
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
|
46 |
+
|
47 |
+
class VectorDatabase(Base):
|
48 |
+
__tablename__ = "vector_database"
|
49 |
+
|
50 |
+
id = Column(Integer, primary_key=True, index=True)
|
51 |
+
name = Column(String, nullable=False, unique=True)
|
52 |
+
description = Column(String, nullable=True)
|
53 |
+
pinecone_index = Column(String, nullable=False)
|
54 |
+
api_key = Column(String, nullable=False)
|
55 |
+
status = Column(String, default="active")
|
56 |
+
created_at = Column(DateTime, server_default=func.now())
|
57 |
+
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
|
58 |
+
|
59 |
+
# Relationships
|
60 |
+
documents = relationship("Document", back_populates="vector_database")
|
61 |
+
vector_statuses = relationship("VectorStatus", back_populates="vector_database")
|
62 |
+
engine_associations = relationship("EngineVectorDb", back_populates="vector_database")
|
63 |
+
|
64 |
+
class Document(Base):
|
65 |
+
__tablename__ = "document"
|
66 |
+
|
67 |
+
id = Column(Integer, primary_key=True, index=True)
|
68 |
+
name = Column(String, nullable=False)
|
69 |
+
file_content = Column(LargeBinary, nullable=True)
|
70 |
+
file_type = Column(String, nullable=True)
|
71 |
+
size = Column(Integer, nullable=True)
|
72 |
+
content_type = Column(String, nullable=True)
|
73 |
+
is_embedded = Column(Boolean, default=False)
|
74 |
+
file_metadata = Column(JSON, nullable=True)
|
75 |
+
vector_database_id = Column(Integer, ForeignKey("vector_database.id"), nullable=False)
|
76 |
+
created_at = Column(DateTime, server_default=func.now())
|
77 |
+
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
|
78 |
+
|
79 |
+
# Relationships
|
80 |
+
vector_database = relationship("VectorDatabase", back_populates="documents")
|
81 |
+
vector_statuses = relationship("VectorStatus", back_populates="document")
|
82 |
+
|
83 |
+
class VectorStatus(Base):
|
84 |
+
__tablename__ = "vector_status"
|
85 |
+
|
86 |
+
id = Column(Integer, primary_key=True, index=True)
|
87 |
+
document_id = Column(Integer, ForeignKey("document.id"), nullable=False)
|
88 |
+
vector_database_id = Column(Integer, ForeignKey("vector_database.id"), nullable=False)
|
89 |
+
vector_id = Column(String, nullable=True)
|
90 |
+
status = Column(String, default="pending")
|
91 |
+
error_message = Column(String, nullable=True)
|
92 |
+
embedded_at = Column(DateTime, nullable=True)
|
93 |
+
|
94 |
+
# Relationships
|
95 |
+
document = relationship("Document", back_populates="vector_statuses")
|
96 |
+
vector_database = relationship("VectorDatabase", back_populates="vector_statuses")
|
97 |
+
|
98 |
+
class TelegramBot(Base):
|
99 |
+
__tablename__ = "telegram_bot"
|
100 |
+
|
101 |
+
id = Column(Integer, primary_key=True, index=True)
|
102 |
+
name = Column(String, nullable=False)
|
103 |
+
username = Column(String, nullable=False, unique=True)
|
104 |
+
token = Column(String, nullable=False)
|
105 |
+
status = Column(String, default="inactive")
|
106 |
+
created_at = Column(DateTime, server_default=func.now())
|
107 |
+
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
|
108 |
+
|
109 |
+
# Relationships
|
110 |
+
bot_engines = relationship("BotEngine", back_populates="bot")
|
111 |
+
|
112 |
+
class ChatEngine(Base):
|
113 |
+
__tablename__ = "chat_engine"
|
114 |
+
|
115 |
+
id = Column(Integer, primary_key=True, index=True)
|
116 |
+
name = Column(String, nullable=False)
|
117 |
+
answer_model = Column(String, nullable=False)
|
118 |
+
system_prompt = Column(Text, nullable=True)
|
119 |
+
empty_response = Column(String, nullable=True)
|
120 |
+
similarity_top_k = Column(Integer, default=3)
|
121 |
+
vector_distance_threshold = Column(Float, default=0.75)
|
122 |
+
grounding_threshold = Column(Float, default=0.2)
|
123 |
+
use_public_information = Column(Boolean, default=False)
|
124 |
+
status = Column(String, default="active")
|
125 |
+
created_at = Column(DateTime, server_default=func.now())
|
126 |
+
last_modified = Column(DateTime, server_default=func.now(), onupdate=func.now())
|
127 |
+
|
128 |
+
# Relationships
|
129 |
+
bot_engines = relationship("BotEngine", back_populates="engine")
|
130 |
+
engine_vector_dbs = relationship("EngineVectorDb", back_populates="engine")
|
131 |
+
|
132 |
+
class BotEngine(Base):
|
133 |
+
__tablename__ = "bot_engine"
|
134 |
+
|
135 |
+
id = Column(Integer, primary_key=True, index=True)
|
136 |
+
bot_id = Column(Integer, ForeignKey("telegram_bot.id"), nullable=False)
|
137 |
+
engine_id = Column(Integer, ForeignKey("chat_engine.id"), nullable=False)
|
138 |
+
created_at = Column(DateTime, server_default=func.now())
|
139 |
+
|
140 |
+
# Relationships
|
141 |
+
bot = relationship("TelegramBot", back_populates="bot_engines")
|
142 |
+
engine = relationship("ChatEngine", back_populates="bot_engines")
|
143 |
+
|
144 |
+
class EngineVectorDb(Base):
|
145 |
+
__tablename__ = "engine_vector_db"
|
146 |
+
|
147 |
+
id = Column(Integer, primary_key=True, index=True)
|
148 |
+
engine_id = Column(Integer, ForeignKey("chat_engine.id"), nullable=False)
|
149 |
+
vector_database_id = Column(Integer, ForeignKey("vector_database.id"), nullable=False)
|
150 |
+
priority = Column(Integer, default=0)
|
151 |
+
|
152 |
+
# Relationships
|
153 |
+
engine = relationship("ChatEngine", back_populates="engine_vector_dbs")
|
154 |
+
vector_database = relationship("VectorDatabase", back_populates="engine_associations")
|
155 |
+
|
156 |
+
class ApiKey(Base):
|
157 |
+
__tablename__ = "api_key"
|
158 |
+
|
159 |
+
id = Column(Integer, primary_key=True, index=True)
|
160 |
+
key = Column(String, nullable=False, unique=True)
|
161 |
+
name = Column(String, nullable=False)
|
162 |
+
description = Column(String, nullable=True)
|
163 |
+
is_active = Column(Boolean, default=True)
|
164 |
+
created_at = Column(DateTime, server_default=func.now())
|
165 |
+
last_used = Column(DateTime, nullable=True)
|
app/database/mongodb.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pymongo import MongoClient
|
3 |
+
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
from datetime import datetime, timedelta
|
6 |
+
import pytz
|
7 |
+
import logging
|
8 |
+
|
9 |
+
# Configure logging
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
# Load environment variables
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
+
# MongoDB connection string from .env
|
16 |
+
MONGODB_URL = os.getenv("MONGODB_URL")
|
17 |
+
DB_NAME = os.getenv("DB_NAME", "Telegram")
|
18 |
+
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "session_chat")
|
19 |
+
|
20 |
+
# Set timeout for MongoDB connection
|
21 |
+
MONGODB_TIMEOUT = int(os.getenv("MONGODB_TIMEOUT", "5000")) # 5 seconds by default
|
22 |
+
|
23 |
+
# Create MongoDB connection with timeout
|
24 |
+
try:
|
25 |
+
client = MongoClient(MONGODB_URL, serverSelectionTimeoutMS=MONGODB_TIMEOUT)
|
26 |
+
db = client[DB_NAME]
|
27 |
+
|
28 |
+
# Collections
|
29 |
+
session_collection = db[COLLECTION_NAME]
|
30 |
+
logger.info(f"MongoDB connection initialized to {DB_NAME}.{COLLECTION_NAME}")
|
31 |
+
|
32 |
+
except Exception as e:
|
33 |
+
logger.error(f"Failed to initialize MongoDB connection: {e}")
|
34 |
+
# Don't raise exception to avoid crash during startup, error handling will be done in functions
|
35 |
+
|
36 |
+
# Check MongoDB connection
|
37 |
+
def check_db_connection():
|
38 |
+
"""Check MongoDB connection"""
|
39 |
+
try:
|
40 |
+
# Issue a ping to confirm a successful connection
|
41 |
+
client.admin.command('ping')
|
42 |
+
logger.info("MongoDB connection is working")
|
43 |
+
return True
|
44 |
+
except (ConnectionFailure, ServerSelectionTimeoutError) as e:
|
45 |
+
logger.error(f"MongoDB connection failed: {e}")
|
46 |
+
return False
|
47 |
+
except Exception as e:
|
48 |
+
logger.error(f"Unknown error when checking MongoDB connection: {e}")
|
49 |
+
return False
|
50 |
+
|
51 |
+
# Timezone for Asia/Ho_Chi_Minh
|
52 |
+
asia_tz = pytz.timezone('Asia/Ho_Chi_Minh')
|
53 |
+
|
54 |
+
def get_local_time():
|
55 |
+
"""Get current time in Asia/Ho_Chi_Minh timezone"""
|
56 |
+
return datetime.now(asia_tz).strftime("%Y-%m-%d %H:%M:%S")
|
57 |
+
|
58 |
+
def get_local_datetime():
|
59 |
+
"""Get current datetime object in Asia/Ho_Chi_Minh timezone"""
|
60 |
+
return datetime.now(asia_tz)
|
61 |
+
|
62 |
+
# For backward compatibility
|
63 |
+
get_vietnam_time = get_local_time
|
64 |
+
get_vietnam_datetime = get_local_datetime
|
65 |
+
|
66 |
+
# Utility functions
|
67 |
+
def save_session(session_id, factor, action, first_name, last_name, message, user_id, username, response=None):
|
68 |
+
"""Save user session to MongoDB"""
|
69 |
+
try:
|
70 |
+
session_data = {
|
71 |
+
"session_id": session_id,
|
72 |
+
"factor": factor,
|
73 |
+
"action": action,
|
74 |
+
"created_at": get_local_time(),
|
75 |
+
"created_at_datetime": get_local_datetime(),
|
76 |
+
"first_name": first_name,
|
77 |
+
"last_name": last_name,
|
78 |
+
"message": message,
|
79 |
+
"user_id": user_id,
|
80 |
+
"username": username,
|
81 |
+
"response": response
|
82 |
+
}
|
83 |
+
result = session_collection.insert_one(session_data)
|
84 |
+
logger.info(f"Session saved with ID: {result.inserted_id}")
|
85 |
+
return {
|
86 |
+
"acknowledged": result.acknowledged,
|
87 |
+
"inserted_id": str(result.inserted_id),
|
88 |
+
"session_data": session_data
|
89 |
+
}
|
90 |
+
except Exception as e:
|
91 |
+
logger.error(f"Error saving session: {e}")
|
92 |
+
raise
|
93 |
+
|
94 |
+
def update_session_response(session_id, response):
|
95 |
+
"""Update a session with response"""
|
96 |
+
try:
|
97 |
+
result = session_collection.update_one(
|
98 |
+
{"session_id": session_id},
|
99 |
+
{"$set": {"response": response}}
|
100 |
+
)
|
101 |
+
|
102 |
+
if result.matched_count == 0:
|
103 |
+
logger.warning(f"No session found with ID: {session_id}")
|
104 |
+
return False
|
105 |
+
|
106 |
+
logger.info(f"Session {session_id} updated with response")
|
107 |
+
return True
|
108 |
+
except Exception as e:
|
109 |
+
logger.error(f"Error updating session response: {e}")
|
110 |
+
raise
|
111 |
+
|
112 |
+
def get_recent_sessions(user_id, action, n=3):
|
113 |
+
"""Get n most recent sessions for a specific user and action"""
|
114 |
+
try:
|
115 |
+
return list(
|
116 |
+
session_collection.find(
|
117 |
+
{"user_id": user_id, "action": action},
|
118 |
+
{"_id": 0, "message": 1, "response": 1}
|
119 |
+
).sort("created_at_datetime", -1).limit(n)
|
120 |
+
)
|
121 |
+
except Exception as e:
|
122 |
+
logger.error(f"Error getting recent sessions: {e}")
|
123 |
+
return []
|
124 |
+
|
125 |
+
def get_user_history(user_id, n=3):
|
126 |
+
"""Get user history for a specific user"""
|
127 |
+
try:
|
128 |
+
# Find all messages of this user
|
129 |
+
user_messages = list(
|
130 |
+
session_collection.find(
|
131 |
+
{
|
132 |
+
"user_id": user_id,
|
133 |
+
"message": {"$exists": True, "$ne": None},
|
134 |
+
# Include all user messages regardless of action type
|
135 |
+
}
|
136 |
+
).sort("created_at_datetime", -1).limit(n * 2) # Get more to ensure we have enough pairs
|
137 |
+
)
|
138 |
+
|
139 |
+
# Group messages by session_id to find pairs
|
140 |
+
session_dict = {}
|
141 |
+
for msg in user_messages:
|
142 |
+
session_id = msg.get("session_id")
|
143 |
+
if session_id not in session_dict:
|
144 |
+
session_dict[session_id] = {}
|
145 |
+
|
146 |
+
if msg.get("factor", "").lower() == "user":
|
147 |
+
session_dict[session_id]["question"] = msg.get("message", "")
|
148 |
+
session_dict[session_id]["timestamp"] = msg.get("created_at_datetime")
|
149 |
+
elif msg.get("factor", "").lower() == "rag":
|
150 |
+
session_dict[session_id]["answer"] = msg.get("response", "")
|
151 |
+
|
152 |
+
# Build history from complete pairs only (with both question and answer)
|
153 |
+
history = []
|
154 |
+
for session_id, data in session_dict.items():
|
155 |
+
if "question" in data and "answer" in data and data.get("answer"):
|
156 |
+
history.append({
|
157 |
+
"question": data["question"],
|
158 |
+
"answer": data["answer"]
|
159 |
+
})
|
160 |
+
|
161 |
+
# Sort by timestamp and limit to n
|
162 |
+
history = sorted(history, key=lambda x: x.get("timestamp", 0), reverse=True)[:n]
|
163 |
+
|
164 |
+
logger.info(f"Retrieved {len(history)} history items for user {user_id}")
|
165 |
+
return history
|
166 |
+
except Exception as e:
|
167 |
+
logger.error(f"Error getting user history: {e}")
|
168 |
+
return []
|
169 |
+
|
170 |
+
# Functions from chatbot.py
|
171 |
+
def get_chat_history(user_id, n=5):
|
172 |
+
"""Get conversation history for a specific user from MongoDB in format suitable for LLM prompt"""
|
173 |
+
try:
|
174 |
+
history = get_user_history(user_id, n)
|
175 |
+
|
176 |
+
# Format history for prompt context
|
177 |
+
formatted_history = ""
|
178 |
+
for item in history:
|
179 |
+
formatted_history += f"User: {item['question']}\nAssistant: {item['answer']}\n\n"
|
180 |
+
|
181 |
+
return formatted_history
|
182 |
+
except Exception as e:
|
183 |
+
logger.error(f"Error getting chat history for prompt: {e}")
|
184 |
+
return ""
|
185 |
+
|
186 |
+
def get_request_history(user_id, n=3):
|
187 |
+
"""Get the most recent user requests to use as context for retrieval"""
|
188 |
+
try:
|
189 |
+
history = get_user_history(user_id, n)
|
190 |
+
|
191 |
+
# Just extract the questions for context
|
192 |
+
requests = []
|
193 |
+
for item in history:
|
194 |
+
requests.append(item['question'])
|
195 |
+
|
196 |
+
# Join all recent requests into a single string for context
|
197 |
+
return " ".join(requests)
|
198 |
+
except Exception as e:
|
199 |
+
logger.error(f"Error getting request history: {e}")
|
200 |
+
return ""
|
app/database/pinecone.py
ADDED
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pinecone import Pinecone
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
import logging
|
5 |
+
from typing import Optional, List, Dict, Any, Union, Tuple
|
6 |
+
import time
|
7 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
8 |
+
import google.generativeai as genai
|
9 |
+
from app.utils.utils import cache
|
10 |
+
from langchain_core.retrievers import BaseRetriever
|
11 |
+
from langchain.callbacks.manager import Callbacks
|
12 |
+
from langchain_core.documents import Document
|
13 |
+
from langchain_core.pydantic_v1 import Field
|
14 |
+
|
15 |
+
# Configure logging
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
# Load environment variables
|
19 |
+
load_dotenv()
|
20 |
+
|
21 |
+
# Pinecone API key and index name
|
22 |
+
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
|
23 |
+
PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME")
|
24 |
+
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
25 |
+
|
26 |
+
# Pinecone retrieval configuration
|
27 |
+
DEFAULT_LIMIT_K = int(os.getenv("PINECONE_DEFAULT_LIMIT_K", "10"))
|
28 |
+
DEFAULT_TOP_K = int(os.getenv("PINECONE_DEFAULT_TOP_K", "6"))
|
29 |
+
DEFAULT_SIMILARITY_METRIC = os.getenv("PINECONE_DEFAULT_SIMILARITY_METRIC", "cosine")
|
30 |
+
DEFAULT_SIMILARITY_THRESHOLD = float(os.getenv("PINECONE_DEFAULT_SIMILARITY_THRESHOLD", "0.75"))
|
31 |
+
ALLOWED_METRICS = os.getenv("PINECONE_ALLOWED_METRICS", "cosine,dotproduct,euclidean").split(",")
|
32 |
+
|
33 |
+
# Export constants for importing elsewhere
|
34 |
+
__all__ = [
|
35 |
+
'get_pinecone_index',
|
36 |
+
'check_db_connection',
|
37 |
+
'search_vectors',
|
38 |
+
'upsert_vectors',
|
39 |
+
'delete_vectors',
|
40 |
+
'fetch_metadata',
|
41 |
+
'get_chain',
|
42 |
+
'DEFAULT_TOP_K',
|
43 |
+
'DEFAULT_LIMIT_K',
|
44 |
+
'DEFAULT_SIMILARITY_METRIC',
|
45 |
+
'DEFAULT_SIMILARITY_THRESHOLD',
|
46 |
+
'ALLOWED_METRICS',
|
47 |
+
'ThresholdRetriever'
|
48 |
+
]
|
49 |
+
|
50 |
+
# Configure Google API
|
51 |
+
if GOOGLE_API_KEY:
|
52 |
+
genai.configure(api_key=GOOGLE_API_KEY)
|
53 |
+
|
54 |
+
# Initialize global variables to store instances of Pinecone and index
|
55 |
+
pc = None
|
56 |
+
index = None
|
57 |
+
_retriever_instance = None
|
58 |
+
|
59 |
+
# Check environment variables
|
60 |
+
if not PINECONE_API_KEY:
|
61 |
+
logger.error("PINECONE_API_KEY is not set in environment variables")
|
62 |
+
|
63 |
+
if not PINECONE_INDEX_NAME:
|
64 |
+
logger.error("PINECONE_INDEX_NAME is not set in environment variables")
|
65 |
+
|
66 |
+
# Initialize Pinecone
|
67 |
+
def init_pinecone():
|
68 |
+
"""Initialize pinecone connection using new API"""
|
69 |
+
global pc, index
|
70 |
+
|
71 |
+
try:
|
72 |
+
# Only initialize if not already initialized
|
73 |
+
if pc is None:
|
74 |
+
logger.info(f"Initializing Pinecone connection to index {PINECONE_INDEX_NAME}...")
|
75 |
+
|
76 |
+
# Initialize Pinecone client using the new API
|
77 |
+
pc = Pinecone(api_key=PINECONE_API_KEY)
|
78 |
+
|
79 |
+
# Check if index exists
|
80 |
+
index_list = pc.list_indexes()
|
81 |
+
|
82 |
+
if not hasattr(index_list, 'names') or PINECONE_INDEX_NAME not in index_list.names():
|
83 |
+
logger.error(f"Index {PINECONE_INDEX_NAME} does not exist in Pinecone")
|
84 |
+
return None
|
85 |
+
|
86 |
+
# Get existing index
|
87 |
+
index = pc.Index(PINECONE_INDEX_NAME)
|
88 |
+
logger.info(f"Pinecone connection established to index {PINECONE_INDEX_NAME}")
|
89 |
+
|
90 |
+
return index
|
91 |
+
except Exception as e:
|
92 |
+
logger.error(f"Error initializing Pinecone: {e}")
|
93 |
+
return None
|
94 |
+
|
95 |
+
# Get Pinecone index singleton
|
96 |
+
def get_pinecone_index():
|
97 |
+
"""Get Pinecone index"""
|
98 |
+
global index
|
99 |
+
if index is None:
|
100 |
+
index = init_pinecone()
|
101 |
+
return index
|
102 |
+
|
103 |
+
# Check Pinecone connection
|
104 |
+
def check_db_connection():
|
105 |
+
"""Check Pinecone connection"""
|
106 |
+
try:
|
107 |
+
pinecone_index = get_pinecone_index()
|
108 |
+
if pinecone_index is None:
|
109 |
+
return False
|
110 |
+
|
111 |
+
# Check index information to confirm connection is working
|
112 |
+
stats = pinecone_index.describe_index_stats()
|
113 |
+
|
114 |
+
# Get total vector count from the new result structure
|
115 |
+
total_vectors = stats.get('total_vector_count', 0)
|
116 |
+
if hasattr(stats, 'namespaces'):
|
117 |
+
# If there are namespaces, calculate total vector count from namespaces
|
118 |
+
total_vectors = sum(ns.get('vector_count', 0) for ns in stats.namespaces.values())
|
119 |
+
|
120 |
+
logger.info(f"Pinecone connection is working. Total vectors: {total_vectors}")
|
121 |
+
return True
|
122 |
+
except Exception as e:
|
123 |
+
logger.error(f"Error in Pinecone connection: {e}")
|
124 |
+
return False
|
125 |
+
|
126 |
+
# Convert similarity score based on the metric
|
127 |
+
def convert_score(score: float, metric: str) -> float:
|
128 |
+
"""
|
129 |
+
Convert similarity score to a 0-1 scale based on the metric used.
|
130 |
+
For metrics like euclidean distance where lower is better, we invert the score.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
score: The raw similarity score
|
134 |
+
metric: The similarity metric used
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
A normalized score between 0-1 where higher means more similar
|
138 |
+
"""
|
139 |
+
if metric.lower() in ["euclidean", "l2"]:
|
140 |
+
# For distance metrics (lower is better), we inverse and normalize
|
141 |
+
# Assuming max reasonable distance is 2.0 for normalized vectors
|
142 |
+
return max(0, 1 - (score / 2.0))
|
143 |
+
else:
|
144 |
+
# For cosine and dot product (higher is better), return as is
|
145 |
+
return score
|
146 |
+
|
147 |
+
# Filter results based on similarity threshold
|
148 |
+
def filter_by_threshold(results, threshold: float, metric: str) -> List[Dict]:
|
149 |
+
"""
|
150 |
+
Filter query results based on similarity threshold.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
results: The query results from Pinecone
|
154 |
+
threshold: The similarity threshold (0-1)
|
155 |
+
metric: The similarity metric used
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
Filtered list of matches
|
159 |
+
"""
|
160 |
+
filtered_matches = []
|
161 |
+
|
162 |
+
if not hasattr(results, 'matches'):
|
163 |
+
return filtered_matches
|
164 |
+
|
165 |
+
for match in results.matches:
|
166 |
+
# Get the score
|
167 |
+
score = getattr(match, 'score', 0)
|
168 |
+
|
169 |
+
# Convert score based on metric
|
170 |
+
normalized_score = convert_score(score, metric)
|
171 |
+
|
172 |
+
# Filter based on threshold
|
173 |
+
if normalized_score >= threshold:
|
174 |
+
# Add normalized score as an additional attribute
|
175 |
+
match.normalized_score = normalized_score
|
176 |
+
filtered_matches.append(match)
|
177 |
+
|
178 |
+
return filtered_matches
|
179 |
+
|
180 |
+
# Search vectors in Pinecone with advanced options
|
181 |
+
async def search_vectors(
|
182 |
+
query_vector,
|
183 |
+
top_k: int = DEFAULT_TOP_K,
|
184 |
+
limit_k: int = DEFAULT_LIMIT_K,
|
185 |
+
similarity_metric: str = DEFAULT_SIMILARITY_METRIC,
|
186 |
+
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
|
187 |
+
namespace: str = "",
|
188 |
+
filter: Optional[Dict] = None
|
189 |
+
) -> Dict:
|
190 |
+
"""
|
191 |
+
Search for most similar vectors in Pinecone with advanced filtering options.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
query_vector: The query vector
|
195 |
+
top_k: Number of results to return (after threshold filtering)
|
196 |
+
limit_k: Maximum number of results to retrieve from Pinecone
|
197 |
+
similarity_metric: Similarity metric to use (cosine, dotproduct, euclidean)
|
198 |
+
similarity_threshold: Threshold for similarity (0-1)
|
199 |
+
namespace: Namespace to search in
|
200 |
+
filter: Filter query
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
Search results with matches filtered by threshold
|
204 |
+
"""
|
205 |
+
try:
|
206 |
+
# Validate parameters
|
207 |
+
if similarity_metric not in ALLOWED_METRICS:
|
208 |
+
logger.warning(f"Invalid similarity metric: {similarity_metric}. Using default: {DEFAULT_SIMILARITY_METRIC}")
|
209 |
+
similarity_metric = DEFAULT_SIMILARITY_METRIC
|
210 |
+
|
211 |
+
if limit_k < top_k:
|
212 |
+
logger.warning(f"limit_k ({limit_k}) must be greater than or equal to top_k ({top_k}). Setting limit_k to {top_k}")
|
213 |
+
limit_k = top_k
|
214 |
+
|
215 |
+
# Create cache key from parameters
|
216 |
+
vector_hash = hash(str(query_vector))
|
217 |
+
cache_key = f"pinecone_search:{vector_hash}:{limit_k}:{similarity_metric}:{similarity_threshold}:{namespace}:{filter}"
|
218 |
+
|
219 |
+
# Check cache first
|
220 |
+
cached_result = cache.get(cache_key)
|
221 |
+
if cached_result is not None:
|
222 |
+
logger.info("Returning cached Pinecone search results")
|
223 |
+
return cached_result
|
224 |
+
|
225 |
+
# If not in cache, perform search
|
226 |
+
pinecone_index = get_pinecone_index()
|
227 |
+
if pinecone_index is None:
|
228 |
+
logger.error("Failed to get Pinecone index for search")
|
229 |
+
return None
|
230 |
+
|
231 |
+
# Query Pinecone with the provided metric and higher limit_k to allow for threshold filtering
|
232 |
+
results = pinecone_index.query(
|
233 |
+
vector=query_vector,
|
234 |
+
top_k=limit_k, # Retrieve more results than needed to allow for threshold filtering
|
235 |
+
namespace=namespace,
|
236 |
+
filter=filter,
|
237 |
+
include_metadata=True,
|
238 |
+
include_values=False, # No need to return vector values to save bandwidth
|
239 |
+
metric=similarity_metric # Specify similarity metric
|
240 |
+
)
|
241 |
+
|
242 |
+
# Filter results by threshold
|
243 |
+
filtered_matches = filter_by_threshold(results, similarity_threshold, similarity_metric)
|
244 |
+
|
245 |
+
# Limit to top_k after filtering
|
246 |
+
filtered_matches = filtered_matches[:top_k]
|
247 |
+
|
248 |
+
# Create a new results object with filtered matches
|
249 |
+
results.matches = filtered_matches
|
250 |
+
|
251 |
+
# Log search result metrics
|
252 |
+
match_count = len(filtered_matches)
|
253 |
+
logger.info(f"Pinecone search returned {match_count} matches after threshold filtering (metric: {similarity_metric}, threshold: {similarity_threshold})")
|
254 |
+
|
255 |
+
# Store result in cache with 5 minute TTL
|
256 |
+
cache.set(cache_key, results, ttl=300)
|
257 |
+
|
258 |
+
return results
|
259 |
+
except Exception as e:
|
260 |
+
logger.error(f"Error searching vectors: {e}")
|
261 |
+
return None
|
262 |
+
|
263 |
+
# Upsert vectors to Pinecone
|
264 |
+
async def upsert_vectors(vectors, namespace=""):
|
265 |
+
"""Upsert vectors to Pinecone index"""
|
266 |
+
try:
|
267 |
+
pinecone_index = get_pinecone_index()
|
268 |
+
if pinecone_index is None:
|
269 |
+
logger.error("Failed to get Pinecone index for upsert")
|
270 |
+
return None
|
271 |
+
|
272 |
+
response = pinecone_index.upsert(
|
273 |
+
vectors=vectors,
|
274 |
+
namespace=namespace
|
275 |
+
)
|
276 |
+
|
277 |
+
# Log upsert metrics
|
278 |
+
upserted_count = response.get('upserted_count', 0)
|
279 |
+
logger.info(f"Upserted {upserted_count} vectors to Pinecone")
|
280 |
+
|
281 |
+
return response
|
282 |
+
except Exception as e:
|
283 |
+
logger.error(f"Error upserting vectors: {e}")
|
284 |
+
return None
|
285 |
+
|
286 |
+
# Delete vectors from Pinecone
|
287 |
+
async def delete_vectors(ids, namespace=""):
|
288 |
+
"""Delete vectors from Pinecone index"""
|
289 |
+
try:
|
290 |
+
pinecone_index = get_pinecone_index()
|
291 |
+
if pinecone_index is None:
|
292 |
+
logger.error("Failed to get Pinecone index for delete")
|
293 |
+
return False
|
294 |
+
|
295 |
+
response = pinecone_index.delete(
|
296 |
+
ids=ids,
|
297 |
+
namespace=namespace
|
298 |
+
)
|
299 |
+
|
300 |
+
logger.info(f"Deleted vectors with IDs {ids} from Pinecone")
|
301 |
+
return True
|
302 |
+
except Exception as e:
|
303 |
+
logger.error(f"Error deleting vectors: {e}")
|
304 |
+
return False
|
305 |
+
|
306 |
+
# Fetch vector metadata from Pinecone
|
307 |
+
async def fetch_metadata(ids, namespace=""):
|
308 |
+
"""Fetch metadata for specific vector IDs"""
|
309 |
+
try:
|
310 |
+
pinecone_index = get_pinecone_index()
|
311 |
+
if pinecone_index is None:
|
312 |
+
logger.error("Failed to get Pinecone index for fetch")
|
313 |
+
return None
|
314 |
+
|
315 |
+
response = pinecone_index.fetch(
|
316 |
+
ids=ids,
|
317 |
+
namespace=namespace
|
318 |
+
)
|
319 |
+
|
320 |
+
return response
|
321 |
+
except Exception as e:
|
322 |
+
logger.error(f"Error fetching vector metadata: {e}")
|
323 |
+
return None
|
324 |
+
|
325 |
+
# Create a custom retriever class for Langchain integration
|
326 |
+
class ThresholdRetriever(BaseRetriever):
|
327 |
+
"""
|
328 |
+
Custom retriever that supports threshold-based filtering and multiple similarity metrics.
|
329 |
+
This integrates with the Langchain ecosystem while using our advanced retrieval logic.
|
330 |
+
"""
|
331 |
+
|
332 |
+
vectorstore: Any = Field(description="Vector store to use for retrieval")
|
333 |
+
embeddings: Any = Field(description="Embeddings model to use for retrieval")
|
334 |
+
search_kwargs: Dict[str, Any] = Field(default_factory=dict, description="Search kwargs for the vectorstore")
|
335 |
+
top_k: int = Field(default=DEFAULT_TOP_K, description="Number of results to return after filtering")
|
336 |
+
limit_k: int = Field(default=DEFAULT_LIMIT_K, description="Maximum number of results to retrieve from Pinecone")
|
337 |
+
similarity_metric: str = Field(default=DEFAULT_SIMILARITY_METRIC, description="Similarity metric to use")
|
338 |
+
similarity_threshold: float = Field(default=DEFAULT_SIMILARITY_THRESHOLD, description="Threshold for similarity")
|
339 |
+
|
340 |
+
class Config:
|
341 |
+
"""Configuration for this pydantic object."""
|
342 |
+
arbitrary_types_allowed = True
|
343 |
+
|
344 |
+
async def search_vectors_sync(
|
345 |
+
self, query_vector,
|
346 |
+
top_k: int = DEFAULT_TOP_K,
|
347 |
+
limit_k: int = DEFAULT_LIMIT_K,
|
348 |
+
similarity_metric: str = DEFAULT_SIMILARITY_METRIC,
|
349 |
+
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
|
350 |
+
namespace: str = "",
|
351 |
+
filter: Optional[Dict] = None
|
352 |
+
) -> Dict:
|
353 |
+
"""Synchronous wrapper for search_vectors"""
|
354 |
+
import asyncio
|
355 |
+
try:
|
356 |
+
# Get current event loop or create a new one
|
357 |
+
try:
|
358 |
+
loop = asyncio.get_event_loop()
|
359 |
+
except RuntimeError:
|
360 |
+
loop = asyncio.new_event_loop()
|
361 |
+
asyncio.set_event_loop(loop)
|
362 |
+
|
363 |
+
# Use event loop to run async function
|
364 |
+
if loop.is_running():
|
365 |
+
# If we're in an event loop, use asyncio.create_task
|
366 |
+
task = asyncio.create_task(search_vectors(
|
367 |
+
query_vector=query_vector,
|
368 |
+
top_k=top_k,
|
369 |
+
limit_k=limit_k,
|
370 |
+
similarity_metric=similarity_metric,
|
371 |
+
similarity_threshold=similarity_threshold,
|
372 |
+
namespace=namespace,
|
373 |
+
filter=filter
|
374 |
+
))
|
375 |
+
return await task
|
376 |
+
else:
|
377 |
+
# If not in an event loop, just await directly
|
378 |
+
return await search_vectors(
|
379 |
+
query_vector=query_vector,
|
380 |
+
top_k=top_k,
|
381 |
+
limit_k=limit_k,
|
382 |
+
similarity_metric=similarity_metric,
|
383 |
+
similarity_threshold=similarity_threshold,
|
384 |
+
namespace=namespace,
|
385 |
+
filter=filter
|
386 |
+
)
|
387 |
+
except Exception as e:
|
388 |
+
logger.error(f"Error in search_vectors_sync: {e}")
|
389 |
+
return None
|
390 |
+
|
391 |
+
def _get_relevant_documents(
|
392 |
+
self, query: str, *, run_manager: Callbacks = None
|
393 |
+
) -> List[Document]:
|
394 |
+
"""
|
395 |
+
Get documents relevant to the query using threshold-based retrieval.
|
396 |
+
|
397 |
+
Args:
|
398 |
+
query: The query string
|
399 |
+
run_manager: The callbacks manager
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
List of relevant documents
|
403 |
+
"""
|
404 |
+
# Generate embedding for query using the embeddings model
|
405 |
+
try:
|
406 |
+
# Use the embeddings model we stored in the class
|
407 |
+
embedding = self.embeddings.embed_query(query)
|
408 |
+
except Exception as e:
|
409 |
+
logger.error(f"Error generating embedding: {e}")
|
410 |
+
# Fallback to creating a new embedding model if needed
|
411 |
+
embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
412 |
+
embedding = embedding_model.embed_query(query)
|
413 |
+
|
414 |
+
# Perform search with advanced options - avoid asyncio.run()
|
415 |
+
import asyncio
|
416 |
+
|
417 |
+
# Get or create event loop
|
418 |
+
try:
|
419 |
+
loop = asyncio.get_event_loop()
|
420 |
+
except RuntimeError:
|
421 |
+
loop = asyncio.new_event_loop()
|
422 |
+
asyncio.set_event_loop(loop)
|
423 |
+
|
424 |
+
# Run asynchronous search in a safe way
|
425 |
+
if loop.is_running():
|
426 |
+
# We're inside an existing event loop (like in FastAPI)
|
427 |
+
# Use a different approach - convert it to a synchronous call
|
428 |
+
from concurrent.futures import ThreadPoolExecutor
|
429 |
+
import functools
|
430 |
+
|
431 |
+
# Define a wrapper function to run in a thread
|
432 |
+
def run_async_in_thread():
|
433 |
+
# Create a new event loop for this thread
|
434 |
+
thread_loop = asyncio.new_event_loop()
|
435 |
+
asyncio.set_event_loop(thread_loop)
|
436 |
+
# Run the coroutine and return the result
|
437 |
+
return thread_loop.run_until_complete(search_vectors(
|
438 |
+
query_vector=embedding,
|
439 |
+
top_k=self.top_k,
|
440 |
+
limit_k=self.limit_k,
|
441 |
+
similarity_metric=self.similarity_metric,
|
442 |
+
similarity_threshold=self.similarity_threshold,
|
443 |
+
namespace=getattr(self.vectorstore, "namespace", ""),
|
444 |
+
filter=self.search_kwargs.get("filter", None)
|
445 |
+
))
|
446 |
+
|
447 |
+
# Run the async function in a thread
|
448 |
+
with ThreadPoolExecutor() as executor:
|
449 |
+
search_result = executor.submit(run_async_in_thread).result()
|
450 |
+
else:
|
451 |
+
# No event loop running, we can use run_until_complete
|
452 |
+
search_result = loop.run_until_complete(search_vectors(
|
453 |
+
query_vector=embedding,
|
454 |
+
top_k=self.top_k,
|
455 |
+
limit_k=self.limit_k,
|
456 |
+
similarity_metric=self.similarity_metric,
|
457 |
+
similarity_threshold=self.similarity_threshold,
|
458 |
+
namespace=getattr(self.vectorstore, "namespace", ""),
|
459 |
+
filter=self.search_kwargs.get("filter", None)
|
460 |
+
))
|
461 |
+
|
462 |
+
# Convert to documents
|
463 |
+
documents = []
|
464 |
+
if search_result and hasattr(search_result, 'matches'):
|
465 |
+
for match in search_result.matches:
|
466 |
+
# Extract metadata
|
467 |
+
metadata = {}
|
468 |
+
if hasattr(match, 'metadata'):
|
469 |
+
metadata = match.metadata
|
470 |
+
|
471 |
+
# Add score to metadata
|
472 |
+
score = getattr(match, 'score', 0)
|
473 |
+
normalized_score = getattr(match, 'normalized_score', score)
|
474 |
+
metadata['score'] = score
|
475 |
+
metadata['normalized_score'] = normalized_score
|
476 |
+
|
477 |
+
# Extract text
|
478 |
+
text = metadata.get('text', '')
|
479 |
+
if 'text' in metadata:
|
480 |
+
del metadata['text'] # Remove from metadata since it's the content
|
481 |
+
|
482 |
+
# Create Document
|
483 |
+
doc = Document(
|
484 |
+
page_content=text,
|
485 |
+
metadata=metadata
|
486 |
+
)
|
487 |
+
documents.append(doc)
|
488 |
+
|
489 |
+
return documents
|
490 |
+
|
491 |
+
# Get the retrieval chain with Pinecone vector store
|
492 |
+
def get_chain(
|
493 |
+
index_name=PINECONE_INDEX_NAME,
|
494 |
+
namespace="Default",
|
495 |
+
top_k=DEFAULT_TOP_K,
|
496 |
+
limit_k=DEFAULT_LIMIT_K,
|
497 |
+
similarity_metric=DEFAULT_SIMILARITY_METRIC,
|
498 |
+
similarity_threshold=DEFAULT_SIMILARITY_THRESHOLD
|
499 |
+
):
|
500 |
+
"""
|
501 |
+
Get the retrieval chain with Pinecone vector store using threshold-based retrieval.
|
502 |
+
|
503 |
+
Args:
|
504 |
+
index_name: Pinecone index name
|
505 |
+
namespace: Pinecone namespace
|
506 |
+
top_k: Number of results to return after filtering
|
507 |
+
limit_k: Maximum number of results to retrieve from Pinecone
|
508 |
+
similarity_metric: Similarity metric to use (cosine, dotproduct, euclidean)
|
509 |
+
similarity_threshold: Threshold for similarity (0-1)
|
510 |
+
|
511 |
+
Returns:
|
512 |
+
ThresholdRetriever instance
|
513 |
+
"""
|
514 |
+
global _retriever_instance
|
515 |
+
try:
|
516 |
+
# If already initialized with same parameters, return cached instance
|
517 |
+
if _retriever_instance is not None:
|
518 |
+
return _retriever_instance
|
519 |
+
|
520 |
+
# Check if chain has been cached
|
521 |
+
cache_key = f"pinecone_retriever:{index_name}:{namespace}:{top_k}:{limit_k}:{similarity_metric}:{similarity_threshold}"
|
522 |
+
cached_retriever = cache.get(cache_key)
|
523 |
+
if cached_retriever is not None:
|
524 |
+
_retriever_instance = cached_retriever
|
525 |
+
logger.info("Retrieved cached Pinecone retriever")
|
526 |
+
return _retriever_instance
|
527 |
+
|
528 |
+
start_time = time.time()
|
529 |
+
logger.info("Initializing new retriever chain with threshold-based filtering")
|
530 |
+
|
531 |
+
# Initialize embeddings model
|
532 |
+
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
533 |
+
|
534 |
+
# Get index
|
535 |
+
pinecone_index = get_pinecone_index()
|
536 |
+
if not pinecone_index:
|
537 |
+
logger.error("Failed to get Pinecone index for retriever chain")
|
538 |
+
return None
|
539 |
+
|
540 |
+
# Get statistics for logging
|
541 |
+
try:
|
542 |
+
stats = pinecone_index.describe_index_stats()
|
543 |
+
total_vectors = stats.get('total_vector_count', 0)
|
544 |
+
logger.info(f"Pinecone index stats - Total vectors: {total_vectors}")
|
545 |
+
except Exception as e:
|
546 |
+
logger.error(f"Error getting index stats: {e}")
|
547 |
+
|
548 |
+
# Use Pinecone from langchain_community.vectorstores
|
549 |
+
from langchain_community.vectorstores import Pinecone as LangchainPinecone
|
550 |
+
|
551 |
+
logger.info(f"Creating Pinecone vectorstore with index: {index_name}, namespace: {namespace}")
|
552 |
+
vectorstore = LangchainPinecone.from_existing_index(
|
553 |
+
embedding=embeddings,
|
554 |
+
index_name=index_name,
|
555 |
+
namespace=namespace,
|
556 |
+
text_key="text"
|
557 |
+
)
|
558 |
+
|
559 |
+
# Create threshold-based retriever
|
560 |
+
logger.info(f"Creating ThresholdRetriever with top_k={top_k}, limit_k={limit_k}, " +
|
561 |
+
f"metric={similarity_metric}, threshold={similarity_threshold}")
|
562 |
+
|
563 |
+
# Create ThresholdRetriever with both vectorstore and embeddings
|
564 |
+
_retriever_instance = ThresholdRetriever(
|
565 |
+
vectorstore=vectorstore,
|
566 |
+
embeddings=embeddings, # Pass embeddings separately
|
567 |
+
top_k=top_k,
|
568 |
+
limit_k=limit_k,
|
569 |
+
similarity_metric=similarity_metric,
|
570 |
+
similarity_threshold=similarity_threshold
|
571 |
+
)
|
572 |
+
|
573 |
+
logger.info(f"Pinecone retriever initialized in {time.time() - start_time:.2f} seconds")
|
574 |
+
|
575 |
+
# Cache the retriever with longer TTL (1 hour) since it rarely changes
|
576 |
+
cache.set(cache_key, _retriever_instance, ttl=3600)
|
577 |
+
|
578 |
+
return _retriever_instance
|
579 |
+
except Exception as e:
|
580 |
+
logger.error(f"Error creating retrieval chain: {e}")
|
581 |
+
return None
|
app/database/postgresql.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from sqlalchemy import create_engine, text
|
3 |
+
from sqlalchemy.ext.declarative import declarative_base
|
4 |
+
from sqlalchemy.orm import sessionmaker
|
5 |
+
from sqlalchemy.exc import SQLAlchemyError, OperationalError
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
import logging
|
8 |
+
|
9 |
+
# Cấu hình logging
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
# Load environment variables
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
+
# Get DB connection mode from environment
|
16 |
+
DB_CONNECTION_MODE = os.getenv("DB_CONNECTION_MODE", "aiven")
|
17 |
+
|
18 |
+
# Set connection string based on mode
|
19 |
+
if DB_CONNECTION_MODE == "aiven":
|
20 |
+
DATABASE_URL = os.getenv("AIVEN_DB_URL")
|
21 |
+
else:
|
22 |
+
# Default or other connection modes can be added here
|
23 |
+
DATABASE_URL = os.getenv("AIVEN_DB_URL")
|
24 |
+
|
25 |
+
if not DATABASE_URL:
|
26 |
+
logger.error("No database URL configured. Please set AIVEN_DB_URL environment variable.")
|
27 |
+
DATABASE_URL = "postgresql://localhost/test" # Fallback để không crash khi khởi động
|
28 |
+
|
29 |
+
# Create SQLAlchemy engine
|
30 |
+
try:
|
31 |
+
engine = create_engine(
|
32 |
+
DATABASE_URL,
|
33 |
+
pool_pre_ping=True,
|
34 |
+
pool_recycle=300, # Recycle connections every 5 minutes
|
35 |
+
pool_size=10, # Tăng kích thước pool từ 5 lên 10
|
36 |
+
max_overflow=20, # Tăng số lượng kết nối tối đa từ 10 lên 20
|
37 |
+
connect_args={
|
38 |
+
"connect_timeout": 3, # Giảm timeout từ 5 xuống 3 giây
|
39 |
+
"keepalives": 1, # Bật keepalive
|
40 |
+
"keepalives_idle": 30, # Thời gian idle trước khi gửi keepalive
|
41 |
+
"keepalives_interval": 10, # Khoảng thời gian giữa các gói keepalive
|
42 |
+
"keepalives_count": 5 # Số lần thử lại trước khi đóng kết nối
|
43 |
+
},
|
44 |
+
# Thêm các tùy chọn hiệu suất
|
45 |
+
isolation_level="READ COMMITTED", # Mức cô lập thấp hơn READ COMMITTED
|
46 |
+
echo=False, # Tắt echo SQL để giảm overhead logging
|
47 |
+
echo_pool=False # Tắt echo pool để giảm overhead logging
|
48 |
+
)
|
49 |
+
logger.info("PostgreSQL engine initialized")
|
50 |
+
except Exception as e:
|
51 |
+
logger.error(f"Failed to initialize PostgreSQL engine: {e}")
|
52 |
+
# Không raise exception để tránh crash khi khởi động, các xử lý lỗi sẽ được thực hiện ở các function
|
53 |
+
|
54 |
+
# Create session factory with optimized settings
|
55 |
+
SessionLocal = sessionmaker(
|
56 |
+
autocommit=False,
|
57 |
+
autoflush=False,
|
58 |
+
bind=engine,
|
59 |
+
expire_on_commit=False # Tránh truy vấn lại DB sau khi commit
|
60 |
+
)
|
61 |
+
|
62 |
+
# Base class for declarative models - use sqlalchemy.orm for SQLAlchemy 2.0 compatibility
|
63 |
+
from sqlalchemy.orm import declarative_base
|
64 |
+
Base = declarative_base()
|
65 |
+
|
66 |
+
# Kiểm tra kết nối PostgreSQL
|
67 |
+
def check_db_connection():
|
68 |
+
"""Kiểm tra kết nối PostgreSQL"""
|
69 |
+
try:
|
70 |
+
# Thực hiện một truy vấn đơn giản để kiểm tra kết nối
|
71 |
+
with engine.connect() as connection:
|
72 |
+
connection.execute(text("SELECT 1"))
|
73 |
+
logger.info("PostgreSQL connection is working")
|
74 |
+
return True
|
75 |
+
except OperationalError as e:
|
76 |
+
logger.error(f"PostgreSQL connection failed: {e}")
|
77 |
+
return False
|
78 |
+
except Exception as e:
|
79 |
+
logger.error(f"Unknown error when checking PostgreSQL connection: {e}")
|
80 |
+
return False
|
81 |
+
|
82 |
+
# Dependency to get DB session
|
83 |
+
def get_db():
|
84 |
+
"""Get database session dependency for FastAPI endpoints"""
|
85 |
+
db = SessionLocal()
|
86 |
+
try:
|
87 |
+
yield db
|
88 |
+
except SQLAlchemyError as e:
|
89 |
+
logger.error(f"Database session error: {e}")
|
90 |
+
db.rollback()
|
91 |
+
raise
|
92 |
+
finally:
|
93 |
+
db.close()
|
94 |
+
|
95 |
+
# Tạo các bảng trong cơ sở dữ liệu nếu chưa tồn tại
|
96 |
+
def create_tables():
|
97 |
+
"""Tạo các bảng trong cơ sở dữ liệu"""
|
98 |
+
try:
|
99 |
+
Base.metadata.create_all(bind=engine)
|
100 |
+
logger.info("Database tables created or already exist")
|
101 |
+
return True
|
102 |
+
except SQLAlchemyError as e:
|
103 |
+
logger.error(f"Failed to create database tables: {e}")
|
104 |
+
return False
|
app/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Pydantic models package
|
app/models/mongodb_models.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field, ConfigDict
|
2 |
+
from typing import Optional, List, Dict, Any
|
3 |
+
from datetime import datetime
|
4 |
+
import uuid
|
5 |
+
|
6 |
+
class SessionBase(BaseModel):
|
7 |
+
"""Base model for session data"""
|
8 |
+
session_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
9 |
+
factor: str
|
10 |
+
action: str
|
11 |
+
first_name: str
|
12 |
+
last_name: Optional[str] = None
|
13 |
+
message: Optional[str] = None
|
14 |
+
user_id: str
|
15 |
+
username: Optional[str] = None
|
16 |
+
|
17 |
+
class SessionCreate(SessionBase):
|
18 |
+
"""Model for creating new session"""
|
19 |
+
response: Optional[str] = None
|
20 |
+
|
21 |
+
class SessionResponse(SessionBase):
|
22 |
+
"""Response model for session data"""
|
23 |
+
created_at: str
|
24 |
+
response: Optional[str] = None
|
25 |
+
|
26 |
+
model_config = ConfigDict(
|
27 |
+
json_schema_extra={
|
28 |
+
"example": {
|
29 |
+
"session_id": "123e4567-e89b-12d3-a456-426614174000",
|
30 |
+
"factor": "user",
|
31 |
+
"action": "asking_freely",
|
32 |
+
"created_at": "2023-06-01 14:30:45",
|
33 |
+
"first_name": "John",
|
34 |
+
"last_name": "Doe",
|
35 |
+
"message": "How can I find emergency contacts?",
|
36 |
+
"user_id": "12345678",
|
37 |
+
"username": "johndoe",
|
38 |
+
"response": "You can find emergency contacts in the Emergency section..."
|
39 |
+
}
|
40 |
+
}
|
41 |
+
)
|
42 |
+
|
43 |
+
class HistoryRequest(BaseModel):
|
44 |
+
"""Request model for history"""
|
45 |
+
user_id: str
|
46 |
+
n: int = 3
|
47 |
+
|
48 |
+
class QuestionAnswer(BaseModel):
|
49 |
+
"""Model for question-answer pair"""
|
50 |
+
question: str
|
51 |
+
answer: str
|
52 |
+
|
53 |
+
class HistoryResponse(BaseModel):
|
54 |
+
"""Response model for history"""
|
55 |
+
history: List[QuestionAnswer]
|
app/models/rag_models.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
from typing import Optional, List, Dict, Any
|
3 |
+
|
4 |
+
class ChatRequest(BaseModel):
|
5 |
+
"""Request model for chat endpoint"""
|
6 |
+
user_id: str = Field(..., description="User ID from Telegram")
|
7 |
+
question: str = Field(..., description="User's question")
|
8 |
+
include_history: bool = Field(True, description="Whether to include user history in prompt")
|
9 |
+
use_rag: bool = Field(True, description="Whether to use RAG")
|
10 |
+
|
11 |
+
# Advanced retrieval parameters
|
12 |
+
similarity_top_k: int = Field(6, description="Number of top similar documents to return (after filtering)")
|
13 |
+
limit_k: int = Field(10, description="Maximum number of documents to retrieve from vector store")
|
14 |
+
similarity_metric: str = Field("cosine", description="Similarity metric to use (cosine, dotproduct, euclidean)")
|
15 |
+
similarity_threshold: float = Field(0.75, description="Threshold for vector similarity (0-1)")
|
16 |
+
|
17 |
+
# User information
|
18 |
+
session_id: Optional[str] = Field(None, description="Session ID for tracking conversations")
|
19 |
+
first_name: Optional[str] = Field(None, description="User's first name")
|
20 |
+
last_name: Optional[str] = Field(None, description="User's last name")
|
21 |
+
username: Optional[str] = Field(None, description="User's username")
|
22 |
+
|
23 |
+
class SourceDocument(BaseModel):
|
24 |
+
"""Model for source documents"""
|
25 |
+
text: str = Field(..., description="Text content of the document")
|
26 |
+
source: Optional[str] = Field(None, description="Source of the document")
|
27 |
+
score: Optional[float] = Field(None, description="Raw similarity score of the document")
|
28 |
+
normalized_score: Optional[float] = Field(None, description="Normalized similarity score (0-1)")
|
29 |
+
metadata: Optional[Dict[str, Any]] = Field(None, description="Metadata of the document")
|
30 |
+
|
31 |
+
class ChatResponse(BaseModel):
|
32 |
+
"""Response model for chat endpoint"""
|
33 |
+
answer: str = Field(..., description="Generated answer")
|
34 |
+
processing_time: float = Field(..., description="Processing time in seconds")
|
35 |
+
|
36 |
+
class ChatResponseInternal(BaseModel):
|
37 |
+
"""Internal model for chat response with sources - used only for logging"""
|
38 |
+
answer: str
|
39 |
+
sources: Optional[List[SourceDocument]] = Field(None, description="Source documents used for generating answer")
|
40 |
+
processing_time: Optional[float] = None
|
41 |
+
|
42 |
+
class EmbeddingRequest(BaseModel):
|
43 |
+
"""Request model for embedding endpoint"""
|
44 |
+
text: str = Field(..., description="Text to generate embedding for")
|
45 |
+
|
46 |
+
class EmbeddingResponse(BaseModel):
|
47 |
+
"""Response model for embedding endpoint"""
|
48 |
+
embedding: List[float] = Field(..., description="Generated embedding")
|
49 |
+
text: str = Field(..., description="Text that was embedded")
|
50 |
+
model: str = Field(..., description="Model used for embedding")
|
51 |
+
|
52 |
+
class HealthResponse(BaseModel):
|
53 |
+
"""Response model for health endpoint"""
|
54 |
+
status: str
|
55 |
+
services: Dict[str, bool]
|
56 |
+
timestamp: str
|
57 |
+
|
58 |
+
class UserMessageModel(BaseModel):
|
59 |
+
"""Model for user messages sent to the RAG API"""
|
60 |
+
user_id: str = Field(..., description="User ID from the client application")
|
61 |
+
session_id: str = Field(..., description="Session ID for tracking the conversation")
|
62 |
+
message: str = Field(..., description="User's message/question")
|
63 |
+
|
64 |
+
# Advanced retrieval parameters (optional)
|
65 |
+
similarity_top_k: Optional[int] = Field(None, description="Number of top similar documents to return (after filtering)")
|
66 |
+
limit_k: Optional[int] = Field(None, description="Maximum number of documents to retrieve from vector store")
|
67 |
+
similarity_metric: Optional[str] = Field(None, description="Similarity metric to use (cosine, dotproduct, euclidean)")
|
68 |
+
similarity_threshold: Optional[float] = Field(None, description="Threshold for vector similarity (0-1)")
|
app/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Utility functions package
|
app/utils/debug_utils.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import logging
|
4 |
+
import traceback
|
5 |
+
import json
|
6 |
+
import time
|
7 |
+
from datetime import datetime
|
8 |
+
import platform
|
9 |
+
|
10 |
+
# Try to import psutil, provide fallback if not available
|
11 |
+
try:
|
12 |
+
import psutil
|
13 |
+
PSUTIL_AVAILABLE = True
|
14 |
+
except ImportError:
|
15 |
+
PSUTIL_AVAILABLE = False
|
16 |
+
logging.warning("psutil module not available. System monitoring features will be limited.")
|
17 |
+
|
18 |
+
# Configure logging
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
class DebugInfo:
|
22 |
+
"""Class containing debug information"""
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def get_system_info():
|
26 |
+
"""Get system information"""
|
27 |
+
try:
|
28 |
+
info = {
|
29 |
+
"os": platform.system(),
|
30 |
+
"os_version": platform.version(),
|
31 |
+
"python_version": platform.python_version(),
|
32 |
+
"cpu_count": os.cpu_count(),
|
33 |
+
"timestamp": datetime.now().isoformat()
|
34 |
+
}
|
35 |
+
|
36 |
+
# Add information from psutil if available
|
37 |
+
if PSUTIL_AVAILABLE:
|
38 |
+
info.update({
|
39 |
+
"total_memory": round(psutil.virtual_memory().total / (1024 * 1024 * 1024), 2), # GB
|
40 |
+
"available_memory": round(psutil.virtual_memory().available / (1024 * 1024 * 1024), 2), # GB
|
41 |
+
"cpu_usage": psutil.cpu_percent(interval=0.1),
|
42 |
+
"memory_usage": psutil.virtual_memory().percent,
|
43 |
+
"disk_usage": psutil.disk_usage('/').percent,
|
44 |
+
})
|
45 |
+
else:
|
46 |
+
info.update({
|
47 |
+
"total_memory": "psutil not available",
|
48 |
+
"available_memory": "psutil not available",
|
49 |
+
"cpu_usage": "psutil not available",
|
50 |
+
"memory_usage": "psutil not available",
|
51 |
+
"disk_usage": "psutil not available",
|
52 |
+
})
|
53 |
+
|
54 |
+
return info
|
55 |
+
except Exception as e:
|
56 |
+
logger.error(f"Error getting system info: {e}")
|
57 |
+
return {"error": str(e)}
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def get_env_info():
|
61 |
+
"""Get environment variable information (masking sensitive information)"""
|
62 |
+
try:
|
63 |
+
# List of environment variables to mask values
|
64 |
+
sensitive_vars = [
|
65 |
+
"API_KEY", "SECRET", "PASSWORD", "TOKEN", "AUTH", "MONGODB_URL",
|
66 |
+
"AIVEN_DB_URL", "PINECONE_API_KEY", "GOOGLE_API_KEY"
|
67 |
+
]
|
68 |
+
|
69 |
+
env_vars = {}
|
70 |
+
for key, value in os.environ.items():
|
71 |
+
# Check if environment variable contains sensitive words
|
72 |
+
is_sensitive = any(s in key.upper() for s in sensitive_vars)
|
73 |
+
|
74 |
+
if is_sensitive and value:
|
75 |
+
# Mask value displaying only the first 4 characters
|
76 |
+
masked_value = value[:4] + "****" if len(value) > 4 else "****"
|
77 |
+
env_vars[key] = masked_value
|
78 |
+
else:
|
79 |
+
env_vars[key] = value
|
80 |
+
|
81 |
+
return env_vars
|
82 |
+
except Exception as e:
|
83 |
+
logger.error(f"Error getting environment info: {e}")
|
84 |
+
return {"error": str(e)}
|
85 |
+
|
86 |
+
@staticmethod
|
87 |
+
def get_database_status():
|
88 |
+
"""Get database connection status"""
|
89 |
+
try:
|
90 |
+
from app.database.postgresql import check_db_connection as check_postgresql
|
91 |
+
from app.database.mongodb import check_db_connection as check_mongodb
|
92 |
+
from app.database.pinecone import check_db_connection as check_pinecone
|
93 |
+
|
94 |
+
return {
|
95 |
+
"postgresql": check_postgresql(),
|
96 |
+
"mongodb": check_mongodb(),
|
97 |
+
"pinecone": check_pinecone(),
|
98 |
+
"timestamp": datetime.now().isoformat()
|
99 |
+
}
|
100 |
+
except Exception as e:
|
101 |
+
logger.error(f"Error getting database status: {e}")
|
102 |
+
return {"error": str(e)}
|
103 |
+
|
104 |
+
class PerformanceMonitor:
|
105 |
+
"""Performance monitoring class"""
|
106 |
+
|
107 |
+
def __init__(self):
|
108 |
+
self.start_time = time.time()
|
109 |
+
self.checkpoints = []
|
110 |
+
|
111 |
+
def checkpoint(self, name):
|
112 |
+
"""Mark a checkpoint and record the time"""
|
113 |
+
current_time = time.time()
|
114 |
+
elapsed = current_time - self.start_time
|
115 |
+
self.checkpoints.append({
|
116 |
+
"name": name,
|
117 |
+
"time": current_time,
|
118 |
+
"elapsed": elapsed
|
119 |
+
})
|
120 |
+
logger.debug(f"Checkpoint '{name}' at {elapsed:.4f}s")
|
121 |
+
return elapsed
|
122 |
+
|
123 |
+
def get_report(self):
|
124 |
+
"""Generate performance report"""
|
125 |
+
if not self.checkpoints:
|
126 |
+
return {"error": "No checkpoints recorded"}
|
127 |
+
|
128 |
+
total_time = time.time() - self.start_time
|
129 |
+
|
130 |
+
# Calculate time between checkpoints
|
131 |
+
intervals = []
|
132 |
+
prev_time = self.start_time
|
133 |
+
|
134 |
+
for checkpoint in self.checkpoints:
|
135 |
+
interval = checkpoint["time"] - prev_time
|
136 |
+
intervals.append({
|
137 |
+
"name": checkpoint["name"],
|
138 |
+
"interval": interval,
|
139 |
+
"elapsed": checkpoint["elapsed"]
|
140 |
+
})
|
141 |
+
prev_time = checkpoint["time"]
|
142 |
+
|
143 |
+
return {
|
144 |
+
"total_time": total_time,
|
145 |
+
"checkpoint_count": len(self.checkpoints),
|
146 |
+
"intervals": intervals
|
147 |
+
}
|
148 |
+
|
149 |
+
class ErrorTracker:
|
150 |
+
"""Class to track and record errors"""
|
151 |
+
|
152 |
+
def __init__(self, max_errors=100):
|
153 |
+
self.errors = []
|
154 |
+
self.max_errors = max_errors
|
155 |
+
|
156 |
+
def track_error(self, error, context=None):
|
157 |
+
"""Record error information"""
|
158 |
+
error_info = {
|
159 |
+
"error_type": type(error).__name__,
|
160 |
+
"error_message": str(error),
|
161 |
+
"traceback": traceback.format_exc(),
|
162 |
+
"timestamp": datetime.now().isoformat(),
|
163 |
+
"context": context or {}
|
164 |
+
}
|
165 |
+
|
166 |
+
# Add to error list
|
167 |
+
self.errors.append(error_info)
|
168 |
+
|
169 |
+
# Limit the number of stored errors
|
170 |
+
if len(self.errors) > self.max_errors:
|
171 |
+
self.errors.pop(0) # Remove oldest error
|
172 |
+
|
173 |
+
return error_info
|
174 |
+
|
175 |
+
def get_errors(self, limit=None):
|
176 |
+
"""Get list of recorded errors"""
|
177 |
+
if limit is None or limit >= len(self.errors):
|
178 |
+
return self.errors
|
179 |
+
return self.errors[-limit:] # Return most recent errors
|
180 |
+
|
181 |
+
# Initialize global objects
|
182 |
+
error_tracker = ErrorTracker()
|
183 |
+
performance_monitor = PerformanceMonitor()
|
184 |
+
|
185 |
+
def debug_view(request=None):
|
186 |
+
"""Create a full debug report"""
|
187 |
+
debug_data = {
|
188 |
+
"system_info": DebugInfo.get_system_info(),
|
189 |
+
"database_status": DebugInfo.get_database_status(),
|
190 |
+
"performance": performance_monitor.get_report(),
|
191 |
+
"recent_errors": error_tracker.get_errors(limit=10),
|
192 |
+
"timestamp": datetime.now().isoformat()
|
193 |
+
}
|
194 |
+
|
195 |
+
# Add request information if available
|
196 |
+
if request:
|
197 |
+
debug_data["request"] = {
|
198 |
+
"method": request.method,
|
199 |
+
"url": str(request.url),
|
200 |
+
"headers": dict(request.headers),
|
201 |
+
"client": {
|
202 |
+
"host": request.client.host if request.client else "unknown",
|
203 |
+
"port": request.client.port if request.client else "unknown"
|
204 |
+
}
|
205 |
+
}
|
206 |
+
|
207 |
+
return debug_data
|
app/utils/middleware.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import Request, status
|
2 |
+
from fastapi.responses import JSONResponse
|
3 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
4 |
+
import logging
|
5 |
+
import time
|
6 |
+
import traceback
|
7 |
+
import uuid
|
8 |
+
from .utils import get_local_time
|
9 |
+
|
10 |
+
# Configure logging
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
14 |
+
"""Middleware to log requests and responses"""
|
15 |
+
|
16 |
+
async def dispatch(self, request: Request, call_next):
|
17 |
+
request_id = str(uuid.uuid4())
|
18 |
+
request.state.request_id = request_id
|
19 |
+
|
20 |
+
# Log request information
|
21 |
+
client_host = request.client.host if request.client else "unknown"
|
22 |
+
logger.info(f"Request [{request_id}]: {request.method} {request.url.path} from {client_host}")
|
23 |
+
|
24 |
+
# Measure processing time
|
25 |
+
start_time = time.time()
|
26 |
+
|
27 |
+
try:
|
28 |
+
# Process request
|
29 |
+
response = await call_next(request)
|
30 |
+
|
31 |
+
# Calculate processing time
|
32 |
+
process_time = time.time() - start_time
|
33 |
+
logger.info(f"Response [{request_id}]: {response.status_code} processed in {process_time:.4f}s")
|
34 |
+
|
35 |
+
# Add headers
|
36 |
+
response.headers["X-Request-ID"] = request_id
|
37 |
+
response.headers["X-Process-Time"] = str(process_time)
|
38 |
+
|
39 |
+
return response
|
40 |
+
|
41 |
+
except Exception as e:
|
42 |
+
# Log error
|
43 |
+
process_time = time.time() - start_time
|
44 |
+
logger.error(f"Error [{request_id}] after {process_time:.4f}s: {str(e)}")
|
45 |
+
logger.error(traceback.format_exc())
|
46 |
+
|
47 |
+
# Return error response
|
48 |
+
return JSONResponse(
|
49 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
50 |
+
content={
|
51 |
+
"detail": "Internal server error",
|
52 |
+
"request_id": request_id,
|
53 |
+
"timestamp": get_local_time()
|
54 |
+
}
|
55 |
+
)
|
56 |
+
|
57 |
+
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
|
58 |
+
"""Middleware to handle uncaught exceptions in the application"""
|
59 |
+
|
60 |
+
async def dispatch(self, request: Request, call_next):
|
61 |
+
try:
|
62 |
+
return await call_next(request)
|
63 |
+
except Exception as e:
|
64 |
+
# Get request_id if available
|
65 |
+
request_id = getattr(request.state, "request_id", str(uuid.uuid4()))
|
66 |
+
|
67 |
+
# Log error
|
68 |
+
logger.error(f"Uncaught exception [{request_id}]: {str(e)}")
|
69 |
+
logger.error(traceback.format_exc())
|
70 |
+
|
71 |
+
# Return error response
|
72 |
+
return JSONResponse(
|
73 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
74 |
+
content={
|
75 |
+
"detail": "Internal server error",
|
76 |
+
"request_id": request_id,
|
77 |
+
"timestamp": get_local_time()
|
78 |
+
}
|
79 |
+
)
|
80 |
+
|
81 |
+
class DatabaseCheckMiddleware(BaseHTTPMiddleware):
|
82 |
+
"""Middleware to check database connections before each request"""
|
83 |
+
|
84 |
+
async def dispatch(self, request: Request, call_next):
|
85 |
+
# Skip paths that don't need database checks
|
86 |
+
skip_paths = ["/", "/health", "/docs", "/redoc", "/openapi.json"]
|
87 |
+
if request.url.path in skip_paths:
|
88 |
+
return await call_next(request)
|
89 |
+
|
90 |
+
# Check database connections
|
91 |
+
try:
|
92 |
+
# TODO: Add checks for MongoDB and Pinecone if needed
|
93 |
+
# PostgreSQL check is already done in route handler with get_db() method
|
94 |
+
|
95 |
+
# Process request normally
|
96 |
+
return await call_next(request)
|
97 |
+
|
98 |
+
except Exception as e:
|
99 |
+
# Log error
|
100 |
+
logger.error(f"Database connection check failed: {str(e)}")
|
101 |
+
|
102 |
+
# Return error response
|
103 |
+
return JSONResponse(
|
104 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
105 |
+
content={
|
106 |
+
"detail": "Database connection failed",
|
107 |
+
"timestamp": get_local_time()
|
108 |
+
}
|
109 |
+
)
|
app/utils/utils.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import time
|
3 |
+
import uuid
|
4 |
+
import threading
|
5 |
+
from functools import wraps
|
6 |
+
from datetime import datetime, timedelta
|
7 |
+
import pytz
|
8 |
+
from typing import Callable, Any, Dict, Optional
|
9 |
+
|
10 |
+
# Configure logging
|
11 |
+
logging.basicConfig(
|
12 |
+
level=logging.INFO,
|
13 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
14 |
+
)
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
# Asia/Ho_Chi_Minh timezone
|
18 |
+
asia_tz = pytz.timezone('Asia/Ho_Chi_Minh')
|
19 |
+
|
20 |
+
def generate_uuid():
|
21 |
+
"""Generate a unique identifier"""
|
22 |
+
return str(uuid.uuid4())
|
23 |
+
|
24 |
+
def get_current_time():
|
25 |
+
"""Get current time in ISO format"""
|
26 |
+
return datetime.now().isoformat()
|
27 |
+
|
28 |
+
def get_local_time():
|
29 |
+
"""Get current time in Asia/Ho_Chi_Minh timezone"""
|
30 |
+
return datetime.now(asia_tz).strftime("%Y-%m-%d %H:%M:%S")
|
31 |
+
|
32 |
+
def get_local_datetime():
|
33 |
+
"""Get current datetime object in Asia/Ho_Chi_Minh timezone"""
|
34 |
+
return datetime.now(asia_tz)
|
35 |
+
|
36 |
+
# For backward compatibility
|
37 |
+
get_vietnam_time = get_local_time
|
38 |
+
get_vietnam_datetime = get_local_datetime
|
39 |
+
|
40 |
+
def timer_decorator(func: Callable) -> Callable:
|
41 |
+
"""
|
42 |
+
Decorator to time function execution and log results.
|
43 |
+
"""
|
44 |
+
@wraps(func)
|
45 |
+
async def wrapper(*args, **kwargs):
|
46 |
+
start_time = time.time()
|
47 |
+
try:
|
48 |
+
result = await func(*args, **kwargs)
|
49 |
+
elapsed_time = time.time() - start_time
|
50 |
+
logger.info(f"Function {func.__name__} executed in {elapsed_time:.4f} seconds")
|
51 |
+
return result
|
52 |
+
except Exception as e:
|
53 |
+
elapsed_time = time.time() - start_time
|
54 |
+
logger.error(f"Function {func.__name__} failed after {elapsed_time:.4f} seconds: {e}")
|
55 |
+
raise
|
56 |
+
return wrapper
|
57 |
+
|
58 |
+
def sanitize_input(text):
|
59 |
+
"""Sanitize input text"""
|
60 |
+
if not text:
|
61 |
+
return ""
|
62 |
+
# Remove potential dangerous characters or patterns
|
63 |
+
return text.strip()
|
64 |
+
|
65 |
+
def truncate_text(text, max_length=100):
|
66 |
+
"""
|
67 |
+
Truncate text to given max length and add ellipsis.
|
68 |
+
"""
|
69 |
+
if not text or len(text) <= max_length:
|
70 |
+
return text
|
71 |
+
return text[:max_length] + "..."
|
72 |
+
|
73 |
+
# Simple in-memory cache implementation (replaces Redis dependency)
|
74 |
+
class SimpleCache:
|
75 |
+
def __init__(self):
|
76 |
+
self._cache = {}
|
77 |
+
self._expiry = {}
|
78 |
+
|
79 |
+
def get(self, key: str) -> Optional[Any]:
|
80 |
+
"""Get value from cache if it exists and hasn't expired"""
|
81 |
+
if key in self._cache:
|
82 |
+
# Check if the key has expired
|
83 |
+
if key in self._expiry and self._expiry[key] > datetime.now():
|
84 |
+
return self._cache[key]
|
85 |
+
else:
|
86 |
+
# Clean up expired keys
|
87 |
+
if key in self._cache:
|
88 |
+
del self._cache[key]
|
89 |
+
if key in self._expiry:
|
90 |
+
del self._expiry[key]
|
91 |
+
return None
|
92 |
+
|
93 |
+
def set(self, key: str, value: Any, ttl: int = 300) -> None:
|
94 |
+
"""Set a value in the cache with TTL in seconds"""
|
95 |
+
self._cache[key] = value
|
96 |
+
# Set expiry time
|
97 |
+
self._expiry[key] = datetime.now() + timedelta(seconds=ttl)
|
98 |
+
|
99 |
+
def delete(self, key: str) -> None:
|
100 |
+
"""Delete a key from the cache"""
|
101 |
+
if key in self._cache:
|
102 |
+
del self._cache[key]
|
103 |
+
if key in self._expiry:
|
104 |
+
del self._expiry[key]
|
105 |
+
|
106 |
+
def clear(self) -> None:
|
107 |
+
"""Clear the entire cache"""
|
108 |
+
self._cache.clear()
|
109 |
+
self._expiry.clear()
|
110 |
+
|
111 |
+
# Initialize cache
|
112 |
+
cache = SimpleCache()
|
113 |
+
|
114 |
+
def get_host_url(request) -> str:
|
115 |
+
"""
|
116 |
+
Get the host URL from a request object.
|
117 |
+
"""
|
118 |
+
host = request.headers.get("host", "localhost")
|
119 |
+
scheme = request.headers.get("x-forwarded-proto", "http")
|
120 |
+
return f"{scheme}://{host}"
|
121 |
+
|
122 |
+
def format_time(timestamp):
|
123 |
+
"""
|
124 |
+
Format a timestamp into a human-readable string.
|
125 |
+
"""
|
126 |
+
return timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
docker-compose.yml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: '3'
|
2 |
+
|
3 |
+
services:
|
4 |
+
backend:
|
5 |
+
build:
|
6 |
+
context: .
|
7 |
+
dockerfile: Dockerfile
|
8 |
+
ports:
|
9 |
+
- "7860:7860"
|
10 |
+
env_file:
|
11 |
+
- .env
|
12 |
+
restart: unless-stopped
|
13 |
+
healthcheck:
|
14 |
+
test: ["CMD", "curl", "-f", "http://localhost:7860/health"]
|
15 |
+
interval: 30s
|
16 |
+
timeout: 10s
|
17 |
+
retries: 3
|
18 |
+
start_period: 40s
|
19 |
+
volumes:
|
20 |
+
- ./app:/app/app
|
21 |
+
command: uvicorn app:app --host 0.0.0.0 --port 7860 --reload
|
pytest.ini
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[pytest]
|
2 |
+
# Bỏ qua cảnh báo về anyio module và các cảnh báo vận hành nội bộ
|
3 |
+
filterwarnings =
|
4 |
+
ignore::pytest.PytestAssertRewriteWarning:.*anyio
|
5 |
+
ignore:.*general_plain_validator_function.* is deprecated.*:DeprecationWarning
|
6 |
+
ignore:.*with_info_plain_validator_function.*:DeprecationWarning
|
7 |
+
|
8 |
+
# Cấu hình cơ bản khác
|
9 |
+
testpaths = tests
|
10 |
+
python_files = test_*.py
|
11 |
+
python_classes = Test*
|
12 |
+
python_functions = test_*
|
requirements.txt
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FastAPI
|
2 |
+
fastapi==0.103.1
|
3 |
+
uvicorn[standard]==0.23.2
|
4 |
+
pydantic==2.4.2
|
5 |
+
python-dotenv==1.0.0
|
6 |
+
websockets==11.0.3
|
7 |
+
|
8 |
+
# MongoDB
|
9 |
+
pymongo==4.6.1
|
10 |
+
dnspython==2.4.2
|
11 |
+
|
12 |
+
# PostgreSQL
|
13 |
+
sqlalchemy==2.0.20
|
14 |
+
pydantic-settings==2.0.3
|
15 |
+
psycopg2-binary==2.9.7
|
16 |
+
|
17 |
+
# Pinecone & RAG
|
18 |
+
pinecone-client==3.0.0
|
19 |
+
langchain==0.1.4
|
20 |
+
langchain-core==0.1.19
|
21 |
+
langchain-community==0.0.14
|
22 |
+
langchain-google-genai==0.0.5
|
23 |
+
langchain-pinecone==0.0.1
|
24 |
+
faiss-cpu==1.7.4
|
25 |
+
google-generativeai==0.3.1
|
26 |
+
|
27 |
+
# Extras
|
28 |
+
pytz==2023.3
|
29 |
+
python-multipart==0.0.6
|
30 |
+
httpx==0.25.1
|
31 |
+
requests==2.31.0
|
32 |
+
beautifulsoup4==4.12.2
|
33 |
+
redis==5.0.1
|
34 |
+
|
35 |
+
# Testing
|
36 |
+
prometheus-client==0.17.1
|
37 |
+
pytest==7.4.0
|
38 |
+
pytest-cov==4.1.0
|
39 |
+
watchfiles==0.21.0
|
40 |
+
|
41 |
+
# Core dependencies
|
42 |
+
starlette==0.27.0
|
43 |
+
psutil==5.9.6
|