Using a current method
Browse files- .gitignore +142 -0
- README.md +1 -0
- notebooks/chroma_db_test.ipynb +261 -0
- notebooks/rag_test.ipynb +0 -0
- poetry.lock +0 -0
- pyproject.toml +23 -0
- requirements.txt +27 -0
- src/__init__.py +0 -0
- src/api.py +88 -0
- src/llm/__init__.py +0 -0
- src/llm/agents/__init__.py +0 -0
- src/llm/agents/base_agent.py +67 -0
- src/llm/agents/context_agent.py +63 -0
- src/llm/agents/conversation_agent.py +205 -0
- src/llm/agents/emotion_agent.py +125 -0
- src/llm/core/__init__.py +0 -0
- src/llm/core/config.py +36 -0
- src/llm/core/llm.py +107 -0
- src/llm/main.py +46 -0
- src/llm/memory/__init__.py +0 -0
- src/llm/memory/history.py +67 -0
- src/llm/memory/memory_manager.py +71 -0
- src/llm/memory/redis_connection.py +38 -0
- src/llm/memory/session_manager.py +48 -0
- src/llm/memory/vector_store.py +72 -0
- src/llm/models/__init__.py +0 -0
- src/llm/models/schemas.py +31 -0
- src/llm/routes.py +179 -0
- src/llm/utils/__init__.py +0 -0
- src/llm/utils/logging.py +54 -0
- src/music/__init__.py +0 -0
- src/music/clients/__init__.py +0 -0
- src/music/clients/spotify_client.py +76 -0
- src/music/config/__init__.py +0 -0
- src/music/config/settings.py +22 -0
- src/music/fetch.py +18 -0
- src/music/main.py +315 -0
- src/music/models/__init__.py +0 -0
- src/music/models/data_models.py +17 -0
- src/music/services/__init__.py +0 -0
- src/music/services/genre_service.py +33 -0
- src/tele_bot/__init__.py +0 -0
- src/tele_bot/bot.py +100 -0
- src/tele_bot/graph.md +13 -0
- src/utils/__init__.py +0 -0
- src/utils/main.py +7 -0
- src/utils/pdf_splitter.py +37 -0
- src/utils/vector_db.py +44 -0
.gitignore
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
|
27 |
+
# PyInstaller
|
28 |
+
# Usually these files are written by a python script from a template
|
29 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
30 |
+
*.manifest
|
31 |
+
*.spec
|
32 |
+
|
33 |
+
# Installer logs
|
34 |
+
pip-log.txt
|
35 |
+
pip-delete-this-directory.txt
|
36 |
+
|
37 |
+
# Unit test / coverage reports
|
38 |
+
htmlcov/
|
39 |
+
.tox/
|
40 |
+
.coverage
|
41 |
+
.coverage.*
|
42 |
+
.cache
|
43 |
+
nosetests.xml
|
44 |
+
coverage.xml
|
45 |
+
*.cover
|
46 |
+
.hypothesis/
|
47 |
+
.pytest_cache/
|
48 |
+
|
49 |
+
# Translations
|
50 |
+
*.mo
|
51 |
+
*.pot
|
52 |
+
|
53 |
+
# Django stuff:
|
54 |
+
*.log
|
55 |
+
logs/
|
56 |
+
local_settings.py
|
57 |
+
db.sqlite3
|
58 |
+
db.sqlite3-journal
|
59 |
+
|
60 |
+
# Flask stuff:
|
61 |
+
instance/
|
62 |
+
.webassets-cache
|
63 |
+
|
64 |
+
# Scrapy stuff:
|
65 |
+
.scrapy
|
66 |
+
|
67 |
+
# Sphinx documentation
|
68 |
+
docs/_build/
|
69 |
+
|
70 |
+
# PyBuilder
|
71 |
+
target/
|
72 |
+
|
73 |
+
# Jupyter Notebook
|
74 |
+
.ipynb_checkpoints
|
75 |
+
|
76 |
+
# IPython
|
77 |
+
profile_default/
|
78 |
+
ipython_config.py
|
79 |
+
|
80 |
+
# pyenv
|
81 |
+
.python-version
|
82 |
+
|
83 |
+
# pipenv
|
84 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
85 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
86 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
87 |
+
# install all needed dependencies.
|
88 |
+
#Pipfile.lock
|
89 |
+
|
90 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
91 |
+
__pypackages__/
|
92 |
+
|
93 |
+
# Celery stuff
|
94 |
+
celerybeat-schedule
|
95 |
+
celerybeat.pid
|
96 |
+
|
97 |
+
# SageMath parsed files
|
98 |
+
*.sage.py
|
99 |
+
|
100 |
+
# Environments
|
101 |
+
.env
|
102 |
+
.venv
|
103 |
+
env/
|
104 |
+
venv/
|
105 |
+
ENV/
|
106 |
+
env.bak/
|
107 |
+
venv.bak/
|
108 |
+
|
109 |
+
# Spyder project settings
|
110 |
+
.spyderproject
|
111 |
+
.spyproject
|
112 |
+
|
113 |
+
# Rope project settings
|
114 |
+
.ropeproject
|
115 |
+
|
116 |
+
# mkdocs documentation
|
117 |
+
/site
|
118 |
+
|
119 |
+
# mypy
|
120 |
+
.mypy_cache/
|
121 |
+
.dmypy.json
|
122 |
+
dmypy.json
|
123 |
+
|
124 |
+
# Pyre type checker
|
125 |
+
.pyre/
|
126 |
+
|
127 |
+
# pytype static type analyzer
|
128 |
+
.pytype/
|
129 |
+
|
130 |
+
# Cython debug symbols
|
131 |
+
cython_debug/
|
132 |
+
|
133 |
+
# Ignore VSCode settings
|
134 |
+
.vscode/
|
135 |
+
|
136 |
+
# Ignore dataset folder
|
137 |
+
/data/
|
138 |
+
/vector_embedding/
|
139 |
+
*.rdb
|
140 |
+
|
141 |
+
# Ignore venv
|
142 |
+
thery/
|
README.md
CHANGED
@@ -10,3 +10,4 @@ short_description: Virtual AI Mental Health Therapist
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
13 |
+
|
notebooks/chroma_db_test.ipynb
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 34,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"from langchain.vectorstores import Chroma \n",
|
10 |
+
"from langchain_openai import OpenAIEmbeddings"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": 35,
|
16 |
+
"metadata": {},
|
17 |
+
"outputs": [],
|
18 |
+
"source": [
|
19 |
+
"from langchain_core.documents import Document\n",
|
20 |
+
"\n",
|
21 |
+
"documents = [\n",
|
22 |
+
" Document(\n",
|
23 |
+
" page_content=\"Dogs are great companions, known for their loyalty and friendliness.\",\n",
|
24 |
+
" metadata={\"source\": \"mammal-pets-doc\"},\n",
|
25 |
+
" ),\n",
|
26 |
+
" Document(\n",
|
27 |
+
" page_content=\"Cats are independent pets that often enjoy their own space.\",\n",
|
28 |
+
" metadata={\"source\": \"mammal-pets-doc\"},\n",
|
29 |
+
" ),\n",
|
30 |
+
" Document(\n",
|
31 |
+
" page_content=\"Goldfish are popular pets for beginners, requiring relatively simple care.\",\n",
|
32 |
+
" metadata={\"source\": \"fish-pets-doc\"},\n",
|
33 |
+
" ),\n",
|
34 |
+
" Document(\n",
|
35 |
+
" page_content=\"Parrots are intelligent birds capable of mimicking human speech.\",\n",
|
36 |
+
" metadata={\"source\": \"bird-pets-doc\"},\n",
|
37 |
+
" ),\n",
|
38 |
+
" Document(\n",
|
39 |
+
" page_content=\"Rabbits are social animals that need plenty of space to hop around.\",\n",
|
40 |
+
" metadata={\"source\": \"mammal-pets-doc\"},\n",
|
41 |
+
" ),\n",
|
42 |
+
"]"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": 29,
|
48 |
+
"metadata": {},
|
49 |
+
"outputs": [],
|
50 |
+
"source": [
|
51 |
+
"import os\n",
|
52 |
+
"db_name = \"test.db\"\n",
|
53 |
+
"persist_directory = os.path.join(\"../vector_embedding\", db_name)"
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "code",
|
58 |
+
"execution_count": 36,
|
59 |
+
"metadata": {},
|
60 |
+
"outputs": [],
|
61 |
+
"source": [
|
62 |
+
"vector_store = Chroma.from_documents(\n",
|
63 |
+
" documents,\n",
|
64 |
+
" embedding=OpenAIEmbeddings(),\n",
|
65 |
+
" persist_directory= persist_directory\n",
|
66 |
+
")"
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "code",
|
71 |
+
"execution_count": 37,
|
72 |
+
"metadata": {},
|
73 |
+
"outputs": [
|
74 |
+
{
|
75 |
+
"data": {
|
76 |
+
"text/plain": [
|
77 |
+
"['1a334a0b-6669-4f14-a970-5a53434dbcdb',\n",
|
78 |
+
" '8e2c9a27-80b3-45ef-bee0-637defe360e6',\n",
|
79 |
+
" '884f0d44-8fb6-45e6-9b6c-256dbb5b8620',\n",
|
80 |
+
" '3cc66429-a94c-41bc-a9b0-e2f5f4a7bfbb',\n",
|
81 |
+
" '3c6af7ac-9070-466a-a35d-a6c151080667']"
|
82 |
+
]
|
83 |
+
},
|
84 |
+
"execution_count": 37,
|
85 |
+
"metadata": {},
|
86 |
+
"output_type": "execute_result"
|
87 |
+
}
|
88 |
+
],
|
89 |
+
"source": [
|
90 |
+
"vector_store.add_documents(documents)"
|
91 |
+
]
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"cell_type": "code",
|
95 |
+
"execution_count": 39,
|
96 |
+
"metadata": {},
|
97 |
+
"outputs": [
|
98 |
+
{
|
99 |
+
"name": "stderr",
|
100 |
+
"output_type": "stream",
|
101 |
+
"text": [
|
102 |
+
"/tmp/ipykernel_28475/485603143.py:1: LangChainDeprecationWarning: Since Chroma 0.4.x the manual persistence method is no longer supported as docs are automatically persisted.\n",
|
103 |
+
" vector_store.persist()\n"
|
104 |
+
]
|
105 |
+
}
|
106 |
+
],
|
107 |
+
"source": [
|
108 |
+
"vector_store.persist()"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"cell_type": "code",
|
113 |
+
"execution_count": null,
|
114 |
+
"metadata": {},
|
115 |
+
"outputs": [],
|
116 |
+
"source": []
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"cell_type": "code",
|
120 |
+
"execution_count": 38,
|
121 |
+
"metadata": {},
|
122 |
+
"outputs": [],
|
123 |
+
"source": [
|
124 |
+
"from langchain_core.documents import Document\n",
|
125 |
+
"from langchain_core.prompts import ChatPromptTemplate\n",
|
126 |
+
"from langchain_core.runnables import RunnablePassthrough"
|
127 |
+
]
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"cell_type": "code",
|
131 |
+
"execution_count": 18,
|
132 |
+
"metadata": {},
|
133 |
+
"outputs": [],
|
134 |
+
"source": [
|
135 |
+
"retriever = vector_store.as_retriever(\n",
|
136 |
+
" search_type = \"similarity\",\n",
|
137 |
+
" search_kwargs={\n",
|
138 |
+
" 'k': 2,\n",
|
139 |
+
" }\n",
|
140 |
+
")"
|
141 |
+
]
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"cell_type": "code",
|
145 |
+
"execution_count": 19,
|
146 |
+
"metadata": {},
|
147 |
+
"outputs": [
|
148 |
+
{
|
149 |
+
"data": {
|
150 |
+
"text/plain": [
|
151 |
+
"[[Document(metadata={'source': 'mammal-pets-doc'}, page_content='Cats are independent pets that often enjoy their own space.'),\n",
|
152 |
+
" Document(metadata={'source': 'mammal-pets-doc'}, page_content='Dogs are great companions, known for their loyalty and friendliness.')],\n",
|
153 |
+
" [Document(metadata={'source': 'mammal-pets-doc'}, page_content='Dogs are great companions, known for their loyalty and friendliness.'),\n",
|
154 |
+
" Document(metadata={'source': 'fish-pets-doc'}, page_content='Goldfish are popular pets for beginners, requiring relatively simple care.')]]"
|
155 |
+
]
|
156 |
+
},
|
157 |
+
"execution_count": 19,
|
158 |
+
"metadata": {},
|
159 |
+
"output_type": "execute_result"
|
160 |
+
}
|
161 |
+
],
|
162 |
+
"source": [
|
163 |
+
"retriever.batch([\"cats\", \"food\"])"
|
164 |
+
]
|
165 |
+
},
|
166 |
+
{
|
167 |
+
"cell_type": "code",
|
168 |
+
"execution_count": 22,
|
169 |
+
"metadata": {},
|
170 |
+
"outputs": [],
|
171 |
+
"source": [
|
172 |
+
"message =\"\"\"\n",
|
173 |
+
"\n",
|
174 |
+
"Write more on the type of animals based on this context only.\n",
|
175 |
+
"\n",
|
176 |
+
"{question}\n",
|
177 |
+
"\n",
|
178 |
+
"Context:\n",
|
179 |
+
"{context}\n",
|
180 |
+
"\"\"\"\n",
|
181 |
+
"\n",
|
182 |
+
"prompt = ChatPromptTemplate.from_messages([(\"human\", message)])"
|
183 |
+
]
|
184 |
+
},
|
185 |
+
{
|
186 |
+
"cell_type": "code",
|
187 |
+
"execution_count": 23,
|
188 |
+
"metadata": {},
|
189 |
+
"outputs": [],
|
190 |
+
"source": [
|
191 |
+
"from langchain_openai import ChatOpenAI\n",
|
192 |
+
"\n",
|
193 |
+
"model = ChatOpenAI(model=\"gpt-4o-mini\")"
|
194 |
+
]
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"cell_type": "code",
|
198 |
+
"execution_count": 24,
|
199 |
+
"metadata": {},
|
200 |
+
"outputs": [],
|
201 |
+
"source": [
|
202 |
+
"rag_chain = {\"context\": retriever,\"question\": RunnablePassthrough()} | prompt | model"
|
203 |
+
]
|
204 |
+
},
|
205 |
+
{
|
206 |
+
"cell_type": "code",
|
207 |
+
"execution_count": 26,
|
208 |
+
"metadata": {},
|
209 |
+
"outputs": [
|
210 |
+
{
|
211 |
+
"name": "stdout",
|
212 |
+
"output_type": "stream",
|
213 |
+
"text": [
|
214 |
+
"Based on the provided context, cats can be characterized as independent animals that typically prefer to have their own space. This independence sets them apart from other pets, such as dogs, which are known for their companionship and loyalty. \n",
|
215 |
+
"\n",
|
216 |
+
"Cats are often solitary hunters by nature, which contributes to their self-sufficient demeanor. They can entertain themselves for extended periods and may not require constant interaction or attention from their owners. This trait makes them suitable for individuals or families with busy lifestyles who may not be able to dedicate as much time to a pet.\n",
|
217 |
+
"\n",
|
218 |
+
"Additionally, cats are known for their grooming habits and often spend a significant amount of time cleaning themselves. They are also playful animals, enjoying toys and engaging in activities that mimic hunting behavior, such as pouncing and chasing.\n",
|
219 |
+
"\n",
|
220 |
+
"In terms of personality, cats can vary greatly. Some may be more social and enjoy interacting with humans, while others may be more reserved and prefer to observe from a distance. Their unique personalities can provide a rich and rewarding experience for pet owners who appreciate their distinctive traits.\n",
|
221 |
+
"\n",
|
222 |
+
"Overall, cats represent a diverse group of animals that thrive on independence while still providing companionship in a more subtle and often more relaxed manner compared to other pets.\n"
|
223 |
+
]
|
224 |
+
}
|
225 |
+
],
|
226 |
+
"source": [
|
227 |
+
"response = rag_chain.invoke(\"tell me about cats\")\n",
|
228 |
+
"\n",
|
229 |
+
"print(response.content)"
|
230 |
+
]
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"cell_type": "code",
|
234 |
+
"execution_count": null,
|
235 |
+
"metadata": {},
|
236 |
+
"outputs": [],
|
237 |
+
"source": []
|
238 |
+
}
|
239 |
+
],
|
240 |
+
"metadata": {
|
241 |
+
"kernelspec": {
|
242 |
+
"display_name": "Python 3",
|
243 |
+
"language": "python",
|
244 |
+
"name": "python3"
|
245 |
+
},
|
246 |
+
"language_info": {
|
247 |
+
"codemirror_mode": {
|
248 |
+
"name": "ipython",
|
249 |
+
"version": 3
|
250 |
+
},
|
251 |
+
"file_extension": ".py",
|
252 |
+
"mimetype": "text/x-python",
|
253 |
+
"name": "python",
|
254 |
+
"nbconvert_exporter": "python",
|
255 |
+
"pygments_lexer": "ipython3",
|
256 |
+
"version": "3.12.6"
|
257 |
+
}
|
258 |
+
},
|
259 |
+
"nbformat": 4,
|
260 |
+
"nbformat_minor": 2
|
261 |
+
}
|
notebooks/rag_test.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
poetry.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "thetherapist"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Using RAG technique to build an LLM model to serve as a therapist for mentally challenged people"
|
5 |
+
authors = ["Testimony Adekoya <[email protected]>"]
|
6 |
+
readme = "README.md"
|
7 |
+
|
8 |
+
[tool.poetry.dependencies]
|
9 |
+
python = "^3.12"
|
10 |
+
openai = "^1.40.1"
|
11 |
+
langchain-community = "^0.2.11"
|
12 |
+
langchain = "^0.2.12"
|
13 |
+
langchain-openai = "^0.1.20"
|
14 |
+
langchain-core = "^0.2.29"
|
15 |
+
chromadb = "^0.5.5"
|
16 |
+
fastapi = "^0.112.0"
|
17 |
+
gpt-researcher = "^0.8.4"
|
18 |
+
md2pdf = "^1.0.1"
|
19 |
+
|
20 |
+
|
21 |
+
[build-system]
|
22 |
+
requires = ["poetry-core"]
|
23 |
+
build-backend = "poetry.core.masonry.api"
|
requirements.txt
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain
|
2 |
+
langchain-core
|
3 |
+
langchain-community
|
4 |
+
langchain-text-splitters
|
5 |
+
langchain-openai
|
6 |
+
langchain-anthropic
|
7 |
+
langchain-google-community
|
8 |
+
langchain_chroma
|
9 |
+
langchain_huggingface
|
10 |
+
langchain_google_genai
|
11 |
+
fastapi
|
12 |
+
gpt-researcher
|
13 |
+
chromadb
|
14 |
+
md2pdf
|
15 |
+
openai
|
16 |
+
faiss-cpu
|
17 |
+
torch
|
18 |
+
transformers
|
19 |
+
sentence-transformers
|
20 |
+
numpy
|
21 |
+
pandas
|
22 |
+
langserve
|
23 |
+
langsmith
|
24 |
+
pydantic
|
25 |
+
spotipy
|
26 |
+
tavily-python
|
27 |
+
python-telegram-bot
|
src/__init__.py
ADDED
File without changes
|
src/api.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, APIRouter, Depends, HTTPException, Request, Response
|
2 |
+
from fastapi.responses import JSONResponse
|
3 |
+
import schedule
|
4 |
+
import time
|
5 |
+
import requests
|
6 |
+
import threading
|
7 |
+
import asyncio
|
8 |
+
import uvicorn
|
9 |
+
from multiprocessing import Process
|
10 |
+
|
11 |
+
from src.llm.routes import router as conversation_router
|
12 |
+
from src.tele_bot.bot import main as run_telegram_bot
|
13 |
+
from src.llm.core.config import settings
|
14 |
+
from src.llm.agents.conversation_agent import ConversationAgent
|
15 |
+
|
16 |
+
def on_startup():
|
17 |
+
global conversation_agent
|
18 |
+
conversation_agent = ConversationAgent()
|
19 |
+
|
20 |
+
|
21 |
+
app = FastAPI(
|
22 |
+
title="TheryAI API",
|
23 |
+
description="API for TheryAI",
|
24 |
+
version="0.1.0",
|
25 |
+
docs_url="/docs",
|
26 |
+
redoc_url="/redoc",
|
27 |
+
openapi_url="/openapi.json",
|
28 |
+
debug=True,
|
29 |
+
on_startup=[on_startup]
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
app.include_router(conversation_router)
|
35 |
+
|
36 |
+
@app.get("/")
|
37 |
+
async def home():
|
38 |
+
return {"message": "Welcome to TheryAI API"}
|
39 |
+
|
40 |
+
@app.get("/health")
|
41 |
+
async def health():
|
42 |
+
return {"status": "ok"}
|
43 |
+
|
44 |
+
def ping_server():
|
45 |
+
try:
|
46 |
+
print("Pinging server")
|
47 |
+
response = requests.get("https://theryai-api./")
|
48 |
+
except requests.exceptions.RequestException as e:
|
49 |
+
print("Server is down")
|
50 |
+
# send email to admin
|
51 |
+
|
52 |
+
schedule.every(10).minutes.do(ping_server)
|
53 |
+
|
54 |
+
|
55 |
+
def run_schedule():
|
56 |
+
while True:
|
57 |
+
schedule.run_pending()
|
58 |
+
time.sleep(1)
|
59 |
+
|
60 |
+
|
61 |
+
thread = threading.Thread(target=run_schedule)
|
62 |
+
thread.daemon = True
|
63 |
+
thread.start()
|
64 |
+
|
65 |
+
|
66 |
+
def run_bot():
|
67 |
+
uvicorn.run(
|
68 |
+
"src.api:app",
|
69 |
+
host="0.0.0.0",
|
70 |
+
port=8000,
|
71 |
+
log_level="info",
|
72 |
+
reload=True,
|
73 |
+
)
|
74 |
+
|
75 |
+
def run_telegram_bot():
|
76 |
+
asyncio(run_telegram_bot())
|
77 |
+
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
fastapi_process = Process(target=run_bot)
|
81 |
+
telegram_process = Process(target=run_telegram_bot)
|
82 |
+
|
83 |
+
fastapi_process.start()
|
84 |
+
telegram_process.start()
|
85 |
+
|
86 |
+
fastapi_process.join()
|
87 |
+
telegram_process.join()
|
88 |
+
|
src/llm/__init__.py
ADDED
File without changes
|
src/llm/agents/__init__.py
ADDED
File without changes
|
src/llm/agents/base_agent.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from datetime import datetime
|
3 |
+
from typing import Any, Dict, Optional
|
4 |
+
import logging
|
5 |
+
from src.llm.core.llm import TheryLLM
|
6 |
+
from src.llm.utils.logging import TheryBotLogger
|
7 |
+
from src.llm.memory.history import RedisHistory
|
8 |
+
from src.llm.memory.session_manager import SessionManager
|
9 |
+
|
10 |
+
class BaseAgent(ABC):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
llm: Optional[TheryLLM] = None,
|
14 |
+
history: Optional[RedisHistory] = None,
|
15 |
+
session_manager: Optional[SessionManager] = None
|
16 |
+
):
|
17 |
+
self.llm = llm or TheryLLM()
|
18 |
+
self.logger = TheryBotLogger()
|
19 |
+
self.history = history or RedisHistory()
|
20 |
+
self.memory_manager = session_manager or SessionManager()
|
21 |
+
|
22 |
+
@abstractmethod
|
23 |
+
def process(self, *args, **kwargs) -> Any:
|
24 |
+
"""Process the input and return response"""
|
25 |
+
pass
|
26 |
+
|
27 |
+
|
28 |
+
def _log_action(
|
29 |
+
self,
|
30 |
+
action: str,
|
31 |
+
metadata: Dict[str, Any],
|
32 |
+
level: str = logging.INFO,
|
33 |
+
session_id: Optional[str] = None,
|
34 |
+
user_id: Optional[str] = None
|
35 |
+
) -> None:
|
36 |
+
"""
|
37 |
+
Log agent actions with metadata and optional session/user context.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
action: The action being logged (e.g., "llm_generation_attempt").
|
41 |
+
metadata: A dictionary of metadata related to the action.
|
42 |
+
level: The log level ("info", "warning", "error", etc.).
|
43 |
+
session_id: Optional session ID for context.
|
44 |
+
user_id: Optional user ID for context.
|
45 |
+
"""
|
46 |
+
# Prepare log data
|
47 |
+
log_data = {
|
48 |
+
"action": action,
|
49 |
+
"metadata": metadata,
|
50 |
+
"timestamp": datetime.utcnow().isoformat(),
|
51 |
+
}
|
52 |
+
|
53 |
+
# Add session and user context if provided
|
54 |
+
if session_id:
|
55 |
+
log_data["session_id"] = session_id
|
56 |
+
if user_id:
|
57 |
+
log_data["user_id"] = user_id
|
58 |
+
|
59 |
+
# Log the data using the existing logger
|
60 |
+
if hasattr(self, "logger"):
|
61 |
+
self.logger.log_interaction(
|
62 |
+
interaction_type="agent_action",
|
63 |
+
data=log_data,
|
64 |
+
level=level
|
65 |
+
)
|
66 |
+
else:
|
67 |
+
print(f"Logging failed: {log_data}")
|
src/llm/agents/context_agent.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import asyncio
|
3 |
+
import logging
|
4 |
+
from typing import Dict, Any
|
5 |
+
from .base_agent import BaseAgent
|
6 |
+
from src.llm.core.config import settings
|
7 |
+
from src.llm.memory.vector_store import FAISSVectorSearch
|
8 |
+
from src.llm.models.schemas import ContextInfo
|
9 |
+
from langchain_community.tools.tavily_search import TavilySearchResults
|
10 |
+
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
|
11 |
+
|
12 |
+
class ContextAgent(BaseAgent):
|
13 |
+
def __init__(self, *args, **kwargs):
|
14 |
+
super().__init__(*args, **kwargs)
|
15 |
+
self._initialize_tools()
|
16 |
+
|
17 |
+
def _initialize_tools(self) -> None:
|
18 |
+
"""Lazy-load expensive resources"""
|
19 |
+
self.web_search = TavilySearchResults(
|
20 |
+
max_results=settings.TAVILY_MAX_RESULTS,
|
21 |
+
include_answer=settings.TAVILY_INCLUDE_ANSWER,
|
22 |
+
include_images=settings.TAVILY_INCLUDE_IMAGES,
|
23 |
+
api_wrapper=TavilySearchAPIWrapper(tavily_api_key=settings.TAVILY_API_KEY)
|
24 |
+
)
|
25 |
+
|
26 |
+
self.vector_search = FAISSVectorSearch()
|
27 |
+
|
28 |
+
def process(self, query: str) -> ContextInfo:
|
29 |
+
"""Gather context from multiple sources"""
|
30 |
+
web_context = self._get_web_context(query)
|
31 |
+
vector_context = self._get_vector_context(query)
|
32 |
+
|
33 |
+
combined_context = f"{web_context}\n\n{vector_context}"
|
34 |
+
|
35 |
+
self._log_action(action="context_gathered", metadata={"query": query, "web_context": web_context, "vector_context": vector_context}, level=logging.INFO)
|
36 |
+
return ContextInfo(
|
37 |
+
query=query,
|
38 |
+
web_context=web_context,
|
39 |
+
vector_context=vector_context,
|
40 |
+
combined_context=combined_context,
|
41 |
+
)
|
42 |
+
|
43 |
+
def _get_web_context(self, query: str) -> str:
|
44 |
+
try:
|
45 |
+
results = self.web_search.invoke(query)
|
46 |
+
return "\n".join([res["content"] for res in results])
|
47 |
+
except Exception as e:
|
48 |
+
self._log_action(action="web_search_error", metadata={"error": str(e)}, level=logging.ERROR)
|
49 |
+
return "Web search unavailable"
|
50 |
+
|
51 |
+
def _get_vector_context(self, query: str) -> str:
|
52 |
+
try:
|
53 |
+
return self.vector_search.search(query)
|
54 |
+
except Exception as e:
|
55 |
+
self._log_action(action="vector_search_error", metadata={"error": str(e)}, level=logging.ERROR)
|
56 |
+
return "Vector search unavailable"
|
57 |
+
|
58 |
+
|
59 |
+
async def process_async(self, query: str) -> ContextInfo:
|
60 |
+
return await asyncio.get_event_loop().run_in_executor(
|
61 |
+
None,
|
62 |
+
lambda: self.process(query)
|
63 |
+
)
|
src/llm/agents/conversation_agent.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uuid
|
2 |
+
import textwrap
|
3 |
+
import logging
|
4 |
+
import asyncio
|
5 |
+
from typing import Dict, Any, Optional, List
|
6 |
+
from src.llm.agents.base_agent import BaseAgent
|
7 |
+
from src.llm.agents.emotion_agent import EmotionAgent
|
8 |
+
from src.llm.agents.context_agent import ContextAgent
|
9 |
+
from src.llm.models.schemas import ConversationResponse, EmotionalAnalysis, ContextInfo
|
10 |
+
from src.llm.models.schemas import SessionData
|
11 |
+
from src.llm.memory.memory_manager import RedisMemoryManager
|
12 |
+
from src.llm.memory.session_manager import SessionManager
|
13 |
+
from src.llm.memory.history import RedisHistory
|
14 |
+
|
15 |
+
class ConversationAgent(BaseAgent):
|
16 |
+
def __init__(self, *args, **kwargs):
|
17 |
+
super().__init__(*args, **kwargs)
|
18 |
+
self.emotion_agent = EmotionAgent(llm=self.llm, history=self.history)
|
19 |
+
self.context_agent = ContextAgent(llm=self.llm, history=self.history)
|
20 |
+
self.memory_manager = RedisMemoryManager()
|
21 |
+
self.session_manager = SessionManager()
|
22 |
+
self.history = RedisHistory()
|
23 |
+
|
24 |
+
def process(
|
25 |
+
self,
|
26 |
+
query: str,
|
27 |
+
session_data: Optional[SessionData] = None
|
28 |
+
) -> ConversationResponse:
|
29 |
+
"""Process user query with emotional awareness and context"""
|
30 |
+
# Generate or validate IDs
|
31 |
+
if session_data:
|
32 |
+
user_id = self.session_manager.validate_session(session_data.session_id)
|
33 |
+
if user_id:
|
34 |
+
# Existing valid session
|
35 |
+
session_id = session_data.session_id
|
36 |
+
is_new_session = False
|
37 |
+
else:
|
38 |
+
# Expired session, create new
|
39 |
+
user_id, session_id = self.session_manager.generate_ids(session_data.user_id)
|
40 |
+
is_new_session = True
|
41 |
+
else:
|
42 |
+
# New conversation
|
43 |
+
user_id, session_id = self.session_manager.generate_ids()
|
44 |
+
is_new_session = True
|
45 |
+
|
46 |
+
chat_id = str(uuid.uuid4())
|
47 |
+
# Analyze emotion
|
48 |
+
emotion_analysis = self.emotion_agent.process(query)
|
49 |
+
|
50 |
+
# Gather context
|
51 |
+
context = self.context_agent.process(query)
|
52 |
+
context = ContextInfo(
|
53 |
+
query=context.query,
|
54 |
+
web_context=context.web_context,
|
55 |
+
vector_context=context.vector_context,
|
56 |
+
combined_context=context.combined_context
|
57 |
+
)
|
58 |
+
|
59 |
+
history_context= self.history.get_full_context(session_id)
|
60 |
+
|
61 |
+
combined_context = context.combined_context if context else None
|
62 |
+
|
63 |
+
# Generate response
|
64 |
+
response = self._generate_response(
|
65 |
+
query=query,
|
66 |
+
emotion_analysis=emotion_analysis,
|
67 |
+
context=combined_context,
|
68 |
+
chat_history=history_context
|
69 |
+
)
|
70 |
+
|
71 |
+
conversation_response = ConversationResponse(
|
72 |
+
session_data=SessionData(
|
73 |
+
user_id=user_id,
|
74 |
+
session_id=session_id,
|
75 |
+
is_new_user=(session_data is None),
|
76 |
+
is_new_session=is_new_session
|
77 |
+
),
|
78 |
+
response=response,
|
79 |
+
emotion_analysis=emotion_analysis,
|
80 |
+
context=context,
|
81 |
+
query=query,
|
82 |
+
safety_level="unknown",
|
83 |
+
suggested_resources=[]
|
84 |
+
)
|
85 |
+
|
86 |
+
self.memory_manager.store_conversation(session_id, chat_id, conversation_response)
|
87 |
+
self.history.add_conversation(session_id, chat_id, conversation_response)
|
88 |
+
|
89 |
+
self._log_action(action="conversation", metadata={"query": query, "response": response}, level=logging.INFO, session_id=session_id, user_id=user_id)
|
90 |
+
|
91 |
+
return conversation_response
|
92 |
+
|
93 |
+
def _generate_response(
|
94 |
+
self,
|
95 |
+
query: str,
|
96 |
+
emotion_analysis: Optional[EmotionalAnalysis],
|
97 |
+
context: Optional[ContextInfo],
|
98 |
+
chat_history: Optional[List[Dict]]
|
99 |
+
) -> str:
|
100 |
+
|
101 |
+
prompt = self._construct_response_prompt(
|
102 |
+
query=query,
|
103 |
+
emotion_analysis=emotion_analysis,
|
104 |
+
context=context,
|
105 |
+
chat_history=chat_history
|
106 |
+
)
|
107 |
+
|
108 |
+
response = self.llm.generate(prompt)
|
109 |
+
return response.content.strip()
|
110 |
+
|
111 |
+
def _construct_response_prompt(self, **kwargs) -> str:
|
112 |
+
# Implement sophisticated prompt construction
|
113 |
+
prompt = f"""
|
114 |
+
You are Thery AI, a compassionate virtual therapist who provides supportive, evidence-based advice and empathetic conversation. Your goal is to create a safe, non-judgmental, and empathetic environment for users to share their concerns. When generating your response, follow these steps internally:
|
115 |
+
|
116 |
+
Chain of Thoughts:
|
117 |
+
|
118 |
+
1. Acknowledge the Emotional State:
|
119 |
+
- Identify and validate the emotion expressed by the user.
|
120 |
+
- Use language that shows understanding and empathy.
|
121 |
+
|
122 |
+
2. Select Relevant Therapeutic Approach:
|
123 |
+
- Consider the user's concern, emotional state, and context to determine the most suitable therapeutic modality (e.g., Cognitive-Behavioral Therapy (CBT), Mindfulness-Based Stress Reduction (MBSR), Acceptance and Commitment Therapy (ACT), or Psychodynamic Therapy).
|
124 |
+
- Tailor your response to incorporate principles and techniques from the chosen approach.
|
125 |
+
|
126 |
+
3. Provide Evidence-Based Support:
|
127 |
+
- Incorporate relevant research or common therapeutic techniques where applicable.
|
128 |
+
- Ensure that your advice is grounded in best practices.
|
129 |
+
|
130 |
+
4. Incorporate Context Appropriately:
|
131 |
+
- Use the provided context (from previous interactions or additional background) to make your response more personalized and relevant.
|
132 |
+
|
133 |
+
5. Maintain a Supportive and Empathetic Tone:
|
134 |
+
- Craft your response as if you were speaking with a friend who cares deeply about the user’s well-being.
|
135 |
+
- Avoid clinical jargon; use accessible, warm, and encouraging language.
|
136 |
+
|
137 |
+
6. Include Specific Coping Strategies When Appropriate:
|
138 |
+
- Offer actionable suggestions (like deep breathing, mindfulness, journaling, or seeking additional support) that the user can try.
|
139 |
+
- Ask gentle follow-up questions to invite the user to share more, if needed.
|
140 |
+
|
141 |
+
Key Attributes:
|
142 |
+
|
143 |
+
1. Empathy: Understand and share feelings with users.
|
144 |
+
2. Active listening: Give full attention to users, understanding their concerns, and responding thoughtfully.
|
145 |
+
3. Non-judgmental: Avoid criticism or judgment, creating a safe and accepting environment.
|
146 |
+
4. Confidentiality: Maintain users' trust by keeping their information private.
|
147 |
+
5. Cultural competence: Understand and respect users' diverse backgrounds, values, and beliefs.
|
148 |
+
|
149 |
+
Conversation Guidelines:
|
150 |
+
|
151 |
+
1. Begin with an open-ended question to encourage users to share their concerns.
|
152 |
+
2. Use reflective listening to ensure understanding and show empathy.
|
153 |
+
3. Avoid giving direct advice; instead, guide users to explore their own thoughts and feelings.
|
154 |
+
4. Focus on empowering users to make their own decisions.
|
155 |
+
5. Manage conversations to maintain a calm and composed tone.
|
156 |
+
|
157 |
+
Important Instructions:
|
158 |
+
|
159 |
+
1. Do not attempt to diagnose or treat mental health conditions. You are not a licensed therapist.
|
160 |
+
2. Avoid providing explicit or graphic responses.
|
161 |
+
3. Do not share personal experiences or opinions.
|
162 |
+
4. Maintain a neutral and respectful tone.
|
163 |
+
5. If a user expresses suicidal thoughts or intentions, provide resources for immediate support (e.g., crisis hotlines, emergency services).
|
164 |
+
|
165 |
+
Example Response:
|
166 |
+
|
167 |
+
User: "I'm feeling overwhelmed with work and personal life."
|
168 |
+
|
169 |
+
You: "I can sense your frustration. Can you tell me more about what's been going on, and how you've been coping with these challenges?"
|
170 |
+
|
171 |
+
Please respond as a therapist would, using the guidelines and attributes above.
|
172 |
+
|
173 |
+
Input Variables:
|
174 |
+
|
175 |
+
- Chat History: {kwargs['chat_history']}
|
176 |
+
- User Query: {kwargs['query']}
|
177 |
+
- Emotional Analysis: {kwargs['emotion_analysis']}
|
178 |
+
- Context: {kwargs['context']}
|
179 |
+
|
180 |
+
Response Example:
|
181 |
+
|
182 |
+
- If the user says, “Hello,” start with a friendly greeting: "Hi there, I'm Thery AI. How can I help you today?"
|
183 |
+
- If the user later says, “I feel sad,” continue with: "I'm sorry to hear you're feeling sad. Can you tell me a bit more about what's been going on? Sometimes sharing details can help in understanding and easing your feelings."
|
184 |
+
|
185 |
+
User: "I'm feeling overwhelmed with work and personal life."
|
186 |
+
|
187 |
+
You: "I can sense your frustration. Can you tell me more about what's been going on, and how you've been coping with these challenges?"
|
188 |
+
|
189 |
+
Please respond as a therapist would, using the guidelines and attributes above. Make sure your responses are not overly long. BE NATURAL, SUUPPORTIVE, AND EMPHATIZING
|
190 |
+
|
191 |
+
"""
|
192 |
+
|
193 |
+
return textwrap.dedent(prompt).strip()
|
194 |
+
|
195 |
+
|
196 |
+
async def process_async(
|
197 |
+
self,
|
198 |
+
query: str,
|
199 |
+
session_data: Optional[SessionData] = None
|
200 |
+
) -> ConversationResponse:
|
201 |
+
|
202 |
+
return await asyncio.get_event_loop().run_in_executor(
|
203 |
+
None,
|
204 |
+
lambda: self.process(query, session_data)
|
205 |
+
)
|
src/llm/agents/emotion_agent.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import textwrap
|
2 |
+
import asyncio
|
3 |
+
from typing import Dict, Any
|
4 |
+
import logging
|
5 |
+
from .base_agent import BaseAgent
|
6 |
+
from src.llm.models.schemas import EmotionalAnalysis
|
7 |
+
|
8 |
+
class EmotionAgent(BaseAgent):
|
9 |
+
def process(self, text: str) -> EmotionalAnalysis:
|
10 |
+
"""Process text for emotional content"""
|
11 |
+
prompt = self._construct_emotion_prompt(text)
|
12 |
+
response = self.llm.generate(prompt)
|
13 |
+
analysis = self._parse_emotion_response(response.content)
|
14 |
+
self._log_action(action="emotion_analysis", metadata={"text": text, "analysis": analysis}, level=logging.INFO)
|
15 |
+
|
16 |
+
return EmotionalAnalysis(
|
17 |
+
primary_emotion=analysis['primary_emotion'],
|
18 |
+
intensity=analysis['intensity'],
|
19 |
+
secondary_emotions=analysis['secondary_emotions'],
|
20 |
+
triggers=analysis['emotional_triggers'],
|
21 |
+
coping_strategies=analysis['coping_strategies'],
|
22 |
+
confidence_score=analysis['confidence_score']
|
23 |
+
)
|
24 |
+
|
25 |
+
def _construct_emotion_prompt(self, text: str) -> str:
|
26 |
+
emotion_prompt = f"""
|
27 |
+
Analyze the emotional content in the following text:
|
28 |
+
Text: {text}
|
29 |
+
|
30 |
+
Provide analysis in the following format:
|
31 |
+
1. Primary emotion: [single emotion]
|
32 |
+
2. Intensity: [number between 1 and 10]
|
33 |
+
3. Secondary emotions: [comma-separated list of emotions]
|
34 |
+
4. Emotional triggers: [comma-separated list of triggers]
|
35 |
+
5. Suggested coping strategies: [comma-separated list of strategies]
|
36 |
+
6. Confidence score: [number between 0 and 1]
|
37 |
+
|
38 |
+
Example:
|
39 |
+
1. Primary emotion: Anxiety
|
40 |
+
2. Intensity: 7
|
41 |
+
3. Secondary emotions: Fear, Worry
|
42 |
+
4. Emotional triggers: Work deadline, Family conflict
|
43 |
+
5. Suggested coping strategies: Deep breathing, Journaling, Talking to a friend
|
44 |
+
6. Confidence score: 0.8
|
45 |
+
"""
|
46 |
+
|
47 |
+
return textwrap.dedent(emotion_prompt).strip()
|
48 |
+
|
49 |
+
def _parse_emotion_response(self, response: str) -> dict:
|
50 |
+
try:
|
51 |
+
analysis = {
|
52 |
+
'primary_emotion': '',
|
53 |
+
'intensity': 0,
|
54 |
+
'secondary_emotions': [],
|
55 |
+
'emotional_triggers': [],
|
56 |
+
'coping_strategies': [],
|
57 |
+
'confidence_score': 0.0
|
58 |
+
}
|
59 |
+
|
60 |
+
for line in response.split('\n'):
|
61 |
+
# Convert the line to string explicitly in case it's not
|
62 |
+
line = str(line).strip()
|
63 |
+
if not line:
|
64 |
+
continue
|
65 |
+
|
66 |
+
# Split on first colon only
|
67 |
+
parts = line.split(':', 1)
|
68 |
+
if len(parts) != 2:
|
69 |
+
continue
|
70 |
+
|
71 |
+
|
72 |
+
self._log_action(action="emotion_analysis_debug", metadata={"line":line}, level=logging.DEBUG)
|
73 |
+
|
74 |
+
# Ensure key is a string before calling lower()
|
75 |
+
key = str(parts[0]).strip().lower() # Explicitly convert to string
|
76 |
+
value = str(parts[1]).strip()
|
77 |
+
|
78 |
+
self._log_action(action="emotion_analysis_debug", metadata={"line":line, "key": key, "value": value}, level=logging.DEBUG)
|
79 |
+
|
80 |
+
if 'primary emotion' in key:
|
81 |
+
analysis['primary_emotion'] = value
|
82 |
+
elif 'intensity' in key:
|
83 |
+
# Convert intensity to integer safely
|
84 |
+
try:
|
85 |
+
analysis['intensity'] = int(value.strip('[]'))
|
86 |
+
except ValueError:
|
87 |
+
analysis['intensity'] = 5 # default value
|
88 |
+
elif 'secondary emotions' in key:
|
89 |
+
analysis['secondary_emotions'] = [
|
90 |
+
s.strip() for s in value.split(',') if s.strip()
|
91 |
+
]
|
92 |
+
elif 'emotional triggers' in key:
|
93 |
+
analysis['emotional_triggers'] = [
|
94 |
+
t.strip() for t in value.split(',') if t.strip()
|
95 |
+
]
|
96 |
+
elif 'suggested coping strategies' in key:
|
97 |
+
analysis['coping_strategies'] = [
|
98 |
+
c.strip() for c in value.split(',') if c.strip()
|
99 |
+
]
|
100 |
+
elif 'confidence score' in key:
|
101 |
+
# Convert confidence score to float safely
|
102 |
+
try:
|
103 |
+
analysis['confidence_score'] = float(value.strip('[]'))
|
104 |
+
except ValueError:
|
105 |
+
analysis['confidence_score'] = 0.5
|
106 |
+
|
107 |
+
if not analysis['primary_emotion']:
|
108 |
+
raise ValueError("Primary emotion not found in response")
|
109 |
+
|
110 |
+
self._log_action(action="emotion_analysis_success", metadata={"response": response, "analysis": analysis}, level=logging.INFO)
|
111 |
+
return analysis
|
112 |
+
|
113 |
+
except Exception as e:
|
114 |
+
self._log_action(
|
115 |
+
action="emotion_analysis_error",
|
116 |
+
metadata={"response": str(response), "error": str(e)},
|
117 |
+
level=logging.ERROR
|
118 |
+
)
|
119 |
+
raise ValueError(f"Failed to parse emotion response: {str(e)}")
|
120 |
+
|
121 |
+
async def process_async(self, text: str) -> EmotionalAnalysis:
|
122 |
+
return await asyncio.get_event_loop().run_in_executor(
|
123 |
+
None,
|
124 |
+
lambda: self.process(text)
|
125 |
+
)
|
src/llm/core/__init__.py
ADDED
File without changes
|
src/llm/core/config.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
import os
|
3 |
+
from pydantic_settings import BaseSettings
|
4 |
+
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
class Settings(BaseSettings):
|
8 |
+
GOOGLE_API_KEY: str = os.getenv("GOOGLE_API_KEY")
|
9 |
+
TAVILY_API_KEY: str = os.getenv("TAVILY_API_KEY")
|
10 |
+
REDIS_HOST: str = "localhost"
|
11 |
+
REDIS_PORT: int = 6379
|
12 |
+
REDIS_DB: int = 0
|
13 |
+
REDIS_USER: str = "redis"
|
14 |
+
REDIS_PASSWORD: str = ""
|
15 |
+
SESSION_TTL: int = 86400
|
16 |
+
MAX_RETRIES: int = 3
|
17 |
+
MAX_TOKENS: int = 200
|
18 |
+
SAFETY_THRESHOLD: float = 0.95
|
19 |
+
TAVILY_MAX_RESULTS: int = 3
|
20 |
+
TAVILY_INCLUDE_IMAGES: bool = False
|
21 |
+
TAVILY_INCLUDE_ANSWER: bool = True
|
22 |
+
LANGCHAIN_API_KEY: str = os.getenv("LANGCHAIN_API_KEY")
|
23 |
+
LANGCHAIN_TRACING_V2: str = os.getenv("LANGCHAIN_TRACING_V2")
|
24 |
+
LANGCHAIN_ENDPOINT: str = os.getenv("LANGCHAIN_ENDPOINT")
|
25 |
+
LANGSMITH_API_KEY: str = os.getenv("LANGSMITH_API_KEY")
|
26 |
+
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY")
|
27 |
+
CLAUDE_API_KEY: str = os.getenv("CLAUDE_API_KEY")
|
28 |
+
SPOTIFY_CLIENT_ID: str = os.getenv("SPOTIFY_CLIENT_ID")
|
29 |
+
SPOTIFY_CLIENT_SECRET: str = os.getenv("SPOTIFY_CLIENT_SECRET")
|
30 |
+
SPOTIFY_REDIRECT_URI: str = os.getenv("SPOTIFY_REDIRECT_URI")
|
31 |
+
TELEGRAM_BOT_TOKEN: str = os.getenv("TELEGRAM_BOT_TOKEN")
|
32 |
+
|
33 |
+
class Config:
|
34 |
+
env_file = ".env"
|
35 |
+
|
36 |
+
settings = Settings()
|
src/llm/core/llm.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Dict, Any
|
2 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
3 |
+
from langchain_core.messages import AIMessage
|
4 |
+
import logging
|
5 |
+
from src.llm.core.config import settings
|
6 |
+
from src.llm.utils.logging import TheryBotLogger
|
7 |
+
|
8 |
+
class LLMError(Exception):
|
9 |
+
"""Custom exception for LLM-related errors"""
|
10 |
+
pass
|
11 |
+
|
12 |
+
class TheryLLM:
|
13 |
+
"""Enhanced LLM wrapper with safety checks and response validation"""
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
model_name: str = "gemini-1.5-flash",
|
18 |
+
temperature: float = 0.3,
|
19 |
+
max_retries: int = 3,
|
20 |
+
safety_threshold: float = 0.75,
|
21 |
+
logger: Optional[TheryBotLogger] = None
|
22 |
+
):
|
23 |
+
self.model_name = model_name
|
24 |
+
self.temperature = temperature
|
25 |
+
self.max_retries = max_retries
|
26 |
+
self.safety_threshold = safety_threshold
|
27 |
+
self.logger = logger or TheryBotLogger()
|
28 |
+
self._initialize_llm()
|
29 |
+
|
30 |
+
def _initialize_llm(self) -> None:
|
31 |
+
"""Initialize the LLM with proper error handling"""
|
32 |
+
try:
|
33 |
+
self.llm = ChatGoogleGenerativeAI(
|
34 |
+
model=self.model_name,
|
35 |
+
temperature=self.temperature,
|
36 |
+
max_retries=self.max_retries,
|
37 |
+
google_api_key=settings.GOOGLE_API_KEY,
|
38 |
+
max_tokens= settings.MAX_TOKENS
|
39 |
+
)
|
40 |
+
self._session_active = True
|
41 |
+
except Exception as e:
|
42 |
+
self._session_active = False
|
43 |
+
self.logger.log_interaction(
|
44 |
+
interaction_type="llm_initialization_failed",
|
45 |
+
data={"error": str(e)},
|
46 |
+
level=logging.ERROR
|
47 |
+
)
|
48 |
+
raise LLMError(f"LLM initialization failed: {str(e)}")
|
49 |
+
|
50 |
+
def generate(self, prompt: str, **kwargs) -> AIMessage:
|
51 |
+
"""Generate a response with safety checks and validation"""
|
52 |
+
if not self._session_active:
|
53 |
+
self._initialize_llm()
|
54 |
+
|
55 |
+
try:
|
56 |
+
# Log the generation attempt
|
57 |
+
self.logger.log_interaction(
|
58 |
+
interaction_type="llm_generation_attempt",
|
59 |
+
data={"prompt": prompt, "kwargs": kwargs},
|
60 |
+
level=logging.INFO
|
61 |
+
)
|
62 |
+
|
63 |
+
# Generate response
|
64 |
+
response = self.llm.invoke(prompt)
|
65 |
+
|
66 |
+
# Validate response
|
67 |
+
validated_response = self._validate_response(response)
|
68 |
+
|
69 |
+
# Log successful generation
|
70 |
+
self.logger.log_interaction(
|
71 |
+
interaction_type="llm_generation_success",
|
72 |
+
data={"prompt": prompt, "response": str(validated_response)},
|
73 |
+
level=logging.INFO
|
74 |
+
)
|
75 |
+
|
76 |
+
return validated_response
|
77 |
+
|
78 |
+
except Exception as e:
|
79 |
+
self.logger.log_interaction(
|
80 |
+
interaction_type="llm_generation_error",
|
81 |
+
data={"prompt": prompt, "error": str(e)},
|
82 |
+
level=logging.ERROR
|
83 |
+
)
|
84 |
+
raise LLMError(f"Generation failed: {str(e)}")
|
85 |
+
|
86 |
+
def _validate_response(
|
87 |
+
self,
|
88 |
+
response: AIMessage
|
89 |
+
) -> AIMessage:
|
90 |
+
"""Validate response content and format"""
|
91 |
+
if not isinstance(response, AIMessage):
|
92 |
+
self.logger.log_interaction(
|
93 |
+
interaction_type="llm_invalid_response_type",
|
94 |
+
data={"response": response},
|
95 |
+
level=logging.ERROR
|
96 |
+
)
|
97 |
+
raise LLMError("Invalid response type")
|
98 |
+
|
99 |
+
if not response.content.strip():
|
100 |
+
self.logger.log_interaction(
|
101 |
+
interaction_type="llm_empty_response",
|
102 |
+
data={"response": response},
|
103 |
+
level=logging.ERROR
|
104 |
+
)
|
105 |
+
raise LLMError("Empty response content")
|
106 |
+
|
107 |
+
return response
|
src/llm/main.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from src.llm.agents.conversation_agent import ConversationAgent
|
3 |
+
from src.llm.utils.logging import TheryBotLogger
|
4 |
+
from src.llm.core.config import settings
|
5 |
+
|
6 |
+
def main():
|
7 |
+
# Initialize logger
|
8 |
+
logger = TheryBotLogger()
|
9 |
+
|
10 |
+
# Initialize main conversation agent
|
11 |
+
agent = ConversationAgent()
|
12 |
+
|
13 |
+
# Example interaction
|
14 |
+
query = "But I have been try to do this for quite a while now and I am still not able to get it right."
|
15 |
+
|
16 |
+
try:
|
17 |
+
# Process query
|
18 |
+
response = agent.process(query)
|
19 |
+
|
20 |
+
# Log interaction
|
21 |
+
logger.log_interaction(
|
22 |
+
interaction_type="user_interaction",
|
23 |
+
data={
|
24 |
+
"query": query,
|
25 |
+
"response": response.response,
|
26 |
+
"status": "success"
|
27 |
+
},
|
28 |
+
level=logging.INFO
|
29 |
+
)
|
30 |
+
|
31 |
+
# Print response
|
32 |
+
print(f"Thery AI: {response.response}")
|
33 |
+
|
34 |
+
except Exception as e:
|
35 |
+
logger.log_interaction(
|
36 |
+
interaction_type="error",
|
37 |
+
data={
|
38 |
+
"query": query,
|
39 |
+
"error": str(e)
|
40 |
+
},
|
41 |
+
level=logging.ERROR
|
42 |
+
)
|
43 |
+
print("An error occurred. Please try again.")
|
44 |
+
|
45 |
+
if __name__ == "__main__":
|
46 |
+
main()
|
src/llm/memory/__init__.py
ADDED
File without changes
|
src/llm/memory/history.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import time
|
3 |
+
from datetime import timedelta
|
4 |
+
from typing import List, Dict, Any, Optional
|
5 |
+
from .redis_connection import RedisConnection
|
6 |
+
from src.llm.models.schemas import ConversationResponse
|
7 |
+
from src.llm.core.config import settings
|
8 |
+
|
9 |
+
class RedisHistory:
|
10 |
+
def __init__(self, session_ttl: int = settings.SESSION_TTL):
|
11 |
+
self.redis = RedisConnection().client
|
12 |
+
self.session_ttl = session_ttl
|
13 |
+
|
14 |
+
def add_conversation(self, session_id: str, chat_id: str, response: ConversationResponse) -> None:
|
15 |
+
"""
|
16 |
+
Store complete conversation response in history
|
17 |
+
"""
|
18 |
+
# Store in session-specific list
|
19 |
+
self.redis.rpush(
|
20 |
+
f"session:{session_id}:history",
|
21 |
+
json.dumps({
|
22 |
+
'chat_id': chat_id,
|
23 |
+
'response': response.dict(),
|
24 |
+
'timestamp': time.time()
|
25 |
+
})
|
26 |
+
)
|
27 |
+
|
28 |
+
# Set TTL for session history
|
29 |
+
self.redis.expire(f"session:{session_id}:history", self.session_ttl)
|
30 |
+
|
31 |
+
def get_conversation_history(self, session_id: str, limit: int = 10) -> List[Dict[str, Any]]:
|
32 |
+
"""
|
33 |
+
Retrieve conversation history with optional limit
|
34 |
+
"""
|
35 |
+
messages = self.redis.lrange(f"session:{session_id}:history", -limit, -1)
|
36 |
+
return [
|
37 |
+
{
|
38 |
+
'chat_id': json.loads(msg)['chat_id'],
|
39 |
+
'response': ConversationResponse(**json.loads(msg)['response']),
|
40 |
+
'timestamp': json.loads(msg)['timestamp']
|
41 |
+
}
|
42 |
+
for msg in messages
|
43 |
+
]
|
44 |
+
|
45 |
+
def get_full_context(self, session_id: str) -> str:
|
46 |
+
"""
|
47 |
+
Generate conversation context string for LLM prompts
|
48 |
+
"""
|
49 |
+
history = self.get_conversation_history(session_id)
|
50 |
+
context_lines = []
|
51 |
+
|
52 |
+
for entry in history:
|
53 |
+
response = entry['response']
|
54 |
+
context_lines.append(
|
55 |
+
f"User: {response.query}\n"
|
56 |
+
f"Therapist: {response.response}\n"
|
57 |
+
f"Emotions: {response.emotion_analysis.primary_emotion} "
|
58 |
+
f"(Intensity: {response.emotion_analysis.intensity})\n"
|
59 |
+
)
|
60 |
+
|
61 |
+
return "\n".join(context_lines)
|
62 |
+
|
63 |
+
def clear_history(self, session_id: str) -> None:
|
64 |
+
"""
|
65 |
+
Clear session history
|
66 |
+
"""
|
67 |
+
self.redis.delete(f"session:{session_id}:history")
|
src/llm/memory/memory_manager.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any, Optional
|
2 |
+
from .redis_connection import RedisConnection
|
3 |
+
from src.llm.models.schemas import ConversationResponse
|
4 |
+
import json
|
5 |
+
import time
|
6 |
+
|
7 |
+
class RedisMemoryManager:
|
8 |
+
def __init__(self):
|
9 |
+
self.redis = RedisConnection().client
|
10 |
+
|
11 |
+
def store_conversation(self, session_id: str, chat_id: str, response: ConversationResponse) -> None:
|
12 |
+
"""
|
13 |
+
Store complete conversation response with metadata
|
14 |
+
"""
|
15 |
+
response_data = response.dict()
|
16 |
+
timestamp = time.time()
|
17 |
+
|
18 |
+
# Store in session-specific hash
|
19 |
+
self.redis.hset(
|
20 |
+
f"session:{session_id}:chats",
|
21 |
+
chat_id,
|
22 |
+
json.dumps({
|
23 |
+
'response': response_data,
|
24 |
+
'timestamp': timestamp
|
25 |
+
})
|
26 |
+
)
|
27 |
+
|
28 |
+
# Update session metadata
|
29 |
+
self.redis.hset(
|
30 |
+
f"session:{session_id}",
|
31 |
+
mapping={
|
32 |
+
'last_chat_id': chat_id,
|
33 |
+
'last_updated': str(timestamp)
|
34 |
+
}
|
35 |
+
)
|
36 |
+
|
37 |
+
def get_conversation(self, session_id: str, chat_id: str) -> Optional[ConversationResponse]:
|
38 |
+
"""
|
39 |
+
Retrieve specific conversation response
|
40 |
+
"""
|
41 |
+
data = self.redis.hget(f"session:{session_id}:chats", chat_id)
|
42 |
+
if data:
|
43 |
+
return ConversationResponse(**json.loads(data)['response'])
|
44 |
+
return None
|
45 |
+
|
46 |
+
def get_session_conversations(self, session_id: str) -> Dict[str, Any]:
|
47 |
+
"""
|
48 |
+
Get all conversations for a session
|
49 |
+
"""
|
50 |
+
conversations = self.redis.hgetall(f"session:{session_id}:chats")
|
51 |
+
return {
|
52 |
+
chat_id: ConversationResponse(**json.loads(data)['response'])
|
53 |
+
for chat_id, data in conversations.items()
|
54 |
+
}
|
55 |
+
|
56 |
+
def update_emotional_state(self, session_id: str, emotions: Dict[str, Any]) -> None:
|
57 |
+
"""
|
58 |
+
Update emotional state tracking
|
59 |
+
"""
|
60 |
+
self.redis.hset(
|
61 |
+
f"session:{session_id}:state",
|
62 |
+
'emotions',
|
63 |
+
json.dumps(emotions)
|
64 |
+
)
|
65 |
+
|
66 |
+
def get_emotional_state(self, session_id: str) -> Dict[str, Any]:
|
67 |
+
"""
|
68 |
+
Retrieve current emotional state
|
69 |
+
"""
|
70 |
+
data = self.redis.hget(f"session:{session_id}:state", 'emotions')
|
71 |
+
return json.loads(data) if data else {}
|
src/llm/memory/redis_connection.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import redis
|
2 |
+
import logging
|
3 |
+
from src.llm.core.config import settings
|
4 |
+
from src.llm.utils.logging import TheryBotLogger
|
5 |
+
|
6 |
+
class RedisConnection:
|
7 |
+
_instance = None
|
8 |
+
|
9 |
+
def __new__(cls):
|
10 |
+
if not cls._instance:
|
11 |
+
cls._instance = super().__new__(cls)
|
12 |
+
cls._instance._initialize_self()
|
13 |
+
return cls._instance
|
14 |
+
|
15 |
+
def _initialize_self(self) -> None:
|
16 |
+
self.logger = TheryBotLogger()
|
17 |
+
self.redis = redis.Redis(
|
18 |
+
host=settings.REDIS_HOST,
|
19 |
+
port=settings.REDIS_PORT,
|
20 |
+
db=settings.REDIS_DB,
|
21 |
+
password=settings.REDIS_PASSWORD,
|
22 |
+
decode_responses=True
|
23 |
+
)
|
24 |
+
try:
|
25 |
+
self.redis.ping()
|
26 |
+
except redis.ConnectionError as e:
|
27 |
+
self.logger.log_interaction(
|
28 |
+
interaction_type="redis_connection_failed",
|
29 |
+
data={"error": str(e)},
|
30 |
+
level=logging.ERROR
|
31 |
+
)
|
32 |
+
raise RuntimeError(f"Redis connection failed: {str(e)}")
|
33 |
+
|
34 |
+
@property
|
35 |
+
def client(self):
|
36 |
+
if not self.redis.ping():
|
37 |
+
self._initialize_self()
|
38 |
+
return self.redis
|
src/llm/memory/session_manager.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import uuid
|
3 |
+
from typing import Optional, Tuple
|
4 |
+
from .redis_connection import RedisConnection
|
5 |
+
from src.llm.core.config import settings
|
6 |
+
|
7 |
+
class SessionManager:
|
8 |
+
def __init__(self):
|
9 |
+
self.redis = RedisConnection().client
|
10 |
+
|
11 |
+
def generate_ids(self, existing_user_id: Optional[str] = None) -> Tuple[str, str]:
|
12 |
+
"""
|
13 |
+
Generate or validate user/session IDs
|
14 |
+
Returns: (user_id, session_id)
|
15 |
+
"""
|
16 |
+
user_id = self._get_or_create_user_id(existing_user_id)
|
17 |
+
session_id = self._create_session(user_id)
|
18 |
+
return user_id, session_id
|
19 |
+
|
20 |
+
def _get_or_create_user_id(self, existing_user_id: Optional[str]) -> str:
|
21 |
+
if existing_user_id:
|
22 |
+
if self.redis.exists(f"user:{existing_user_id}"):
|
23 |
+
return existing_user_id
|
24 |
+
# If invalid existing ID, generate new
|
25 |
+
return str(uuid.uuid4())
|
26 |
+
return str(uuid.uuid4())
|
27 |
+
|
28 |
+
def _create_session(self, user_id: str) -> str:
|
29 |
+
session_id = str(uuid.uuid4())
|
30 |
+
# Store session metadata
|
31 |
+
self.redis.hset(f"session:{session_id}", mapping={
|
32 |
+
"user_id": user_id,
|
33 |
+
"created_at": str(time.time()),
|
34 |
+
"activity": str(time.time())
|
35 |
+
})
|
36 |
+
# Set TTL (24 hours by default)
|
37 |
+
self.redis.expire(f"session:{session_id}", settings.SESSION_TTL)
|
38 |
+
# Link to user
|
39 |
+
self.redis.sadd(f"user:{user_id}:sessions", session_id)
|
40 |
+
return session_id
|
41 |
+
|
42 |
+
def validate_session(self, session_id: str) -> Optional[str]:
|
43 |
+
"""Returns user_id if valid session"""
|
44 |
+
if self.redis.exists(f"session:{session_id}"):
|
45 |
+
# Update last activity
|
46 |
+
self.redis.hset(f"session:{session_id}", "activity", str(time.time()))
|
47 |
+
return self.redis.hget(f"session:{session_id}", "user_id")
|
48 |
+
return None
|
src/llm/memory/vector_store.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import List, Optional
|
3 |
+
import logging
|
4 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
5 |
+
from langchain_community.vectorstores import FAISS
|
6 |
+
from src.llm.utils.logging import TheryBotLogger
|
7 |
+
|
8 |
+
class FAISSVectorSearch:
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
embedding_model: Optional[HuggingFaceEmbeddings] = None,
|
12 |
+
db_path: Path = Path("vector_embedding/mental_health_vector_db"),
|
13 |
+
k: int = 5,
|
14 |
+
logger: Optional[TheryBotLogger] = None
|
15 |
+
):
|
16 |
+
self.embedding_model = embedding_model or self._get_default_embedding_model()
|
17 |
+
self.db_path = db_path
|
18 |
+
self.k = k
|
19 |
+
self.logger = logger or TheryBotLogger()
|
20 |
+
self._initialize_store()
|
21 |
+
|
22 |
+
def _get_default_embedding_model(self) -> HuggingFaceEmbeddings:
|
23 |
+
return HuggingFaceEmbeddings(
|
24 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
25 |
+
model_kwargs={"device": "cpu"},
|
26 |
+
encode_kwargs={
|
27 |
+
"padding": "max_length",
|
28 |
+
"max_length": 512,
|
29 |
+
"truncation": True,
|
30 |
+
"normalize_embeddings": True
|
31 |
+
}
|
32 |
+
)
|
33 |
+
|
34 |
+
def _initialize_store(self) -> None:
|
35 |
+
if self.db_path.exists():
|
36 |
+
self.vectorstore = FAISS.load_local(
|
37 |
+
str(self.db_path),
|
38 |
+
self.embedding_model,
|
39 |
+
allow_dangerous_deserialization=True
|
40 |
+
)
|
41 |
+
else:
|
42 |
+
# Initialize with empty store
|
43 |
+
self.vectorstore = FAISS.from_texts(
|
44 |
+
[""], self.embedding_model
|
45 |
+
)
|
46 |
+
|
47 |
+
def search(self, query: str, k: Optional[int] = None) -> List[str]:
|
48 |
+
try:
|
49 |
+
results = self.vectorstore.similarity_search(
|
50 |
+
query,
|
51 |
+
k=(k or self.k)
|
52 |
+
)
|
53 |
+
return [res.page_content for res in results]
|
54 |
+
except Exception as e:
|
55 |
+
# Log error and return empty results
|
56 |
+
self.logger.log_interaction(
|
57 |
+
interaction_type="vector_search_error",
|
58 |
+
data={"error": str(e)},
|
59 |
+
level=logging.ERROR
|
60 |
+
)
|
61 |
+
return []
|
62 |
+
|
63 |
+
def add_texts(self, texts: List[str]) -> None:
|
64 |
+
"""Add new texts to the vector store"""
|
65 |
+
self.vectorstore.add_texts(texts)
|
66 |
+
# Optionally save after adding
|
67 |
+
self.save()
|
68 |
+
|
69 |
+
def save(self) -> None:
|
70 |
+
"""Save the vector store to disk"""
|
71 |
+
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
72 |
+
self.vectorstore.save_local(str(self.db_path))
|
src/llm/models/__init__.py
ADDED
File without changes
|
src/llm/models/schemas.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
from typing import Dict, Any, Optional, List
|
3 |
+
|
4 |
+
class EmotionalAnalysis(BaseModel):
|
5 |
+
primary_emotion: str
|
6 |
+
intensity: int = Field(..., ge=1, le=10)
|
7 |
+
secondary_emotions: List[str]
|
8 |
+
triggers: List[str]
|
9 |
+
coping_strategies: List[str] = []
|
10 |
+
confidence_score: float = Field(..., ge=0, le=1)
|
11 |
+
|
12 |
+
class ContextInfo(BaseModel):
|
13 |
+
query: str
|
14 |
+
web_context: str = ""
|
15 |
+
vector_context: List[str] = Field(default_factory=list)
|
16 |
+
combined_context: str = ""
|
17 |
+
|
18 |
+
class SessionData(BaseModel):
|
19 |
+
user_id: str = Field(..., description="Unique user identifier")
|
20 |
+
session_id: str = Field(..., description="Current session identifier")
|
21 |
+
is_new_user: bool = Field(False, description="Flag for new user detection")
|
22 |
+
is_new_session: bool = Field(False, description="Flag for new session detection")
|
23 |
+
|
24 |
+
class ConversationResponse(BaseModel):
|
25 |
+
session_data: SessionData
|
26 |
+
response: str = Field(..., description="Primary assistant response")
|
27 |
+
emotion_analysis: EmotionalAnalysis
|
28 |
+
context: ContextInfo = Field(default_factory=ContextInfo)
|
29 |
+
query: str
|
30 |
+
safety_level: str = Field("unknown", description="Assessment of response safety") # Default value
|
31 |
+
suggested_resources: List[str] = Field(default_factory=list)
|
src/llm/routes.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
2 |
+
from fastapi.responses import JSONResponse
|
3 |
+
from typing import List, Optional
|
4 |
+
from datetime import datetime
|
5 |
+
|
6 |
+
from src.llm.agents.conversation_agent import ConversationAgent
|
7 |
+
from src.llm.models.schemas import ConversationResponse, SessionData
|
8 |
+
from src.llm.utils.logging import TheryBotLogger
|
9 |
+
from src.llm.memory.history import RedisHistory
|
10 |
+
from src.llm.memory.memory_manager import RedisMemoryManager
|
11 |
+
from src.llm.memory.session_manager import SessionManager
|
12 |
+
from src.llm.core.config import settings
|
13 |
+
|
14 |
+
router = APIRouter(
|
15 |
+
prefix="/api/v1",
|
16 |
+
tags=["TheryAI Services"],
|
17 |
+
responses={
|
18 |
+
200: {"description": "Success"},
|
19 |
+
400: {"description": "Bad Request"},
|
20 |
+
404: {"description": "Not found"},
|
21 |
+
500: {"description": "Internal Server Error"}
|
22 |
+
},
|
23 |
+
)
|
24 |
+
|
25 |
+
# Initialize managers
|
26 |
+
session_manager = SessionManager()
|
27 |
+
memory_manager = RedisMemoryManager()
|
28 |
+
history = RedisHistory()
|
29 |
+
logger = TheryBotLogger()
|
30 |
+
conversation_agent = ConversationAgent()
|
31 |
+
|
32 |
+
@router.post("/users", response_model=dict)
|
33 |
+
async def create_user():
|
34 |
+
"""Create a new user and return user_id"""
|
35 |
+
try:
|
36 |
+
user_id, _ = session_manager.generate_ids()
|
37 |
+
return {"user_id": user_id}
|
38 |
+
except Exception as e:
|
39 |
+
logger.error(f"Error creating user: {str(e)}")
|
40 |
+
raise HTTPException(status_code=500, detail="Failed to create user")
|
41 |
+
|
42 |
+
@router.get("/users/{user_id}/sessions", response_model=List[dict])
|
43 |
+
async def get_user_sessions(user_id: str):
|
44 |
+
"""Get all sessions for a user"""
|
45 |
+
try:
|
46 |
+
sessions = session_manager.get_user_sessions(user_id)
|
47 |
+
return sessions
|
48 |
+
except Exception as e:
|
49 |
+
logger.error(f"Error fetching sessions for user {user_id}: {str(e)}")
|
50 |
+
raise HTTPException(status_code=404, detail="User not found")
|
51 |
+
|
52 |
+
@router.post("/sessions", response_model=SessionData)
|
53 |
+
async def create_session(user_id: Optional[str] = None):
|
54 |
+
"""Create a new session"""
|
55 |
+
try:
|
56 |
+
user_id, session_id = session_manager.generate_ids(existing_user_id=user_id)
|
57 |
+
return SessionData(
|
58 |
+
user_id=user_id,
|
59 |
+
session_id=session_id,
|
60 |
+
is_new_user=(user_id is None),
|
61 |
+
is_new_session=True
|
62 |
+
)
|
63 |
+
except Exception as e:
|
64 |
+
logger.error(f"Error creating session: {str(e)}")
|
65 |
+
raise HTTPException(status_code=500, detail="Failed to create session")
|
66 |
+
|
67 |
+
@router.get("/sessions/{session_id}", response_model=SessionData)
|
68 |
+
async def get_session(session_id: str):
|
69 |
+
"""Get session metadata"""
|
70 |
+
try:
|
71 |
+
user_id = session_manager.validate_session(session_id)
|
72 |
+
if not user_id:
|
73 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
74 |
+
|
75 |
+
return SessionData(
|
76 |
+
user_id=user_id,
|
77 |
+
session_id=session_id,
|
78 |
+
is_new_user=False,
|
79 |
+
is_new_session=False
|
80 |
+
)
|
81 |
+
except HTTPException as he:
|
82 |
+
raise he
|
83 |
+
except Exception as e:
|
84 |
+
logger.error(f"Error fetching session {session_id}: {str(e)}")
|
85 |
+
raise HTTPException(status_code=500, detail="Failed to fetch session")
|
86 |
+
|
87 |
+
@router.get("/sessions/{session_id}/messages", response_model=List[ConversationResponse])
|
88 |
+
async def get_session_messages(
|
89 |
+
session_id: str,
|
90 |
+
limit: Optional[int] = 10
|
91 |
+
):
|
92 |
+
"""Get messages from a session"""
|
93 |
+
try:
|
94 |
+
if not session_manager.validate_session(session_id):
|
95 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
96 |
+
|
97 |
+
messages = history.get_conversation_history(session_id, limit=limit)
|
98 |
+
return [msg["response"] for msg in messages]
|
99 |
+
except HTTPException as he:
|
100 |
+
raise he
|
101 |
+
except Exception as e:
|
102 |
+
logger.error(f"Error fetching messages for session {session_id}: {str(e)}")
|
103 |
+
raise HTTPException(status_code=500, detail="Failed to fetch messages")
|
104 |
+
|
105 |
+
@router.post("/sessions/{session_id}/messages", response_model=ConversationResponse)
|
106 |
+
async def create_message(
|
107 |
+
session_id: str,
|
108 |
+
message: str,
|
109 |
+
background_tasks: BackgroundTasks
|
110 |
+
):
|
111 |
+
"""Create a new message in a session"""
|
112 |
+
try:
|
113 |
+
user_id = session_manager.validate_session(session_id)
|
114 |
+
if not user_id:
|
115 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
116 |
+
|
117 |
+
session_data = SessionData(
|
118 |
+
user_id=user_id,
|
119 |
+
session_id=session_id,
|
120 |
+
is_new_user=False,
|
121 |
+
is_new_session=False
|
122 |
+
)
|
123 |
+
|
124 |
+
response = await conversation_agent.process_async(
|
125 |
+
query=message,
|
126 |
+
session_data=session_data
|
127 |
+
)
|
128 |
+
|
129 |
+
# Store conversation asynchronously
|
130 |
+
background_tasks.add_task(
|
131 |
+
memory_manager.store_conversation,
|
132 |
+
session_id,
|
133 |
+
str(datetime.now().timestamp()),
|
134 |
+
response
|
135 |
+
)
|
136 |
+
|
137 |
+
return response
|
138 |
+
except HTTPException as he:
|
139 |
+
raise he
|
140 |
+
except Exception as e:
|
141 |
+
logger.error(f"Error processing message in session {session_id}: {str(e)}")
|
142 |
+
raise HTTPException(status_code=500, detail="Failed to process message")
|
143 |
+
|
144 |
+
@router.get("/sessions/{session_id}/memory", response_model=dict)
|
145 |
+
async def get_session_memory(session_id: str):
|
146 |
+
"""Get all memory data for a session"""
|
147 |
+
try:
|
148 |
+
if not session_manager.validate_session(session_id):
|
149 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
150 |
+
|
151 |
+
conversations = memory_manager.get_session_conversations(session_id)
|
152 |
+
emotional_state = memory_manager.get_emotional_state(session_id)
|
153 |
+
|
154 |
+
return {
|
155 |
+
"conversations": conversations,
|
156 |
+
"emotional_state": emotional_state
|
157 |
+
}
|
158 |
+
except HTTPException as he:
|
159 |
+
raise he
|
160 |
+
except Exception as e:
|
161 |
+
logger.error(f"Error fetching memory for session {session_id}: {str(e)}")
|
162 |
+
raise HTTPException(status_code=500, detail="Failed to fetch memory")
|
163 |
+
|
164 |
+
@router.delete("/sessions/{session_id}", response_model=dict)
|
165 |
+
async def end_session(session_id: str):
|
166 |
+
"""End a session and clean up resources"""
|
167 |
+
try:
|
168 |
+
if not session_manager.validate_session(session_id):
|
169 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
170 |
+
|
171 |
+
history.clear_history(session_id)
|
172 |
+
session_manager.end_session(session_id)
|
173 |
+
|
174 |
+
return {"message": "Session ended successfully"}
|
175 |
+
except HTTPException as he:
|
176 |
+
raise he
|
177 |
+
except Exception as e:
|
178 |
+
logger.error(f"Error ending session {session_id}: {str(e)}")
|
179 |
+
raise HTTPException(status_code=500, detail="Failed to end session")
|
src/llm/utils/__init__.py
ADDED
File without changes
|
src/llm/utils/logging.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import json
|
3 |
+
from datetime import datetime
|
4 |
+
from typing import Any, Dict
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
class TheryBotLogger:
|
8 |
+
def __init__(self, log_dir: Path = Path("logs")):
|
9 |
+
self.log_dir = log_dir
|
10 |
+
self._setup_logging()
|
11 |
+
|
12 |
+
def _setup_logging(self) -> None:
|
13 |
+
self.log_dir.mkdir(parents=True, exist_ok=True)
|
14 |
+
|
15 |
+
# Setup file handler
|
16 |
+
file_handler = logging.FileHandler(
|
17 |
+
self.log_dir / f"thery_bot_{datetime.now():%Y%m%d}.log"
|
18 |
+
)
|
19 |
+
file_handler.setLevel(logging.INFO)
|
20 |
+
|
21 |
+
# Setup console handler
|
22 |
+
console_handler = logging.StreamHandler()
|
23 |
+
console_handler.setLevel(logging.WARNING)
|
24 |
+
|
25 |
+
# Setup formatters
|
26 |
+
file_formatter = logging.Formatter(
|
27 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
28 |
+
)
|
29 |
+
console_formatter = logging.Formatter(
|
30 |
+
'%(levelname)s: %(message)s'
|
31 |
+
)
|
32 |
+
|
33 |
+
file_handler.setFormatter(file_formatter)
|
34 |
+
console_handler.setFormatter(console_formatter)
|
35 |
+
|
36 |
+
# Setup root logger
|
37 |
+
logging.root.setLevel(logging.INFO)
|
38 |
+
logging.root.addHandler(file_handler)
|
39 |
+
logging.root.addHandler(console_handler)
|
40 |
+
|
41 |
+
def log_interaction(
|
42 |
+
self,
|
43 |
+
interaction_type: str,
|
44 |
+
data: Dict[str, Any],
|
45 |
+
level: int = logging.INFO
|
46 |
+
) -> None:
|
47 |
+
"""Log an interaction with structured data"""
|
48 |
+
log_entry = {
|
49 |
+
"timestamp": datetime.now().isoformat(),
|
50 |
+
"type": interaction_type,
|
51 |
+
"data": data
|
52 |
+
}
|
53 |
+
|
54 |
+
logging.log(level, json.dumps(log_entry))
|
src/music/__init__.py
ADDED
File without changes
|
src/music/clients/__init__.py
ADDED
File without changes
|
src/music/clients/spotify_client.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from spotipy import Spotify
|
2 |
+
from spotipy.oauth2 import SpotifyClientCredentials
|
3 |
+
from typing import Dict, List
|
4 |
+
from src.music.config.settings import SpotifyConfig, logger
|
5 |
+
from src.music.models.data_models import RecommendationParameters
|
6 |
+
|
7 |
+
class SpotifyClient:
|
8 |
+
"""Handles Spotify authentication and API operations"""
|
9 |
+
|
10 |
+
def __init__(self, config: SpotifyConfig):
|
11 |
+
self.config = config
|
12 |
+
self.config.validate()
|
13 |
+
self.authenticate()
|
14 |
+
|
15 |
+
def authenticate(self):
|
16 |
+
"""Authenticate and set up the Spotify client"""
|
17 |
+
self.client_credentials_manager = SpotifyClientCredentials(
|
18 |
+
client_id=self.config.client_id,
|
19 |
+
client_secret=self.config.client_secret
|
20 |
+
)
|
21 |
+
self.client = Spotify(client_credentials_manager=self.client_credentials_manager)
|
22 |
+
|
23 |
+
def get_recommendations(self, params: RecommendationParameters) -> List[Dict]:
|
24 |
+
"""Get track recommendations from Spotify"""
|
25 |
+
try:
|
26 |
+
response = self.client.recommendations(
|
27 |
+
seed_genres=params.seed_genres,
|
28 |
+
target_features=params.target_features,
|
29 |
+
limit=params.limit,
|
30 |
+
market=params.market
|
31 |
+
)
|
32 |
+
return response['tracks']
|
33 |
+
except Exception as e:
|
34 |
+
logger.error(f"Recommendation failed: {str(e)}")
|
35 |
+
raise
|
36 |
+
|
37 |
+
def refresh_authentication(self):
|
38 |
+
"""Refresh authentication and reinitialize the client"""
|
39 |
+
try:
|
40 |
+
self.authenticate()
|
41 |
+
except Exception as e:
|
42 |
+
logger.error(f"Failed to refresh authentication: {str(e)}")
|
43 |
+
raise
|
44 |
+
|
45 |
+
def get_available_genres(self) -> List[str]:
|
46 |
+
"""Get available genre seeds from Spotify"""
|
47 |
+
try:
|
48 |
+
self.refresh_authentication() # Ensure fresh authentication
|
49 |
+
response = self.client.recommendation_genre_seeds()
|
50 |
+
logger.info(f"Available genres: {response}")
|
51 |
+
return response.get('genres', [])
|
52 |
+
except Exception as e:
|
53 |
+
logger.error(f"Failed to fetch genres: {str(e)}")
|
54 |
+
return []
|
55 |
+
|
56 |
+
def get_audio_features(self, track_uri: str) -> Dict:
|
57 |
+
"""Get audio features for a track"""
|
58 |
+
try:
|
59 |
+
return self.client.audio_features(track_uri)[0] or {}
|
60 |
+
except Exception as e:
|
61 |
+
logger.error(f"Failed to fetch audio features: {str(e)}")
|
62 |
+
return {}
|
63 |
+
|
64 |
+
|
65 |
+
# Usage:
|
66 |
+
spotify_client = SpotifyClient(config=SpotifyConfig())
|
67 |
+
params = RecommendationParameters(
|
68 |
+
seed_genres=["pop", "rock"],
|
69 |
+
target_features={"danceability": 0.7, "energy": 0.6},
|
70 |
+
limit=10,
|
71 |
+
market="US"
|
72 |
+
)
|
73 |
+
|
74 |
+
print(params) # Debugging
|
75 |
+
recommendations = spotify_client.get_recommendations(params)
|
76 |
+
print(recommendations)
|
src/music/config/__init__.py
ADDED
File without changes
|
src/music/config/settings.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
import logging
|
5 |
+
|
6 |
+
load_dotenv()
|
7 |
+
|
8 |
+
# Configure logging
|
9 |
+
logging.basicConfig(level=logging.INFO)
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class SpotifyConfig:
|
15 |
+
client_id: str = os.getenv("SPOTIFY_CLIENT_ID")
|
16 |
+
client_secret: str = os.getenv("SPOTIFY_CLIENT_SECRET")
|
17 |
+
market: str = "US"
|
18 |
+
max_retries: int = 3
|
19 |
+
|
20 |
+
def validate(self):
|
21 |
+
if not self.client_id or not self.client_secret:
|
22 |
+
raise ValueError("Missing Spotify credentials in environment")
|
src/music/fetch.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#TODO:Work on the fetching music from Spotify. Figure out a way to find music based on the emotion detected during the conversation.
|
2 |
+
#TODO: Implement a smooth algorithm to sync music without a fail.
|
3 |
+
|
4 |
+
|
5 |
+
import json
|
6 |
+
import time
|
7 |
+
import spotipy
|
8 |
+
import random
|
9 |
+
from spotipy.oauth2 import SpotifyClientCredentials
|
10 |
+
import os
|
11 |
+
|
12 |
+
client_id = os.getenv("SPOTIFY_CLIENT_ID")
|
13 |
+
client_secret = os.getenv("SPOTIFY_SECRET")
|
14 |
+
|
15 |
+
client_credentials_manager = SpotifyClientCredentials(client_id=client_id, client_secret=client_secret)
|
16 |
+
|
17 |
+
sp = spotipy.Spotify(client_credentials_manager=client_credentials_manager)
|
18 |
+
|
src/music/main.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from spotipy import Spotify
|
3 |
+
from spotipy.oauth2 import SpotifyClientCredentials
|
4 |
+
from typing import Dict, List, Optional
|
5 |
+
import os
|
6 |
+
import logging
|
7 |
+
from dataclasses import dataclass
|
8 |
+
import json
|
9 |
+
from datetime import datetime
|
10 |
+
from dotenv import load_dotenv
|
11 |
+
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
+
# Configure logging
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
# -------------------------
|
18 |
+
# Data Structures
|
19 |
+
# -------------------------
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class TrackRecommendation:
|
23 |
+
uri: str
|
24 |
+
name: str
|
25 |
+
artist: str
|
26 |
+
preview_url: Optional[str]
|
27 |
+
audio_features: Dict
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class RecommendationParameters:
|
31 |
+
seed_genres: List[str]
|
32 |
+
target_features: Dict
|
33 |
+
limit: int = 20
|
34 |
+
market: str = "US"
|
35 |
+
|
36 |
+
# -------------------------
|
37 |
+
# Core Interfaces
|
38 |
+
# -------------------------
|
39 |
+
|
40 |
+
class IMusicRecommendationStrategy(ABC):
|
41 |
+
@abstractmethod
|
42 |
+
def generate_recommendations(self, emotion: str, context: Dict) -> List[TrackRecommendation]:
|
43 |
+
pass
|
44 |
+
|
45 |
+
class IAudioAnalyzer(ABC):
|
46 |
+
@abstractmethod
|
47 |
+
def analyze_track(self, track_uri: str) -> Dict:
|
48 |
+
pass
|
49 |
+
|
50 |
+
# -------------------------
|
51 |
+
# Spotify Client
|
52 |
+
# -------------------------
|
53 |
+
|
54 |
+
class SpotifyClient:
|
55 |
+
"""Handles Spotify authentication and basic API operations"""
|
56 |
+
|
57 |
+
def __init__(self):
|
58 |
+
self.client_credentials_manager = SpotifyClientCredentials(
|
59 |
+
client_id=os.getenv("SPOTIFY_CLIENT_ID"),
|
60 |
+
client_secret=os.getenv("SPOTIFY_CLIENT_SECRET")
|
61 |
+
)
|
62 |
+
self.client = Spotify(client_credentials_manager=self.client_credentials_manager)
|
63 |
+
|
64 |
+
def get_recommendations(self, params: RecommendationParameters) -> List[Dict]:
|
65 |
+
"""Base recommendation API call"""
|
66 |
+
try:
|
67 |
+
response = self.client.recommendations(
|
68 |
+
seed_genres=params.seed_genres,
|
69 |
+
target_features=params.target_features,
|
70 |
+
limit=params.limit,
|
71 |
+
market=params.market
|
72 |
+
)
|
73 |
+
return response['tracks']
|
74 |
+
except Exception as e:
|
75 |
+
logger.error(f"Recommendation failed: {str(e)}")
|
76 |
+
raise
|
77 |
+
|
78 |
+
# -------------------------
|
79 |
+
# Emotion Mapping System
|
80 |
+
# -------------------------
|
81 |
+
|
82 |
+
class EmotionAudioProfile:
|
83 |
+
"""Maps emotions to audio characteristics with cultural adaptation"""
|
84 |
+
|
85 |
+
def __init__(self):
|
86 |
+
self.base_profiles = {
|
87 |
+
"sad": {"target_valence": 0.2, "target_energy": 0.3},
|
88 |
+
"happy": {"target_valence": 0.8, "target_energy": 0.7},
|
89 |
+
"anxious": {"target_valence": 0.5, "target_energy": 0.4},
|
90 |
+
"angry": {"target_valence": 0.3, "target_energy": 0.8}
|
91 |
+
}
|
92 |
+
|
93 |
+
self.cultural_adjustments = {
|
94 |
+
"US": {"happy": {"target_danceability": 0.8}},
|
95 |
+
"JP": {"happy": {"target_danceability": 0.6}}
|
96 |
+
}
|
97 |
+
|
98 |
+
def get_profile(self, emotion: str, country: str = "US") -> Dict:
|
99 |
+
"""Get culturally adjusted audio profile"""
|
100 |
+
profile = self.base_profiles.get(emotion, {}).copy()
|
101 |
+
profile.update(self.cultural_adjustments.get(country, {}).get(emotion, {}))
|
102 |
+
return profile
|
103 |
+
|
104 |
+
|
105 |
+
class GenreMapper:
|
106 |
+
"""Hierarchical genre mapping system with fallbacks"""
|
107 |
+
|
108 |
+
def __init__(self, spotify_client: SpotifyClient):
|
109 |
+
self.spotify = spotify_client
|
110 |
+
self.genre_hierarchy = {
|
111 |
+
"sad": ["blues", "soul", "acoustic"],
|
112 |
+
"happy": ["pop", "dance", "disco"],
|
113 |
+
"anxious": ["ambient", "classical"],
|
114 |
+
"angry": ["rock", "metal"]
|
115 |
+
}
|
116 |
+
self.available_genres = self._load_available_genres()
|
117 |
+
|
118 |
+
def _load_available_genres(self) -> List[str]:
|
119 |
+
"""Get valid Spotify genres"""
|
120 |
+
return self.spotify.client.recommendation_genre_seeds()['genres']
|
121 |
+
|
122 |
+
def get_genres(self, emotion: str) -> List[str]:
|
123 |
+
"""Get best available genres for emotion"""
|
124 |
+
for genre in self.genre_hierarchy.get(emotion, []):
|
125 |
+
if genre in self.available_genres:
|
126 |
+
return [genre]
|
127 |
+
return ["pop"]
|
128 |
+
|
129 |
+
# -------------------------
|
130 |
+
# AI Integration
|
131 |
+
# -------------------------
|
132 |
+
|
133 |
+
class LLMEnhancer:
|
134 |
+
"""Enhances recommendations using LLM context analysis"""
|
135 |
+
|
136 |
+
def __init__(self):
|
137 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
138 |
+
self.llm = ChatGoogleGenerativeAI(model="gemini-pro")
|
139 |
+
|
140 |
+
def enhance_params(self, context: Dict) -> Dict:
|
141 |
+
"""Analyze conversation context for musical attributes"""
|
142 |
+
prompt = f"""
|
143 |
+
Analyze this therapeutic context to suggest music parameters:
|
144 |
+
{json.dumps(context, indent=2)}
|
145 |
+
|
146 |
+
Return JSON with:
|
147 |
+
- target_energy (0-1)
|
148 |
+
- target_danceability (0-1)
|
149 |
+
- target_tempo
|
150 |
+
- seed_artist (main artist name)
|
151 |
+
- seed_track (main track name)
|
152 |
+
"""
|
153 |
+
|
154 |
+
try:
|
155 |
+
response = self.llm.invoke(prompt)
|
156 |
+
return json.loads(response.content)
|
157 |
+
except Exception as e:
|
158 |
+
logger.warning(f"LLM enhancement failed: {str(e)}")
|
159 |
+
return {}
|
160 |
+
|
161 |
+
# -------------------------
|
162 |
+
# Recommendation Engine
|
163 |
+
# -------------------------
|
164 |
+
|
165 |
+
class TherapeuticMusicRecommender(IMusicRecommendationStrategy):
|
166 |
+
"""Main recommendation engine with multiple strategies"""
|
167 |
+
|
168 |
+
def __init__(self):
|
169 |
+
self.spotify = SpotifyClient()
|
170 |
+
self.audio_profiler = EmotionAudioProfile()
|
171 |
+
self.genre_mapper = GenreMapper(self.spotify)
|
172 |
+
self.llm_enhancer = LLMEnhancer()
|
173 |
+
self.cache = RecommendationCache()
|
174 |
+
|
175 |
+
def generate_recommendations(self, emotion: str, context: Dict) -> List[TrackRecommendation]:
|
176 |
+
"""Generate context-aware recommendations"""
|
177 |
+
# Check cache first
|
178 |
+
cache_key = self._generate_cache_key(emotion, context)
|
179 |
+
if cached := self.cache.get(cache_key):
|
180 |
+
return cached
|
181 |
+
|
182 |
+
# Build parameters
|
183 |
+
params = self._build_recommendation_params(emotion, context)
|
184 |
+
|
185 |
+
# Get raw recommendations
|
186 |
+
raw_tracks = self.spotify.get_recommendations(params)
|
187 |
+
|
188 |
+
# Process and enrich tracks
|
189 |
+
processed = self._process_tracks(raw_tracks)
|
190 |
+
|
191 |
+
# Cache results
|
192 |
+
self.cache.store(cache_key, processed)
|
193 |
+
|
194 |
+
return processed
|
195 |
+
|
196 |
+
def _build_recommendation_params(self, emotion: str, context: Dict) -> RecommendationParameters:
|
197 |
+
"""Construct recommendation parameters"""
|
198 |
+
base_features = self.audio_profiler.get_profile(
|
199 |
+
emotion,
|
200 |
+
context.get('user', {}).get('country', 'US')
|
201 |
+
)
|
202 |
+
|
203 |
+
llm_features = self.llm_enhancer.enhance_params(context)
|
204 |
+
|
205 |
+
return RecommendationParameters(
|
206 |
+
seed_genres=self.genre_mapper.get_genres(emotion),
|
207 |
+
target_features={**base_features, **llm_features},
|
208 |
+
market=context.get('user', {}).get('country', 'US'),
|
209 |
+
limit=context.get('limit', 20)
|
210 |
+
)
|
211 |
+
|
212 |
+
def _process_tracks(self, raw_tracks: List[Dict]) -> List[TrackRecommendation]:
|
213 |
+
"""Convert raw tracks to enriched recommendations"""
|
214 |
+
return [
|
215 |
+
TrackRecommendation(
|
216 |
+
uri=track['uri'],
|
217 |
+
name=track['name'],
|
218 |
+
artist=track['artists'][0]['name'],
|
219 |
+
preview_url=track.get('preview_url'),
|
220 |
+
audio_features=self.spotify.client.audio_features(track['uri'])[0]
|
221 |
+
) for track in raw_tracks
|
222 |
+
]
|
223 |
+
|
224 |
+
def _generate_cache_key(self, emotion: str, context: Dict) -> str:
|
225 |
+
"""Generate unique cache key"""
|
226 |
+
return f"{emotion}-{context.get('user', {}).get('id', 'anonymous')}"
|
227 |
+
|
228 |
+
# -------------------------
|
229 |
+
# Advanced Features
|
230 |
+
# -------------------------
|
231 |
+
|
232 |
+
class RecommendationCache:
|
233 |
+
"""LRU cache for recommendations"""
|
234 |
+
|
235 |
+
def __init__(self, max_size: int = 100):
|
236 |
+
self.cache = {}
|
237 |
+
self.max_size = max_size
|
238 |
+
self.order = []
|
239 |
+
|
240 |
+
def get(self, key: str) -> Optional[List[TrackRecommendation]]:
|
241 |
+
if key in self.cache:
|
242 |
+
self.order.remove(key)
|
243 |
+
self.order.append(key)
|
244 |
+
return self.cache[key]
|
245 |
+
return None
|
246 |
+
|
247 |
+
def store(self, key: str, recommendations: List[TrackRecommendation]):
|
248 |
+
if len(self.cache) >= self.max_size:
|
249 |
+
oldest = self.order.pop(0)
|
250 |
+
del self.cache[oldest]
|
251 |
+
self.cache[key] = recommendations
|
252 |
+
self.order.append(key)
|
253 |
+
|
254 |
+
class MoodTransitionEngine:
|
255 |
+
"""Creates playlists that transition between emotional states"""
|
256 |
+
|
257 |
+
def __init__(self, recommender: TherapeuticMusicRecommender):
|
258 |
+
self.recommender = recommender
|
259 |
+
|
260 |
+
def create_transition_playlist(self, start_emotion: str, end_emotion: str, context: Dict) -> List[TrackRecommendation]:
|
261 |
+
"""Generate mood transition sequence"""
|
262 |
+
steps = self._calculate_transition_steps(start_emotion, end_emotion)
|
263 |
+
playlist = []
|
264 |
+
|
265 |
+
for step in steps:
|
266 |
+
context['transition_step'] = step
|
267 |
+
playlist += self.recommender.generate_recommendations(
|
268 |
+
emotion=step['emotion'],
|
269 |
+
context=context
|
270 |
+
)
|
271 |
+
|
272 |
+
return playlist
|
273 |
+
|
274 |
+
def _calculate_transition_steps(self, start: str, end: str) -> List[Dict]:
|
275 |
+
"""Determine intermediate emotional states"""
|
276 |
+
transitions = {
|
277 |
+
('sad', 'happy'): [{'emotion': 'sad', 'intensity': 0.8},
|
278 |
+
{'emotion': 'neutral', 'intensity': 0.5},
|
279 |
+
{'emotion': 'happy', 'intensity': 0.7}],
|
280 |
+
# Add other transition paths
|
281 |
+
}
|
282 |
+
return transitions.get((start, end), [])
|
283 |
+
|
284 |
+
# -------------------------
|
285 |
+
# Usage Example
|
286 |
+
# -------------------------
|
287 |
+
|
288 |
+
if __name__ == "__main__":
|
289 |
+
# Initialize system
|
290 |
+
recommender = TherapeuticMusicRecommender()
|
291 |
+
|
292 |
+
# Sample context from therapy session
|
293 |
+
context = {
|
294 |
+
"user": {
|
295 |
+
"id": "user123",
|
296 |
+
"country": "US",
|
297 |
+
"time_of_day": datetime.now().hour
|
298 |
+
},
|
299 |
+
"conversation": {
|
300 |
+
"emotion": "anxious",
|
301 |
+
"key_phrases": ["work stress", "sleep issues"],
|
302 |
+
"therapist_notes": "Needs calming music with nature sounds"
|
303 |
+
}
|
304 |
+
}
|
305 |
+
|
306 |
+
# Generate recommendations
|
307 |
+
recommendations = recommender.generate_recommendations(
|
308 |
+
emotion="anxious",
|
309 |
+
context=context
|
310 |
+
)
|
311 |
+
|
312 |
+
# Output results
|
313 |
+
print(f"Generated {len(recommendations)} tracks:")
|
314 |
+
for track in recommendations[:3]:
|
315 |
+
print(f"- {track.artist}: {track.name} ({track.audio_features['tempo']} BPM)")
|
src/music/models/__init__.py
ADDED
File without changes
|
src/music/models/data_models.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict, List, Optional
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class TrackRecommendation:
|
6 |
+
uri: str
|
7 |
+
name: str
|
8 |
+
artist: str
|
9 |
+
preview_url: Optional[str]
|
10 |
+
audio_features: Dict
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class RecommendationParameters:
|
14 |
+
seed_genres: List[str]
|
15 |
+
target_features: Dict
|
16 |
+
limit: int = 20
|
17 |
+
market: str = "US"
|
src/music/services/__init__.py
ADDED
File without changes
|
src/music/services/genre_service.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
from src.music.clients.spotify_client import SpotifyClient
|
3 |
+
from src.music.config.settings import logger, SpotifyConfig
|
4 |
+
|
5 |
+
class GenreMapper:
|
6 |
+
"""Maps emotions to appropriate genres"""
|
7 |
+
|
8 |
+
def __init__(self, spotify_client: SpotifyClient):
|
9 |
+
self.spotify = spotify_client
|
10 |
+
self.genre_hierarchy = {
|
11 |
+
"sad": ["blues", "soul", "acoustic"],
|
12 |
+
"happy": ["pop", "dance", "disco"],
|
13 |
+
"anxious": ["ambient", "classical"],
|
14 |
+
"angry": ["rock", "metal"]
|
15 |
+
}
|
16 |
+
self.available_genres = self._load_available_genres()
|
17 |
+
|
18 |
+
def _load_available_genres(self) -> List[str]:
|
19 |
+
"""Load available genres from Spotify"""
|
20 |
+
return self.spotify.get_available_genres()
|
21 |
+
|
22 |
+
def get_genres(self, emotion: str) -> List[str]:
|
23 |
+
"""Get appropriate genres for an emotion"""
|
24 |
+
emotion_genres = self.genre_hierarchy.get(emotion, [])
|
25 |
+
available = [genre for genre in emotion_genres if genre in self.available_genres]
|
26 |
+
return available if available else ["pop"]
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
# Usage:
|
31 |
+
client = SpotifyClient(config=SpotifyConfig())
|
32 |
+
mapper = GenreMapper(client)
|
33 |
+
print(mapper.get_genres("happy"))
|
src/tele_bot/__init__.py
ADDED
File without changes
|
src/tele_bot/bot.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#TODO: Collate the necessary features for the Telegram bot, including parts like buttons, other services, Web Search Feature (*Might be nice to try this out.*)
|
2 |
+
|
3 |
+
import asyncio
|
4 |
+
import logging
|
5 |
+
from telegram import Update, ReplyKeyboardMarkup
|
6 |
+
from telegram.ext import (
|
7 |
+
Application,
|
8 |
+
CommandHandler,
|
9 |
+
MessageHandler,
|
10 |
+
filters,
|
11 |
+
ContextTypes,
|
12 |
+
CallbackContext
|
13 |
+
)
|
14 |
+
|
15 |
+
from src.llm.agents.conversation_agent import ConversationAgent
|
16 |
+
from src.llm.models.schemas import SessionData
|
17 |
+
from src.llm.core.config import settings
|
18 |
+
|
19 |
+
# Configure logging
|
20 |
+
logging.basicConfig(
|
21 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
22 |
+
level=logging.INFO
|
23 |
+
)
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
conversation_agent = ConversationAgent()
|
27 |
+
|
28 |
+
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
29 |
+
"""Handle /start command with interactive keyboard"""
|
30 |
+
keyboard = [
|
31 |
+
["💬 Start Chatting"],
|
32 |
+
["ℹ️ About", "🛠 Help"]
|
33 |
+
]
|
34 |
+
reply_markup = ReplyKeyboardMarkup(keyboard, resize_keyboard=True)
|
35 |
+
|
36 |
+
welcome_msg = (
|
37 |
+
"🌟 Welcome to Thery AI! 🌟\n\n"
|
38 |
+
"I'm here to provide compassionate mental health support. "
|
39 |
+
"How can I help you today?"
|
40 |
+
)
|
41 |
+
|
42 |
+
await update.message.reply_text(welcome_msg, reply_markup=reply_markup)
|
43 |
+
|
44 |
+
async def handle_message(update: Update, context: CallbackContext):
|
45 |
+
"""Process user messages with conversation context"""
|
46 |
+
try:
|
47 |
+
user = update.effective_user
|
48 |
+
text = update.message.text
|
49 |
+
|
50 |
+
# Get or create session
|
51 |
+
session_data = context.user_data.get('session_data')
|
52 |
+
|
53 |
+
# Process query
|
54 |
+
response = await conversation_agent.process_async(
|
55 |
+
query=text,
|
56 |
+
session_data=session_data
|
57 |
+
)
|
58 |
+
|
59 |
+
# Update session data
|
60 |
+
context.user_data['session_data'] = response.session_data
|
61 |
+
|
62 |
+
# Send response with typing indicator
|
63 |
+
await context.bot.send_chat_action(
|
64 |
+
chat_id=update.effective_chat.id,
|
65 |
+
action="typing"
|
66 |
+
)
|
67 |
+
await update.message.reply_text(response.response)
|
68 |
+
|
69 |
+
except Exception as e:
|
70 |
+
logger.error(f"Error processing message: {str(e)}")
|
71 |
+
await update.message.reply_text("I'm having trouble understanding. Let's try that again.")
|
72 |
+
|
73 |
+
|
74 |
+
async def error_handler(update: Update, context: CallbackContext):
|
75 |
+
"""Handle errors in the bot"""
|
76 |
+
logger.error(f"Update {update} caused error: {context.error}")
|
77 |
+
await update.message.reply_text("Oops! Something went wrong. Please try again.")
|
78 |
+
|
79 |
+
|
80 |
+
def main():
|
81 |
+
"""Configure and start the bot"""
|
82 |
+
application = Application.builder().token(settings.TELEGRAM_BOT_TOKEN).build()
|
83 |
+
|
84 |
+
# Add handlers
|
85 |
+
application.add_handler(CommandHandler("start", start))
|
86 |
+
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_message))
|
87 |
+
|
88 |
+
# Add error handler
|
89 |
+
application.add_error_handler(error_handler)
|
90 |
+
|
91 |
+
# Start polling
|
92 |
+
logger.info("Starting Thery AI Telegram bot...")
|
93 |
+
application.run_polling(
|
94 |
+
poll_interval=1,
|
95 |
+
allowed_updates=Update.ALL_TYPES,
|
96 |
+
drop_pending_updates=True
|
97 |
+
)
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
main()
|
src/tele_bot/graph.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
```mermaid
|
2 |
+
graph TD
|
3 |
+
A[Telegram User] --> B[Telegram Bot]
|
4 |
+
C[Web Frontend] --> D[FastAPI HTTP]
|
5 |
+
C --> E[FastAPI WebSocket]
|
6 |
+
B & D & E --> F[Conversation Agent]
|
7 |
+
F --> G[Emotion Agent]
|
8 |
+
F --> H[Context Agent]
|
9 |
+
G & H --> I[LLM Core]
|
10 |
+
F --> J[Redis Memory]
|
11 |
+
J --> K[Session Management]
|
12 |
+
J --> L[Conversation History]
|
13 |
+
```
|
src/utils/__init__.py
ADDED
File without changes
|
src/utils/main.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import src.utils.vector_db as vector_db
|
2 |
+
import src.utils.pdf_splitter as pdf_splitter
|
3 |
+
|
4 |
+
vector_db = vector_db.VectorDatabase("mental_health_db")
|
5 |
+
pdf_process = pdf_splitter.PDFProcessor("../data/mental_health")
|
6 |
+
processed_pdf = pdf_process.run()
|
7 |
+
vector_db.create_db(processed_pdf)
|
src/utils/pdf_splitter.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
from langchain_community.document_loaders import PyPDFLoader
|
3 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
4 |
+
|
5 |
+
class DataExtractor:
|
6 |
+
def __init__(self, pdf_directory):
|
7 |
+
self.pdf_directory = pdf_directory
|
8 |
+
self.pdf_text = []
|
9 |
+
self.docs = []
|
10 |
+
self.split_docs = None
|
11 |
+
|
12 |
+
|
13 |
+
def extract_text(self):
|
14 |
+
print(f'Extracting text from pdf files in {self.pdf_directory}')
|
15 |
+
pdf_files = glob.glob(f'{self.pdf_directory}/*.pdf')
|
16 |
+
print(pdf_files)
|
17 |
+
for pdf_file in pdf_files:
|
18 |
+
print(pdf_file)
|
19 |
+
loader = PyPDFLoader(pdf_file)
|
20 |
+
documents = loader.load()
|
21 |
+
self.pdf_text.append(documents)
|
22 |
+
|
23 |
+
return self.pdf_text
|
24 |
+
|
25 |
+
|
26 |
+
# Function to clean and split text
|
27 |
+
def clean_and_split_text(self, documents):
|
28 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
29 |
+
split_docs = []
|
30 |
+
print(f'Cleaning and splitting text from {len(documents)} documents')
|
31 |
+
for doc in documents:
|
32 |
+
# Splitting each document individually
|
33 |
+
split_docs.extend(splitter.split_documents(doc))
|
34 |
+
print(f'Number of documents after splitting: {len(split_docs)}')
|
35 |
+
return split_docs
|
36 |
+
|
37 |
+
|
src/utils/vector_db.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from src.utils.pdf_splitter import DataExtractor
|
3 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings # Changed
|
4 |
+
from langchain_community.vectorstores import FAISS # Fixed import
|
5 |
+
|
6 |
+
class VectorDatabase:
|
7 |
+
def __init__(self, db_name):
|
8 |
+
self.db_name = db_name # Use parameter
|
9 |
+
self.persist_directory = os.path.join("vector_embedding", self.db_name) # Fixed path
|
10 |
+
|
11 |
+
# Correct embeddings for sentence-transformers model
|
12 |
+
self.embeddings = HuggingFaceEmbeddings(
|
13 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
14 |
+
model_kwargs={"device": "cpu"},
|
15 |
+
encode_kwargs={
|
16 |
+
"padding": "max_length",
|
17 |
+
"max_length": 512,
|
18 |
+
"truncation": True,
|
19 |
+
"normalize_embeddings": True
|
20 |
+
}
|
21 |
+
)
|
22 |
+
|
23 |
+
def create_db(self, pdf_data):
|
24 |
+
# Create and persist database in one step
|
25 |
+
self.vectDB = FAISS.from_documents(
|
26 |
+
documents=pdf_data,
|
27 |
+
embedding=self.embeddings
|
28 |
+
)
|
29 |
+
self.vectDB.save_local(self.persist_directory)
|
30 |
+
# No need for add_documents() or explicit persist() when using from_documents
|
31 |
+
|
32 |
+
def main():
|
33 |
+
pdf_directory = './data/mental_health'
|
34 |
+
data_extractor = DataExtractor(pdf_directory)
|
35 |
+
text_data = data_extractor.extract_text()
|
36 |
+
text_data = data_extractor.clean_and_split_text(text_data)
|
37 |
+
|
38 |
+
# Step 2: Create and load the vector database
|
39 |
+
vector_db = VectorDatabase(db_name="mental_health_vector_db")
|
40 |
+
vector_db.create_db(text_data)
|
41 |
+
print("Vector embeddings have been generated and loaded successfully.")
|
42 |
+
|
43 |
+
if __name__ == "__main__":
|
44 |
+
main()
|