Spaces:
Running
Running
v2 of public chat
Browse files* use ReAct agent
* clean up tool code
- README.md +2 -2
- ask_candid/base/api_base.py +3 -3
- ask_candid/base/api_base_async.py +3 -3
- ask_candid/base/config/base.py +10 -0
- ask_candid/base/config/connections.py +6 -14
- ask_candid/base/config/models.py +1 -0
- ask_candid/base/config/rest.py +49 -10
- ask_candid/base/lambda_base.py +3 -3
- ask_candid/base/retrieval/__init__.py +0 -0
- ask_candid/base/retrieval/elastic.py +205 -0
- ask_candid/base/retrieval/knowledge_base.py +362 -0
- ask_candid/base/retrieval/schemas.py +23 -0
- ask_candid/base/retrieval/sources.py +40 -0
- ask_candid/base/retrieval/sparse_lexical.py +98 -0
- ask_candid/base/utils.py +52 -0
- ask_candid/chat.py +68 -55
- ask_candid/services/small_lm.py +32 -6
- ask_candid/tools/general.py +17 -0
- ask_candid/tools/org_search.py +182 -0
- ask_candid/tools/search.py +56 -111
- ask_candid/tools/utils.py +14 -0
- chat_v2.py +265 -0
- requirements.txt +5 -5
README.md
CHANGED
|
@@ -6,8 +6,8 @@ colorFrom: blue
|
|
| 6 |
colorTo: purple
|
| 7 |
python_version: 3.12
|
| 8 |
sdk: gradio
|
| 9 |
-
sdk_version: 5.
|
| 10 |
-
app_file:
|
| 11 |
pinned: true
|
| 12 |
license: mit
|
| 13 |
---
|
|
|
|
| 6 |
colorTo: purple
|
| 7 |
python_version: 3.12
|
| 8 |
sdk: gradio
|
| 9 |
+
sdk_version: 5.42.0
|
| 10 |
+
app_file: chat_v2.py
|
| 11 |
pinned: true
|
| 12 |
license: mit
|
| 13 |
---
|
ask_candid/base/api_base.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import
|
| 2 |
|
| 3 |
from urllib3.util.retry import Retry
|
| 4 |
from requests.adapters import HTTPAdapter
|
|
@@ -10,7 +10,7 @@ class BaseAPI:
|
|
| 10 |
def __init__(
|
| 11 |
self,
|
| 12 |
url: str,
|
| 13 |
-
headers:
|
| 14 |
total_retries: int = 3,
|
| 15 |
backoff_factor: int = 2
|
| 16 |
) -> None:
|
|
@@ -36,7 +36,7 @@ class BaseAPI:
|
|
| 36 |
r.raise_for_status()
|
| 37 |
return r.json()
|
| 38 |
|
| 39 |
-
def post(self, payload:
|
| 40 |
r = self.session.post(url=self.__url, headers=self.__headers, json=payload, timeout=30)
|
| 41 |
r.raise_for_status()
|
| 42 |
return r.json()
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
|
| 3 |
from urllib3.util.retry import Retry
|
| 4 |
from requests.adapters import HTTPAdapter
|
|
|
|
| 10 |
def __init__(
|
| 11 |
self,
|
| 12 |
url: str,
|
| 13 |
+
headers: dict[str, Any] | None = None,
|
| 14 |
total_retries: int = 3,
|
| 15 |
backoff_factor: int = 2
|
| 16 |
) -> None:
|
|
|
|
| 36 |
r.raise_for_status()
|
| 37 |
return r.json()
|
| 38 |
|
| 39 |
+
def post(self, payload: dict[str, Any]):
|
| 40 |
r = self.session.post(url=self.__url, headers=self.__headers, json=payload, timeout=30)
|
| 41 |
r.raise_for_status()
|
| 42 |
return r.json()
|
ask_candid/base/api_base_async.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import
|
| 2 |
import json
|
| 3 |
|
| 4 |
import aiohttp
|
|
@@ -6,7 +6,7 @@ import aiohttp
|
|
| 6 |
|
| 7 |
class BaseAsyncAPI:
|
| 8 |
|
| 9 |
-
def __init__(self, url: str, headers:
|
| 10 |
self.__url = url
|
| 11 |
self.__headers = headers
|
| 12 |
self.__retries = max(retries, 5)
|
|
@@ -29,7 +29,7 @@ class BaseAsyncAPI:
|
|
| 29 |
break
|
| 30 |
return output
|
| 31 |
|
| 32 |
-
async def post(self, payload:
|
| 33 |
session_timeout = aiohttp.ClientTimeout(total=30)
|
| 34 |
async with aiohttp.ClientSession(headers=self.__headers, timeout=session_timeout) as session:
|
| 35 |
output = {}
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
import json
|
| 3 |
|
| 4 |
import aiohttp
|
|
|
|
| 6 |
|
| 7 |
class BaseAsyncAPI:
|
| 8 |
|
| 9 |
+
def __init__(self, url: str, headers: dict[str, Any] | None = None, retries: int = 3) -> None:
|
| 10 |
self.__url = url
|
| 11 |
self.__headers = headers
|
| 12 |
self.__retries = max(retries, 5)
|
|
|
|
| 29 |
break
|
| 30 |
return output
|
| 31 |
|
| 32 |
+
async def post(self, payload: dict[str, Any]):
|
| 33 |
session_timeout = aiohttp.ClientTimeout(total=30)
|
| 34 |
async with aiohttp.ClientSession(headers=self.__headers, timeout=session_timeout) as session:
|
| 35 |
output = {}
|
ask_candid/base/config/base.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from dotenv import dotenv_values, find_dotenv
|
| 4 |
+
|
| 5 |
+
__env_values__ = dotenv_values(
|
| 6 |
+
dotenv_path=find_dotenv(".env", raise_error_if_not_found=False)
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
def _load_value(key: str):
|
| 10 |
+
return __env_values__.get(key) or os.getenv(key)
|
ask_candid/base/config/connections.py
CHANGED
|
@@ -1,33 +1,25 @@
|
|
| 1 |
from dataclasses import dataclass, field
|
| 2 |
-
import os
|
| 3 |
|
| 4 |
-
from
|
| 5 |
|
| 6 |
|
| 7 |
@dataclass
|
| 8 |
class BaseElasticSearchConnection:
|
| 9 |
"""Elasticsearch connection dataclass
|
| 10 |
"""
|
| 11 |
-
url: str = field(default_factory=str)
|
| 12 |
-
username: str = field(default_factory=str)
|
| 13 |
-
password: str = field(default_factory=str)
|
| 14 |
|
| 15 |
|
| 16 |
@dataclass
|
| 17 |
class BaseElasticAPIKeyCredential:
|
| 18 |
"""Cloud ID/API key data class
|
| 19 |
"""
|
| 20 |
-
cloud_id: str = field(default_factory=str)
|
| 21 |
-
api_key: str = field(default_factory=str)
|
| 22 |
|
| 23 |
|
| 24 |
-
__env_values__ = dotenv_values(
|
| 25 |
-
dotenv_path=find_dotenv(".env", raise_error_if_not_found=False)
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
def _load_value(key: str):
|
| 29 |
-
return __env_values__.get(key) or os.getenv(key)
|
| 30 |
-
|
| 31 |
SEMANTIC_ELASTIC_QA = BaseElasticAPIKeyCredential(
|
| 32 |
cloud_id=_load_value("SEMANTIC_ELASTIC_CLOUD_ID"),
|
| 33 |
api_key=_load_value("SEMANTIC_ELASTIC_API_KEY"),
|
|
|
|
| 1 |
from dataclasses import dataclass, field
|
|
|
|
| 2 |
|
| 3 |
+
from ask_candid.base.config.base import _load_value
|
| 4 |
|
| 5 |
|
| 6 |
@dataclass
|
| 7 |
class BaseElasticSearchConnection:
|
| 8 |
"""Elasticsearch connection dataclass
|
| 9 |
"""
|
| 10 |
+
url: str | None = field(default_factory=str)
|
| 11 |
+
username: str | None = field(default_factory=str)
|
| 12 |
+
password: str | None = field(default_factory=str)
|
| 13 |
|
| 14 |
|
| 15 |
@dataclass
|
| 16 |
class BaseElasticAPIKeyCredential:
|
| 17 |
"""Cloud ID/API key data class
|
| 18 |
"""
|
| 19 |
+
cloud_id: str | None = field(default_factory=str)
|
| 20 |
+
api_key: str | None = field(default_factory=str)
|
| 21 |
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
SEMANTIC_ELASTIC_QA = BaseElasticAPIKeyCredential(
|
| 24 |
cloud_id=_load_value("SEMANTIC_ELASTIC_CLOUD_ID"),
|
| 25 |
api_key=_load_value("SEMANTIC_ELASTIC_API_KEY"),
|
ask_candid/base/config/models.py
CHANGED
|
@@ -3,6 +3,7 @@ from types import MappingProxyType
|
|
| 3 |
Name2Endpoint = MappingProxyType({
|
| 4 |
"gpt-4o": "gpt-4o",
|
| 5 |
"claude-3.5-haiku": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
|
|
|
| 6 |
# "llama-3.1-70b-instruct": "us.meta.llama3-1-70b-instruct-v1:0",
|
| 7 |
# "mistral-large": "mistral.mistral-large-2402-v1:0",
|
| 8 |
# "mixtral-8x7B": "mistral.mixtral-8x7b-instruct-v0:1",
|
|
|
|
| 3 |
Name2Endpoint = MappingProxyType({
|
| 4 |
"gpt-4o": "gpt-4o",
|
| 5 |
"claude-3.5-haiku": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
| 6 |
+
"claude-4-sonnet": "us.anthropic.claude-sonnet-4-20250514-v1:0",
|
| 7 |
# "llama-3.1-70b-instruct": "us.meta.llama3-1-70b-instruct-v1:0",
|
| 8 |
# "mistral-large": "mistral.mistral-large-2402-v1:0",
|
| 9 |
# "mixtral-8x7B": "mistral.mixtral-8x7b-instruct-v0:1",
|
ask_candid/base/config/rest.py
CHANGED
|
@@ -1,25 +1,64 @@
|
|
| 1 |
-
from typing import TypedDict
|
| 2 |
-
import os
|
| 3 |
|
| 4 |
-
from
|
| 5 |
|
| 6 |
|
| 7 |
class Api(TypedDict):
|
| 8 |
"""REST API configuration template
|
| 9 |
"""
|
| 10 |
-
url: str
|
| 11 |
-
key: str
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
def _load_value(key: str):
|
| 18 |
-
return __env_values__.get(key) or os.getenv(key)
|
| 19 |
|
| 20 |
CDS_API = Api(
|
| 21 |
url=_load_value("CDS_API_URL"),
|
| 22 |
key=_load_value("CDS_API_KEY")
|
| 23 |
)
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
OPENAI = Api(url=None, key=_load_value("OPENAI_API_KEY"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TypedDict, NamedTuple
|
|
|
|
| 2 |
|
| 3 |
+
from ask_candid.base.config.base import _load_value
|
| 4 |
|
| 5 |
|
| 6 |
class Api(TypedDict):
|
| 7 |
"""REST API configuration template
|
| 8 |
"""
|
| 9 |
+
url: str | None
|
| 10 |
+
key: str | None
|
| 11 |
|
| 12 |
+
class ApiConfig(NamedTuple):
|
| 13 |
+
url: str | None
|
| 14 |
+
key: str | None
|
| 15 |
+
|
| 16 |
+
@property
|
| 17 |
+
def header(self) -> dict[str, str | None]:
|
| 18 |
+
return {"x-api-key": self.key}
|
| 19 |
+
|
| 20 |
+
def endpoint(self, route: str):
|
| 21 |
+
return f"{self.url}/{route}"
|
| 22 |
|
|
|
|
|
|
|
| 23 |
|
| 24 |
CDS_API = Api(
|
| 25 |
url=_load_value("CDS_API_URL"),
|
| 26 |
key=_load_value("CDS_API_KEY")
|
| 27 |
)
|
| 28 |
|
| 29 |
+
CANDID_SEARCH_API = Api(
|
| 30 |
+
url=_load_value("CANDID_SEARCH_API_URL"),
|
| 31 |
+
key=_load_value("CANDID_SEARCH_API_KEY")
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
OPENAI = Api(url=None, key=_load_value("OPENAI_API_KEY"))
|
| 35 |
+
|
| 36 |
+
SEARCH = ApiConfig(
|
| 37 |
+
url="https://ajr9jccwf0.execute-api.us-east-1.amazonaws.com/Prod",
|
| 38 |
+
key=_load_value("SEARCH_API_KEY")
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
AUTOCODING = ApiConfig(
|
| 42 |
+
url="https://auto-coding-api.candid.org",
|
| 43 |
+
key=_load_value("AUTOCODING_API_KEY")
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
DOCUMENT = ApiConfig(
|
| 47 |
+
url="https://dtntz2p635.execute-api.us-east-1.amazonaws.com/Prod",
|
| 48 |
+
key=_load_value("GEOCODING_API_KEY")
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
FUNDER_RECOMMENDATION = ApiConfig(
|
| 52 |
+
url="https://r6g59fxbie.execute-api.us-east-1.amazonaws.com/Prod",
|
| 53 |
+
key=_load_value("FUNDER_RECS_API_KEY")
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
LOI_WRITER = ApiConfig(
|
| 57 |
+
url="https://tc2ir1o7ne.execute-api.us-east-1.amazonaws.com/Prod",
|
| 58 |
+
key=_load_value("LOI_WRITER_API_KEY")
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
GOLDEN_ORG = ApiConfig(
|
| 62 |
+
url="https://qfdur742ih.execute-api.us-east-1.amazonaws.com/Prod",
|
| 63 |
+
key=_load_value("GOLDEN_RECORD_API_KEY")
|
| 64 |
+
)
|
ask_candid/base/lambda_base.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import
|
| 2 |
from time import sleep
|
| 3 |
import json
|
| 4 |
|
|
@@ -25,7 +25,7 @@ class LambdaInvokeBase:
|
|
| 25 |
|
| 26 |
def __init__(
|
| 27 |
self, function_name: str,
|
| 28 |
-
access_key:
|
| 29 |
) -> None:
|
| 30 |
if access_key is not None and secret_key is not None:
|
| 31 |
self._client = boto3.client(
|
|
@@ -39,7 +39,7 @@ class LambdaInvokeBase:
|
|
| 39 |
|
| 40 |
self.function_name = function_name
|
| 41 |
|
| 42 |
-
def _submit_request(self, payload:
|
| 43 |
response = self._client.invoke(
|
| 44 |
FunctionName=self.function_name,
|
| 45 |
InvocationType="RequestResponse",
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
from time import sleep
|
| 3 |
import json
|
| 4 |
|
|
|
|
| 25 |
|
| 26 |
def __init__(
|
| 27 |
self, function_name: str,
|
| 28 |
+
access_key: str | None = None, secret_key: str | None = None,
|
| 29 |
) -> None:
|
| 30 |
if access_key is not None and secret_key is not None:
|
| 31 |
self._client = boto3.client(
|
|
|
|
| 39 |
|
| 40 |
self.function_name = function_name
|
| 41 |
|
| 42 |
+
def _submit_request(self, payload: dict[str, Any]) -> dict[str, Any] | list[Any]:
|
| 43 |
response = self._client.invoke(
|
| 44 |
FunctionName=self.function_name,
|
| 45 |
InvocationType="RequestResponse",
|
ask_candid/base/retrieval/__init__.py
ADDED
|
File without changes
|
ask_candid/base/retrieval/elastic.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
from collections.abc import Iterator
|
| 3 |
+
|
| 4 |
+
from elasticsearch import Elasticsearch
|
| 5 |
+
|
| 6 |
+
from ask_candid.base.retrieval.sparse_lexical import SpladeEncoder
|
| 7 |
+
from ask_candid.base.config.connections import BaseElasticAPIKeyCredential, BaseElasticSearchConnection
|
| 8 |
+
|
| 9 |
+
NEWS_TRUST_SCORE_THRESHOLD = 0.8
|
| 10 |
+
SPARSE_ENCODING_SCORE_THRESHOLD = 0.4
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def build_sparse_vector_query(
|
| 14 |
+
query: str,
|
| 15 |
+
fields: tuple[str, ...],
|
| 16 |
+
inference_id: str = ".elser-2-elasticsearch"
|
| 17 |
+
) -> dict[str, Any]:
|
| 18 |
+
"""Builds a valid Elasticsearch text expansion query payload
|
| 19 |
+
|
| 20 |
+
Parameters
|
| 21 |
+
----------
|
| 22 |
+
query : str
|
| 23 |
+
Search context string
|
| 24 |
+
fields : Tuple[str, ...]
|
| 25 |
+
Semantic text field names
|
| 26 |
+
inference_id : str, optional
|
| 27 |
+
ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch"
|
| 28 |
+
|
| 29 |
+
Returns
|
| 30 |
+
-------
|
| 31 |
+
Dict[str, Any]
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
output = []
|
| 35 |
+
|
| 36 |
+
for f in fields:
|
| 37 |
+
output.append({
|
| 38 |
+
"nested": {
|
| 39 |
+
"path": f"embeddings.{f}.chunks",
|
| 40 |
+
"query": {
|
| 41 |
+
"sparse_vector": {
|
| 42 |
+
"field": f"embeddings.{f}.chunks.vector",
|
| 43 |
+
"inference_id": inference_id,
|
| 44 |
+
"prune": True,
|
| 45 |
+
"query": query,
|
| 46 |
+
# "boost": 1 / len(fields)
|
| 47 |
+
}
|
| 48 |
+
},
|
| 49 |
+
"inner_hits": {
|
| 50 |
+
"_source": False,
|
| 51 |
+
"size": 2,
|
| 52 |
+
"fields": [f"embeddings.{f}.chunks.chunk"]
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
})
|
| 56 |
+
return {"query": {"bool": {"should": output}}}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def build_sparse_vector_and_text_query(
|
| 60 |
+
query: str,
|
| 61 |
+
semantic_fields: tuple[str, ...],
|
| 62 |
+
text_fields: tuple[str, ...] | None,
|
| 63 |
+
highlight_fields: tuple[str, ...] | None,
|
| 64 |
+
excluded_fields: tuple[str, ...] | None,
|
| 65 |
+
inference_id: str = ".elser-2-elasticsearch"
|
| 66 |
+
) -> dict[str, Any]:
|
| 67 |
+
"""Builds Elasticsearch sparse vector and text query payload
|
| 68 |
+
|
| 69 |
+
Parameters
|
| 70 |
+
----------
|
| 71 |
+
query : str
|
| 72 |
+
Search context string
|
| 73 |
+
semantic_fields : Tuple[str]
|
| 74 |
+
Semantic text field names
|
| 75 |
+
highlight_fields: Tuple[str]
|
| 76 |
+
Fields which relevant chunks will be helpful for the agent to read
|
| 77 |
+
text_fields : Tuple[str]
|
| 78 |
+
Regular text fields
|
| 79 |
+
excluded_fields : Tuple[str]
|
| 80 |
+
Fields to exclude from the source
|
| 81 |
+
inference_id : str, optional
|
| 82 |
+
ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch"
|
| 83 |
+
|
| 84 |
+
Returns
|
| 85 |
+
-------
|
| 86 |
+
Dict[str, Any]
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
output = []
|
| 90 |
+
final_query = {}
|
| 91 |
+
|
| 92 |
+
for f in semantic_fields:
|
| 93 |
+
output.append({
|
| 94 |
+
"sparse_vector": {
|
| 95 |
+
"field": f"{f}",
|
| 96 |
+
"inference_id": inference_id,
|
| 97 |
+
"query": query,
|
| 98 |
+
"boost": 1,
|
| 99 |
+
"prune": True # doesn't seem it changes anything if we use text queries additionally
|
| 100 |
+
}
|
| 101 |
+
})
|
| 102 |
+
|
| 103 |
+
if text_fields:
|
| 104 |
+
output.append({
|
| 105 |
+
"multi_match": {
|
| 106 |
+
"fields": text_fields,
|
| 107 |
+
"query": query,
|
| 108 |
+
"boost": 3
|
| 109 |
+
}
|
| 110 |
+
})
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
final_query = {
|
| 114 |
+
"track_total_hits": False,
|
| 115 |
+
"query": {
|
| 116 |
+
"bool": {"should": output}
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
if highlight_fields:
|
| 121 |
+
final_query["highlight"] = {
|
| 122 |
+
"fields": {
|
| 123 |
+
f"{f}": {
|
| 124 |
+
"type": "semantic", # ensures that highlighting is applied exclusively to semantic_text fields.
|
| 125 |
+
"number_of_fragments": 2, # number of chunks
|
| 126 |
+
"order": "none" # can be "score", but we have only two and hope for context
|
| 127 |
+
}
|
| 128 |
+
for f in highlight_fields
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
if excluded_fields:
|
| 133 |
+
final_query["_source"] = {"excludes": list(excluded_fields)}
|
| 134 |
+
return final_query
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def news_query_builder(
|
| 138 |
+
query: str,
|
| 139 |
+
fields: tuple[str, ...],
|
| 140 |
+
encoder: SpladeEncoder,
|
| 141 |
+
days_ago: int = 60,
|
| 142 |
+
) -> dict[str, Any]:
|
| 143 |
+
"""Builds a valid Elasticsearch query against Candid news, simulating a token expansion.
|
| 144 |
+
|
| 145 |
+
Parameters
|
| 146 |
+
----------
|
| 147 |
+
query : str
|
| 148 |
+
Search context string
|
| 149 |
+
|
| 150 |
+
Returns
|
| 151 |
+
-------
|
| 152 |
+
Dict[str, Any]
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
tokens = encoder.token_expand(query)
|
| 156 |
+
|
| 157 |
+
elastic_query = {
|
| 158 |
+
"_source": ["id", "link", "title", "content", "site_name"],
|
| 159 |
+
"query": {
|
| 160 |
+
"bool": {
|
| 161 |
+
"filter": [
|
| 162 |
+
{"range": {"event_date": {"gte": f"now-{days_ago}d/d"}}},
|
| 163 |
+
{"range": {"insert_date": {"gte": f"now-{days_ago}d/d"}}},
|
| 164 |
+
{"range": {"article_trust_worthiness": {"gt": NEWS_TRUST_SCORE_THRESHOLD}}}
|
| 165 |
+
],
|
| 166 |
+
"should": []
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
for token, score in tokens.items():
|
| 172 |
+
if score > SPARSE_ENCODING_SCORE_THRESHOLD:
|
| 173 |
+
elastic_query["query"]["bool"]["should"].append({
|
| 174 |
+
"multi_match": {
|
| 175 |
+
"query": token,
|
| 176 |
+
"fields": fields,
|
| 177 |
+
"boost": score
|
| 178 |
+
}
|
| 179 |
+
})
|
| 180 |
+
return elastic_query
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def multi_search_base(
|
| 184 |
+
queries: list[dict[str, Any]],
|
| 185 |
+
credentials: BaseElasticSearchConnection | BaseElasticAPIKeyCredential,
|
| 186 |
+
timeout: int = 180
|
| 187 |
+
) -> Iterator[dict[str, Any]]:
|
| 188 |
+
if isinstance(credentials, BaseElasticAPIKeyCredential):
|
| 189 |
+
es = Elasticsearch(
|
| 190 |
+
cloud_id=credentials.cloud_id,
|
| 191 |
+
api_key=credentials.api_key,
|
| 192 |
+
verify_certs=False,
|
| 193 |
+
request_timeout=timeout
|
| 194 |
+
)
|
| 195 |
+
elif isinstance(credentials, BaseElasticSearchConnection):
|
| 196 |
+
es = Elasticsearch(
|
| 197 |
+
credentials.url,
|
| 198 |
+
http_auth=(credentials.username, credentials.password),
|
| 199 |
+
timeout=timeout
|
| 200 |
+
)
|
| 201 |
+
else:
|
| 202 |
+
raise TypeError(f"Invalid credentials of type `{type(credentials)}")
|
| 203 |
+
|
| 204 |
+
yield from es.msearch(body=queries).get("responses", [])
|
| 205 |
+
es.close()
|
ask_candid/base/retrieval/knowledge_base.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal, Any
|
| 2 |
+
from collections.abc import Iterator, Iterable
|
| 3 |
+
from itertools import groupby
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
from langchain_core.documents import Document
|
| 7 |
+
|
| 8 |
+
from ask_candid.base.retrieval.elastic import (
|
| 9 |
+
build_sparse_vector_query,
|
| 10 |
+
build_sparse_vector_and_text_query,
|
| 11 |
+
news_query_builder,
|
| 12 |
+
multi_search_base
|
| 13 |
+
)
|
| 14 |
+
from ask_candid.base.retrieval.sparse_lexical import SpladeEncoder
|
| 15 |
+
from ask_candid.base.retrieval.schemas import ElasticHitsResult
|
| 16 |
+
import ask_candid.base.retrieval.sources as S
|
| 17 |
+
from ask_candid.services.small_lm import CandidSLM
|
| 18 |
+
|
| 19 |
+
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC
|
| 20 |
+
|
| 21 |
+
SourceNames = Literal[
|
| 22 |
+
"Candid Blog",
|
| 23 |
+
"Candid Help",
|
| 24 |
+
"Candid Learning",
|
| 25 |
+
"Candid News",
|
| 26 |
+
"IssueLab Research Reports",
|
| 27 |
+
"YouTube Training"
|
| 28 |
+
]
|
| 29 |
+
sparse_encoder = SpladeEncoder()
|
| 30 |
+
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
logger.setLevel(logging.INFO)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# TODO remove
|
| 36 |
+
def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
|
| 37 |
+
"""Pads the relevant chunk of text with context before and after
|
| 38 |
+
|
| 39 |
+
Parameters
|
| 40 |
+
----------
|
| 41 |
+
field_name : str
|
| 42 |
+
a field with the long text that was chunked into pieces
|
| 43 |
+
hit : ElasticHitsResult
|
| 44 |
+
context_length : int, optional
|
| 45 |
+
length of text to add before and after the chunk, by default 1024
|
| 46 |
+
add_context : bool, optional
|
| 47 |
+
Set to `False` to expand the text context by searching for the Elastic inner hit inside the larger document
|
| 48 |
+
, by default True
|
| 49 |
+
|
| 50 |
+
Returns
|
| 51 |
+
-------
|
| 52 |
+
str
|
| 53 |
+
longer chunks stuffed together
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
chunks = []
|
| 57 |
+
# NOTE chunks have tokens, long text is a string, but may contain html which affects tokenization
|
| 58 |
+
long_text = hit.source.get(field_name) or ""
|
| 59 |
+
long_text = long_text.lower()
|
| 60 |
+
|
| 61 |
+
inner_hits_field = f"embeddings.{field_name}.chunks"
|
| 62 |
+
found_chunks = hit.inner_hits.get(inner_hits_field, {}) if hit.inner_hits else None
|
| 63 |
+
if found_chunks:
|
| 64 |
+
for h in found_chunks.get("hits", {}).get("hits") or []:
|
| 65 |
+
chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
|
| 66 |
+
|
| 67 |
+
# cutting the middle because we may have tokenizing artifacts there
|
| 68 |
+
chunk = chunk[3: -3]
|
| 69 |
+
|
| 70 |
+
if add_context:
|
| 71 |
+
# Find the start and end indices of the chunk in the large text
|
| 72 |
+
start_index = long_text.find(chunk[:20])
|
| 73 |
+
|
| 74 |
+
# Chunk is found
|
| 75 |
+
if start_index != -1:
|
| 76 |
+
end_index = start_index + len(chunk)
|
| 77 |
+
pre_start_index = max(0, start_index - context_length)
|
| 78 |
+
post_end_index = min(len(long_text), end_index + context_length)
|
| 79 |
+
chunks.append(long_text[pre_start_index:post_end_index])
|
| 80 |
+
else:
|
| 81 |
+
chunks.append(chunk)
|
| 82 |
+
return '\n\n'.join(chunks)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def generate_queries(
|
| 86 |
+
query: str,
|
| 87 |
+
sources: list[SourceNames],
|
| 88 |
+
news_days_ago: int = 60
|
| 89 |
+
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
| 90 |
+
"""Builds Elastic queries against indices which do or do not support sparse vector queries.
|
| 91 |
+
|
| 92 |
+
Parameters
|
| 93 |
+
----------
|
| 94 |
+
query : str
|
| 95 |
+
Text describing a user's question or a description of investigative work which requires support from Candid's
|
| 96 |
+
knowledge base
|
| 97 |
+
sources : list[SourceNames]
|
| 98 |
+
One or more sources of knowledge from different areas at Candid.
|
| 99 |
+
* Candid Blog: Blog posts from Candid staff and trusted partners intended to help those in the sector or
|
| 100 |
+
illuminate ongoing work
|
| 101 |
+
* Candid Help: Candid FAQs to help user's get started with Candid's product platform and learning resources
|
| 102 |
+
* Candid Learning: Training documents from Candid's subject matter experts
|
| 103 |
+
* Candid News: News articles and press releases about real-time activity in the philanthropic sector
|
| 104 |
+
* IssueLab Research Reports: Academic research reports about the social/philanthropic sector
|
| 105 |
+
* YouTube Training: Transcripts from video-based training seminars from Candid's subject matter experts
|
| 106 |
+
news_days_ago : int, optional
|
| 107 |
+
How many days in the past to search for news articles, if a user is asking for recent trends then this value
|
| 108 |
+
should be set lower >~ 10, by default 60
|
| 109 |
+
|
| 110 |
+
Returns
|
| 111 |
+
-------
|
| 112 |
+
tuple[list[dict[str, Any]], list[dict[str, Any]]]
|
| 113 |
+
(sparse vector queries, queries for indices which do not support sparse vectors)
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
vector_queries = []
|
| 117 |
+
quasi_vector_queries = []
|
| 118 |
+
|
| 119 |
+
for source_name in sources:
|
| 120 |
+
if source_name == "Candid Blog":
|
| 121 |
+
q = build_sparse_vector_query(query=query, fields=S.CandidBlogConfig.semantic_fields)
|
| 122 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
| 123 |
+
q["size"] = 5
|
| 124 |
+
vector_queries.extend([{"index": S.CandidBlogConfig.index_name}, q])
|
| 125 |
+
elif source_name == "Candid Help":
|
| 126 |
+
q = build_sparse_vector_query(query=query, fields=S.CandidHelpConfig.semantic_fields)
|
| 127 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
| 128 |
+
q["size"] = 5
|
| 129 |
+
vector_queries.extend([{"index": S.CandidHelpConfig.index_name}, q])
|
| 130 |
+
elif source_name == "Candid Learning":
|
| 131 |
+
q = build_sparse_vector_query(query=query, fields=S.CandidLearningConfig.semantic_fields)
|
| 132 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
| 133 |
+
q["size"] = 5
|
| 134 |
+
vector_queries.extend([{"index": S.CandidLearningConfig.index_name}, q])
|
| 135 |
+
elif source_name == "Candid News":
|
| 136 |
+
q = news_query_builder(
|
| 137 |
+
query=query,
|
| 138 |
+
fields=S.CandidNewsConfig.semantic_fields,
|
| 139 |
+
encoder=sparse_encoder,
|
| 140 |
+
days_ago=news_days_ago
|
| 141 |
+
)
|
| 142 |
+
q["size"] = 5
|
| 143 |
+
quasi_vector_queries.extend([{"index": S.CandidNewsConfig.index_name}, q])
|
| 144 |
+
elif source_name == "IssueLab Research Reports":
|
| 145 |
+
q = build_sparse_vector_query(query=query, fields=S.IssueLabConfig.semantic_fields)
|
| 146 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
| 147 |
+
q["size"] = 1
|
| 148 |
+
vector_queries.extend([{"index": S.IssueLabConfig.index_name}, q])
|
| 149 |
+
elif source_name == "YouTube Training":
|
| 150 |
+
q = build_sparse_vector_and_text_query(
|
| 151 |
+
query=query,
|
| 152 |
+
semantic_fields=S.YoutubeConfig.semantic_fields,
|
| 153 |
+
text_fields=S.YoutubeConfig.text_fields,
|
| 154 |
+
highlight_fields=S.YoutubeConfig.highlight_fields,
|
| 155 |
+
excluded_fields=S.YoutubeConfig.excluded_fields
|
| 156 |
+
)
|
| 157 |
+
q["size"] = 5
|
| 158 |
+
vector_queries.extend([{"index": S.YoutubeConfig.index_name}, q])
|
| 159 |
+
|
| 160 |
+
return vector_queries, quasi_vector_queries
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def run_search(
|
| 164 |
+
vector_searches: list[dict[str, Any]] | None = None,
|
| 165 |
+
non_vector_searches: list[dict[str, Any]] | None = None,
|
| 166 |
+
) -> list[ElasticHitsResult]:
|
| 167 |
+
def _msearch_response_generator(responses: Iterable[dict[str, Any]]) -> Iterator[ElasticHitsResult]:
|
| 168 |
+
for query_group in responses:
|
| 169 |
+
for h in query_group.get("hits", {}).get("hits", []):
|
| 170 |
+
inner_hits = h.get("inner_hits", {})
|
| 171 |
+
|
| 172 |
+
if not inner_hits and "news" in h.get("_index"):
|
| 173 |
+
inner_hits = {"text": h.get("_source", {}).get("content")}
|
| 174 |
+
|
| 175 |
+
yield ElasticHitsResult(
|
| 176 |
+
index=h["_index"],
|
| 177 |
+
id=h["_id"],
|
| 178 |
+
score=h["_score"],
|
| 179 |
+
source=h["_source"],
|
| 180 |
+
inner_hits=inner_hits,
|
| 181 |
+
highlight=h.get("highlight", {})
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
results = []
|
| 185 |
+
if vector_searches is not None and len(vector_searches) > 0:
|
| 186 |
+
hits = multi_search_base(queries=vector_searches, credentials=SEMANTIC_ELASTIC_QA)
|
| 187 |
+
for hit in _msearch_response_generator(responses=hits):
|
| 188 |
+
results.append(hit)
|
| 189 |
+
if non_vector_searches is not None and len(non_vector_searches) > 0:
|
| 190 |
+
hits = multi_search_base(queries=non_vector_searches, credentials=NEWS_ELASTIC)
|
| 191 |
+
for hit in _msearch_response_generator(responses=hits):
|
| 192 |
+
results.append(hit)
|
| 193 |
+
return results
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def retrieved_text(hits: dict[str, Any]) -> str:
|
| 197 |
+
"""Extracts retrieved sub-texts from documents which are strong hits from semantic queries for the purpose of
|
| 198 |
+
re-scoring by a secondary language model.
|
| 199 |
+
|
| 200 |
+
Parameters
|
| 201 |
+
----------
|
| 202 |
+
hits : Dict[str, Any]
|
| 203 |
+
|
| 204 |
+
Returns
|
| 205 |
+
-------
|
| 206 |
+
str
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
nlp = CandidSLM()
|
| 210 |
+
|
| 211 |
+
text = []
|
| 212 |
+
for _, v in hits.items():
|
| 213 |
+
if _ == "text":
|
| 214 |
+
s = nlp.summarize(v, top_k=3)
|
| 215 |
+
text.append(s.summary)
|
| 216 |
+
# text.append(v)
|
| 217 |
+
continue
|
| 218 |
+
|
| 219 |
+
for h in (v.get("hits", {}).get("hits") or []):
|
| 220 |
+
for _, field in h.get("fields", {}).items():
|
| 221 |
+
for chunk in field:
|
| 222 |
+
if chunk.get("chunk"):
|
| 223 |
+
text.extend(chunk["chunk"])
|
| 224 |
+
return '\n'.join(text)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def reranker(
|
| 228 |
+
query_results: Iterable[ElasticHitsResult],
|
| 229 |
+
search_text: str | None = None,
|
| 230 |
+
max_num_results: int = 5
|
| 231 |
+
) -> Iterator[ElasticHitsResult]:
|
| 232 |
+
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
|
| 233 |
+
This will shuffle results
|
| 234 |
+
|
| 235 |
+
Parameters
|
| 236 |
+
----------
|
| 237 |
+
query_results : Iterable[ElasticHitsResult]
|
| 238 |
+
|
| 239 |
+
Yields
|
| 240 |
+
------
|
| 241 |
+
Iterator[ElasticHitsResult]
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
results: list[ElasticHitsResult] = []
|
| 245 |
+
texts: list[str] = []
|
| 246 |
+
for _, data in groupby(query_results, key=lambda x: x.index):
|
| 247 |
+
data = list(data) # noqa: PLW2901
|
| 248 |
+
max_score = max(data, key=lambda x: x.score).score
|
| 249 |
+
min_score = min(data, key=lambda x: x.score).score
|
| 250 |
+
|
| 251 |
+
for d in data:
|
| 252 |
+
d.score = (d.score - min_score) / (max_score - min_score + 1e-9)
|
| 253 |
+
results.append(d)
|
| 254 |
+
|
| 255 |
+
if search_text:
|
| 256 |
+
if d.inner_hits:
|
| 257 |
+
text = retrieved_text(d.inner_hits)
|
| 258 |
+
if d.highlight:
|
| 259 |
+
highlight_texts = []
|
| 260 |
+
for k,v in d.highlight.items():
|
| 261 |
+
v_text = '\n'.join(v)
|
| 262 |
+
highlight_texts.append(v_text)
|
| 263 |
+
text = '\n'.join(highlight_texts)
|
| 264 |
+
texts.append(text)
|
| 265 |
+
|
| 266 |
+
if search_text and len(texts) == len(results) and len(texts) > 1:
|
| 267 |
+
logger.info("Re-ranking %d retrieval results", len(results))
|
| 268 |
+
scores = sparse_encoder.query_reranking(query=search_text, documents=texts)
|
| 269 |
+
for r, s in zip(results, scores):
|
| 270 |
+
r.score = s
|
| 271 |
+
|
| 272 |
+
yield from sorted(results, key=lambda x: x.score, reverse=True)[:max_num_results]
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def process_hit(hit: ElasticHitsResult) -> Document:
|
| 276 |
+
if "issuelab-elser" in hit.index:
|
| 277 |
+
doc = Document(
|
| 278 |
+
page_content='\n\n'.join([
|
| 279 |
+
hit.source.get("combined_item_description", ""),
|
| 280 |
+
hit.source.get("description", ""),
|
| 281 |
+
hit.source.get("combined_issuelab_findings", ""),
|
| 282 |
+
get_context("content", hit, context_length=12)
|
| 283 |
+
]),
|
| 284 |
+
metadata={
|
| 285 |
+
"title": hit.source["title"],
|
| 286 |
+
"source": "IssueLab",
|
| 287 |
+
"source_id": hit.source["resource_id"],
|
| 288 |
+
"url": hit.source.get("permalink", "")
|
| 289 |
+
}
|
| 290 |
+
)
|
| 291 |
+
elif "youtube" in hit.index:
|
| 292 |
+
highlight = hit.highlight or {}
|
| 293 |
+
doc = Document(
|
| 294 |
+
page_content='\n\n'.join([
|
| 295 |
+
hit.source.get("title", ""),
|
| 296 |
+
hit.source.get("semantic_description", ""),
|
| 297 |
+
' '.join(highlight.get("semantic_cc_text", []))
|
| 298 |
+
]),
|
| 299 |
+
metadata={
|
| 300 |
+
"title": hit.source.get("title", ""),
|
| 301 |
+
"source": "Candid YouTube",
|
| 302 |
+
"source_id": hit.source['video_id'],
|
| 303 |
+
"url": f"https://www.youtube.com/watch?v={hit.source['video_id']}"
|
| 304 |
+
}
|
| 305 |
+
)
|
| 306 |
+
elif "candid-blog" in hit.index:
|
| 307 |
+
doc = Document(
|
| 308 |
+
page_content='\n\n'.join([
|
| 309 |
+
hit.source.get("title", ""),
|
| 310 |
+
hit.source.get("excerpt", ""),
|
| 311 |
+
get_context("content", hit, context_length=12, add_context=False),
|
| 312 |
+
get_context("authors_text", hit, context_length=12, add_context=False),
|
| 313 |
+
hit.source.get("title_summary_tags", "")
|
| 314 |
+
]),
|
| 315 |
+
metadata={
|
| 316 |
+
"title": hit.source.get("title", ""),
|
| 317 |
+
"source": "Candid Blog",
|
| 318 |
+
"source_id": hit.source["id"],
|
| 319 |
+
"url": hit.source["link"]
|
| 320 |
+
}
|
| 321 |
+
)
|
| 322 |
+
elif "candid-learning" in hit.index:
|
| 323 |
+
doc = Document(
|
| 324 |
+
page_content='\n\n'.join([
|
| 325 |
+
hit.source.get("title", ""),
|
| 326 |
+
hit.source.get("staff_recommendations", ""),
|
| 327 |
+
hit.source.get("training_topics", ""),
|
| 328 |
+
get_context("content", hit, context_length=12)
|
| 329 |
+
]),
|
| 330 |
+
metadata={
|
| 331 |
+
"title": hit.source["title"],
|
| 332 |
+
"source": "Candid Learning",
|
| 333 |
+
"source_id": hit.source["post_id"],
|
| 334 |
+
"url": hit.source.get("url", "")
|
| 335 |
+
}
|
| 336 |
+
)
|
| 337 |
+
elif "candid-help" in hit.index:
|
| 338 |
+
doc = Document(
|
| 339 |
+
page_content='\n\n'.join([
|
| 340 |
+
hit.source.get("combined_article_description", ""),
|
| 341 |
+
get_context("content", hit, context_length=12)
|
| 342 |
+
]),
|
| 343 |
+
metadata={
|
| 344 |
+
"title": hit.source.get("title", ""),
|
| 345 |
+
"source": "Candid Help",
|
| 346 |
+
"source_id": hit.source["id"],
|
| 347 |
+
"url": hit.source.get("link", "")
|
| 348 |
+
}
|
| 349 |
+
)
|
| 350 |
+
elif "news" in hit.index:
|
| 351 |
+
doc = Document(
|
| 352 |
+
page_content='\n\n'.join([hit.source.get("title", ""), hit.source.get("content", "")]),
|
| 353 |
+
metadata={
|
| 354 |
+
"title": hit.source.get("title", ""),
|
| 355 |
+
"source": hit.source.get("site_name") or "Candid News",
|
| 356 |
+
"source_id": hit.source["id"],
|
| 357 |
+
"url": hit.source.get("link", "")
|
| 358 |
+
}
|
| 359 |
+
)
|
| 360 |
+
else:
|
| 361 |
+
raise ValueError(f"Unknown source result from index {hit.index}")
|
| 362 |
+
return doc
|
ask_candid/base/retrieval/schemas.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class ElasticSourceConfig:
|
| 7 |
+
index_name: str
|
| 8 |
+
semantic_fields: tuple[str,...] = field(default_factory=tuple)
|
| 9 |
+
text_fields: tuple[str,...] | None = field(default_factory=tuple)
|
| 10 |
+
highlight_fields: tuple[str,...] | None = field(default_factory=tuple)
|
| 11 |
+
excluded_fields: tuple[str,...] | None = field(default_factory=tuple)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class ElasticHitsResult:
|
| 16 |
+
"""Dataclass for Elasticsearch hits results
|
| 17 |
+
"""
|
| 18 |
+
index: str
|
| 19 |
+
id: Any
|
| 20 |
+
score: float
|
| 21 |
+
source: dict[str, Any]
|
| 22 |
+
inner_hits: dict[str, Any] | None
|
| 23 |
+
highlight: dict[str, list[str]] | None
|
ask_candid/base/retrieval/sources.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ask_candid.base.retrieval.schemas import ElasticSourceConfig
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
CandidBlogConfig = ElasticSourceConfig(
|
| 5 |
+
index_name="search-semantic-candid-blog",
|
| 6 |
+
semantic_fields=("content", "authors_text", "title_summary_tags")
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
CandidHelpConfig = ElasticSourceConfig(
|
| 11 |
+
index_name="search-semantic-candid-help-elser_ve1",
|
| 12 |
+
semantic_fields=("content", "combined_article_description")
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
CandidLearningConfig = ElasticSourceConfig(
|
| 17 |
+
index_name="search-semantic-candid-learning_ve1",
|
| 18 |
+
semantic_fields=("content", "title", "training_topics", "staff_recommendations")
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
CandidNewsConfig = ElasticSourceConfig(
|
| 23 |
+
index_name="news_1",
|
| 24 |
+
semantic_fields=("title", "content")
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
IssueLabConfig = ElasticSourceConfig(
|
| 29 |
+
index_name="search-semantic-issuelab-elser_ve2",
|
| 30 |
+
semantic_fields=("description", "content", "combined_issuelab_findings", "combined_item_description")
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
YoutubeConfig = ElasticSourceConfig(
|
| 35 |
+
index_name="search-semantic-youtube",
|
| 36 |
+
semantic_fields=("semantic_title", "semantic_description","semantic_cc_text"),
|
| 37 |
+
text_fields=("title", "description", "cc_text"),
|
| 38 |
+
highlight_fields=("semantic_cc_text",),
|
| 39 |
+
excluded_fields=("cc_text", "semantic_cc_text", "semantic_title")
|
| 40 |
+
)
|
ask_candid/base/retrieval/sparse_lexical.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tqdm.auto import tqdm
|
| 2 |
+
|
| 3 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
| 4 |
+
from transformers.tokenization_utils_base import BatchEncoding
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SpladeEncoder:
|
| 12 |
+
batch_size = 8
|
| 13 |
+
model_id = "naver/splade-v3"
|
| 14 |
+
|
| 15 |
+
def __init__(self):
|
| 16 |
+
|
| 17 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
| 18 |
+
self.model = AutoModelForMaskedLM.from_pretrained(self.model_id)
|
| 19 |
+
self.idx2token = {idx: token for token, idx in self.tokenizer.get_vocab().items()}
|
| 20 |
+
|
| 21 |
+
if torch.cuda.is_available():
|
| 22 |
+
self.device = torch.device("cuda")
|
| 23 |
+
elif torch.mps.is_available():
|
| 24 |
+
self.device = torch.device("mps")
|
| 25 |
+
else:
|
| 26 |
+
self.device = torch.device("cpu")
|
| 27 |
+
self.model.to(self.device)
|
| 28 |
+
|
| 29 |
+
@torch.no_grad()
|
| 30 |
+
def forward(self, inputs: BatchEncoding) -> Tensor:
|
| 31 |
+
output = self.model(**inputs.to(self.device))
|
| 32 |
+
|
| 33 |
+
logits: Tensor = output.logits
|
| 34 |
+
mask: Tensor = inputs.attention_mask
|
| 35 |
+
|
| 36 |
+
vec = (logits.relu() + 1).log() * mask.unsqueeze(dim=-1)
|
| 37 |
+
return vec.max(dim=1)[0].squeeze()
|
| 38 |
+
|
| 39 |
+
def encode(self, texts: list[str]) -> Tensor:
|
| 40 |
+
"""Forward pass to get dense vectors
|
| 41 |
+
|
| 42 |
+
Parameters
|
| 43 |
+
----------
|
| 44 |
+
texts : list[str]
|
| 45 |
+
|
| 46 |
+
Returns
|
| 47 |
+
-------
|
| 48 |
+
torch.Tensor
|
| 49 |
+
Dense vectors
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
vectors = []
|
| 53 |
+
for batch in tqdm(DataLoader(dataset=texts, shuffle=False, batch_size=self.batch_size), desc="Encoding"): # type: ignore
|
| 54 |
+
tokens = self.tokenizer(batch, return_tensors='pt', truncation=True, padding=True)
|
| 55 |
+
vec = self.forward(inputs=tokens)
|
| 56 |
+
vectors.append(vec)
|
| 57 |
+
return torch.vstack(vectors)
|
| 58 |
+
|
| 59 |
+
def query_reranking(self, query: str, documents: list[str]) -> list[float]:
|
| 60 |
+
"""Cosine similarity re-ranking.
|
| 61 |
+
|
| 62 |
+
Parameters
|
| 63 |
+
----------
|
| 64 |
+
query : str
|
| 65 |
+
Retrieval query
|
| 66 |
+
documents : list[str]
|
| 67 |
+
Retrieved documents
|
| 68 |
+
|
| 69 |
+
Returns
|
| 70 |
+
-------
|
| 71 |
+
list[float]
|
| 72 |
+
Cosine values
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
vec = self.encode([query, *documents])
|
| 76 |
+
xQ = F.normalize(vec[:1], dim=-1, p=2.)
|
| 77 |
+
xD = F.normalize(vec[1:], dim=-1, p=2.)
|
| 78 |
+
return (xQ * xD).sum(dim=-1).cpu().tolist()
|
| 79 |
+
|
| 80 |
+
def token_expand(self, query: str) -> dict[str, float]:
|
| 81 |
+
"""Sparse lexical token expansion.
|
| 82 |
+
|
| 83 |
+
Parameters
|
| 84 |
+
----------
|
| 85 |
+
query : str
|
| 86 |
+
Retrieval query
|
| 87 |
+
|
| 88 |
+
Returns
|
| 89 |
+
-------
|
| 90 |
+
dict[str, float]
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
vec = self.encode([query]).squeeze()
|
| 94 |
+
cols = vec.nonzero().squeeze().cpu().tolist()
|
| 95 |
+
weights = vec[cols].cpu().tolist()
|
| 96 |
+
|
| 97 |
+
sparse_dict_tokens = {self.idx2token[idx]: round(weight, 3) for idx, weight in zip(cols, weights) if weight > 0}
|
| 98 |
+
return dict(sorted(sparse_dict_tokens.items(), key=lambda item: item[1], reverse=True))
|
ask_candid/base/utils.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
|
| 3 |
|
|
@@ -12,3 +15,52 @@ def async_tasks(*tasks):
|
|
| 12 |
loop.stop()
|
| 13 |
loop.close()
|
| 14 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections.abc import Callable
|
| 2 |
+
from functools import wraps
|
| 3 |
+
from time import sleep
|
| 4 |
import asyncio
|
| 5 |
|
| 6 |
|
|
|
|
| 15 |
loop.stop()
|
| 16 |
loop.close()
|
| 17 |
return results
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def retry_on_status(
|
| 21 |
+
num_retries: int = 3,
|
| 22 |
+
backoff_factor: float = 0.5,
|
| 23 |
+
max_backoff: float | None = None,
|
| 24 |
+
retry_statuses: tuple[int, ...] = (501, 503)
|
| 25 |
+
):
|
| 26 |
+
"""
|
| 27 |
+
Retry decorator for functions making httpx requests.
|
| 28 |
+
Retries on specific HTTP status codes with exponential backoff.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
num_retries (int): Max number of retries.
|
| 32 |
+
backoff_factor (float): Multiplier for delay (e.g., 0.5, 1, etc.).
|
| 33 |
+
max_backoff (float, optional): Cap on the backoff delay in seconds.
|
| 34 |
+
retry_statuses (tuple): HTTP status codes to retry on.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def decorator(func: Callable):
|
| 38 |
+
|
| 39 |
+
if asyncio.iscoroutinefunction(func):
|
| 40 |
+
# Async version
|
| 41 |
+
@wraps(func)
|
| 42 |
+
async def async_wrapper(*args, **kwargs):
|
| 43 |
+
for attempt in range(num_retries + 1):
|
| 44 |
+
response = await func(*args, **kwargs)
|
| 45 |
+
if response.status_code not in retry_statuses:
|
| 46 |
+
return response
|
| 47 |
+
if attempt < num_retries:
|
| 48 |
+
delay = min(backoff_factor * (2 ** attempt), max_backoff or float('inf'))
|
| 49 |
+
await asyncio.sleep(delay)
|
| 50 |
+
return response
|
| 51 |
+
return async_wrapper
|
| 52 |
+
|
| 53 |
+
# Sync version
|
| 54 |
+
@wraps(func)
|
| 55 |
+
def sync_wrapper(*args, **kwargs):
|
| 56 |
+
for attempt in range(num_retries + 1):
|
| 57 |
+
response = func(*args, **kwargs)
|
| 58 |
+
if response.status_code not in retry_statuses:
|
| 59 |
+
return response
|
| 60 |
+
if attempt < num_retries:
|
| 61 |
+
delay = min(backoff_factor * (2 ** attempt), max_backoff or float('inf'))
|
| 62 |
+
sleep(delay)
|
| 63 |
+
return response
|
| 64 |
+
return sync_wrapper
|
| 65 |
+
|
| 66 |
+
return decorator
|
ask_candid/chat.py
CHANGED
|
@@ -1,66 +1,79 @@
|
|
| 1 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
from
|
| 5 |
-
from langgraph.checkpoint.memory import MemorySaver
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
indices: Optional[List[str]] = None,
|
| 18 |
-
premium_features: Optional[List[str]] = None,
|
| 19 |
-
) -> Tuple[gr.MultimodalTextbox, List[Dict[str, Any]], str]:
|
| 20 |
-
if premium_features is None:
|
| 21 |
-
premium_features = []
|
| 22 |
-
if len(history) == 0:
|
| 23 |
-
history.append({"role": "system", "content": START_SYSTEM_PROMPT})
|
| 24 |
|
| 25 |
-
history.append({"role": "user", "content": user_input["text"]})
|
| 26 |
-
inputs = {"messages": history}
|
| 27 |
-
# thread_id can be an email https://github.com/yurisasc/memory-enhanced-ai-assistant/blob/main/assistant.py
|
| 28 |
-
thread_id = get_session_id(thread_id)
|
| 29 |
-
config = {"configurable": {"thread_id": thread_id}}
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
enable_recommendations=enable_recommendations
|
| 37 |
-
)
|
| 38 |
|
| 39 |
-
memory = MemorySaver() # TODO: don't use for Prod
|
| 40 |
-
graph = workflow.compile(checkpointer=memory)
|
| 41 |
-
response = graph.invoke(inputs, config=config)
|
| 42 |
-
messages = response["messages"]
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
last_message = messages[-1]
|
| 51 |
-
ai_answer = last_message.content
|
| 52 |
|
| 53 |
-
sources_html = ""
|
| 54 |
-
for message in messages[-2:]:
|
| 55 |
-
if message.type == "HTML":
|
| 56 |
-
sources_html = message.content
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
"metadata": {"title": "Sources HTML"},
|
| 64 |
-
})
|
| 65 |
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TypedDict, Literal, Any
|
| 2 |
+
from collections.abc import Iterator
|
| 3 |
+
from dataclasses import asdict
|
| 4 |
+
import logging
|
| 5 |
+
import json
|
| 6 |
|
| 7 |
+
from langchain_core.messages.tool import ToolMessage
|
| 8 |
+
from gradio import ChatMessage
|
|
|
|
| 9 |
|
| 10 |
+
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
logger.setLevel(logging.INFO)
|
| 13 |
|
| 14 |
|
| 15 |
+
class ToolInput(TypedDict):
|
| 16 |
+
name: str
|
| 17 |
+
args: dict[str, Any]
|
| 18 |
+
id: str
|
| 19 |
+
type: Literal["tool_call"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
class CalledTool(TypedDict):
|
| 23 |
+
id: str
|
| 24 |
+
name: Literal["tools"]
|
| 25 |
+
input: list[ToolInput]
|
| 26 |
+
triggers: tuple[str, ...]
|
|
|
|
|
|
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
class ToolResult(TypedDict):
|
| 30 |
+
id: str
|
| 31 |
+
name: Literal["tools"]
|
| 32 |
+
error: bool | None
|
| 33 |
+
result: list[tuple[str, list[ToolMessage]]]
|
| 34 |
+
interrupts: list
|
|
|
|
|
|
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
def convert_history_for_graph_agent(history: list[dict | ChatMessage]) -> list[dict]:
|
| 38 |
+
_hist = []
|
| 39 |
+
for h in history:
|
| 40 |
+
if isinstance(h, ChatMessage):
|
| 41 |
+
h = asdict(h)
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
if h.get("content"):
|
| 44 |
+
# if h.get("metadata"):
|
| 45 |
+
# # skip if it's a tool-call
|
| 46 |
+
# continue
|
| 47 |
+
_hist.append(h)
|
| 48 |
+
return _hist
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def format_tool_call(input_chunk: CalledTool) -> Iterator[ChatMessage]:
|
| 52 |
+
for graph_input in input_chunk["input"]:
|
| 53 |
+
yield ChatMessage(
|
| 54 |
+
role="assistant",
|
| 55 |
+
content=json.dumps(graph_input["args"]),
|
| 56 |
+
metadata={
|
| 57 |
+
"title": f"Using tool `{graph_input.get('name')}`",
|
| 58 |
+
"status": "done",
|
| 59 |
+
"id": input_chunk["id"],
|
| 60 |
+
"parent_id": input_chunk["id"]
|
| 61 |
+
}
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def format_tool_response(result_chunk: ToolResult) -> Iterator[ChatMessage]:
|
| 66 |
+
for _, outputs in result_chunk["result"]:
|
| 67 |
+
for tool in outputs:
|
| 68 |
+
logger.info("Called tool `%s`", tool.name)
|
| 69 |
+
yield ChatMessage(
|
| 70 |
+
role="assistant",
|
| 71 |
+
content=tool.content,
|
| 72 |
+
metadata={
|
| 73 |
+
"title": f"Results from tool `{tool.name}`",
|
| 74 |
+
"tool_name": tool.name,
|
| 75 |
+
"documents": tool.artifact,
|
| 76 |
+
"status": "done",
|
| 77 |
+
"parent_id": result_chunk["id"]
|
| 78 |
+
} # pyright: ignore[reportArgumentType]
|
| 79 |
+
)
|
ask_candid/services/small_lm.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
from typing import List, Optional
|
| 2 |
from dataclasses import dataclass
|
| 3 |
from enum import Enum
|
| 4 |
|
|
@@ -9,10 +8,26 @@ from ask_candid.base.lambda_base import LambdaInvokeBase
|
|
| 9 |
|
| 10 |
@dataclass(slots=True)
|
| 11 |
class Encoding:
|
| 12 |
-
inputs:
|
| 13 |
vectors: torch.Tensor
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
class CandidSLM(LambdaInvokeBase):
|
| 17 |
"""Wrapper around Candid's custom small language model.
|
| 18 |
For more details see https://dev.azure.com/guidestar/DataScience/_git/graph-ai?path=/releases/language.
|
|
@@ -35,7 +50,7 @@ class CandidSLM(LambdaInvokeBase):
|
|
| 35 |
DOCUMENT_NER_SALIENCE = "/document/entitySalience"
|
| 36 |
|
| 37 |
def __init__(
|
| 38 |
-
self, access_key:
|
| 39 |
) -> None:
|
| 40 |
super().__init__(
|
| 41 |
function_name="small-lm",
|
|
@@ -43,11 +58,22 @@ class CandidSLM(LambdaInvokeBase):
|
|
| 43 |
secret_key=secret_key
|
| 44 |
)
|
| 45 |
|
| 46 |
-
def encode(self, text:
|
| 47 |
response = self._submit_request({"text": text, "path": self.Tasks.ENCODE.value})
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
inputs=(response.get("inputs") or []),
|
| 51 |
vectors=torch.tensor((response.get("vectors") or []), dtype=torch.float32)
|
| 52 |
)
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
from enum import Enum
|
| 3 |
|
|
|
|
| 8 |
|
| 9 |
@dataclass(slots=True)
|
| 10 |
class Encoding:
|
| 11 |
+
inputs: list[str]
|
| 12 |
vectors: torch.Tensor
|
| 13 |
|
| 14 |
|
| 15 |
+
@dataclass(slots=True)
|
| 16 |
+
class SummaryItem:
|
| 17 |
+
rank: int
|
| 18 |
+
score: float
|
| 19 |
+
text: str
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass(slots=True)
|
| 23 |
+
class TextSummary:
|
| 24 |
+
snippets: list[SummaryItem]
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def summary(self) -> str:
|
| 28 |
+
return ' '.join([_.text for _ in self.snippets])
|
| 29 |
+
|
| 30 |
+
|
| 31 |
class CandidSLM(LambdaInvokeBase):
|
| 32 |
"""Wrapper around Candid's custom small language model.
|
| 33 |
For more details see https://dev.azure.com/guidestar/DataScience/_git/graph-ai?path=/releases/language.
|
|
|
|
| 50 |
DOCUMENT_NER_SALIENCE = "/document/entitySalience"
|
| 51 |
|
| 52 |
def __init__(
|
| 53 |
+
self, access_key: str | None = None, secret_key: str | None = None
|
| 54 |
) -> None:
|
| 55 |
super().__init__(
|
| 56 |
function_name="small-lm",
|
|
|
|
| 58 |
secret_key=secret_key
|
| 59 |
)
|
| 60 |
|
| 61 |
+
def encode(self, text: list[str]) -> Encoding:
|
| 62 |
response = self._submit_request({"text": text, "path": self.Tasks.ENCODE.value})
|
| 63 |
+
assert isinstance(response, dict)
|
| 64 |
|
| 65 |
+
return Encoding(
|
| 66 |
inputs=(response.get("inputs") or []),
|
| 67 |
vectors=torch.tensor((response.get("vectors") or []), dtype=torch.float32)
|
| 68 |
)
|
| 69 |
+
|
| 70 |
+
def summarize(self, text: list[str], top_k: int) -> TextSummary:
|
| 71 |
+
response = self._submit_request({"text": text, "path": self.Tasks.DOCUMENT_SUMMARIZE.value})
|
| 72 |
+
assert isinstance(response, dict)
|
| 73 |
+
|
| 74 |
+
return TextSummary(
|
| 75 |
+
snippets=[
|
| 76 |
+
SummaryItem(rank=item["rank"], score=item["score"], text=item["value"])
|
| 77 |
+
for item in (response.get("summary") or [])[:top_k]
|
| 78 |
+
]
|
| 79 |
+
)
|
ask_candid/tools/general.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import date
|
| 2 |
+
|
| 3 |
+
from langchain_core.tools import tool
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@tool
|
| 7 |
+
def get_current_day() -> date:
|
| 8 |
+
"""Get the current day to reference for any time-sensitive data requests. This might be useful for information
|
| 9 |
+
searches through news data, where more current articles may be more relevant.
|
| 10 |
+
|
| 11 |
+
Returns
|
| 12 |
+
-------
|
| 13 |
+
date
|
| 14 |
+
Today's date
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
return date.today()
|
ask_candid/tools/org_search.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
from langchain_core.output_parsers.pydantic import PydanticOutputParser
|
| 5 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
| 6 |
+
from langchain_core.runnables import RunnableSequence
|
| 7 |
+
from langchain_core.prompts import PromptTemplate
|
| 8 |
+
from langchain_core.tools import tool, BaseTool
|
| 9 |
+
|
| 10 |
+
from thefuzz import fuzz
|
| 11 |
+
|
| 12 |
+
from ask_candid.tools.utils import format_candid_profile_link
|
| 13 |
+
from ask_candid.base.api_base import BaseAPI
|
| 14 |
+
from ask_candid.base.config.rest import CANDID_SEARCH_API
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class OrganizationNames(BaseModel):
|
| 18 |
+
"""List of names of social-sector organizations, such as nonprofits and foundations."""
|
| 19 |
+
orgnames: list[str] = Field(..., description="List of organization names.")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class OrganizationIdentifierArgs(BaseModel):
|
| 23 |
+
text: str = Field(..., description="Chat model response text which contains named organizations.")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class OrganizationIdentifier(BaseTool):
|
| 27 |
+
llm: BaseChatModel
|
| 28 |
+
parser: type[PydanticOutputParser] = PydanticOutputParser(pydantic_object=OrganizationNames)
|
| 29 |
+
template: str = """Extract only the names of officially recognized organizations, foundations, and government
|
| 30 |
+
entities from the text below. Do not include any entries that contain descriptions, regional identifiers, or
|
| 31 |
+
explanations within parentheses or following the name. Strictly exclude databases, resources, crowdfunding
|
| 32 |
+
platforms, and general terms. Provide the output only in the specified JSON format.
|
| 33 |
+
|
| 34 |
+
input text: ```{chatbot_output}```
|
| 35 |
+
output format: ```{format_instructions}```
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
name: str = "organization-identifier"
|
| 39 |
+
description: str = """
|
| 40 |
+
Identify the names of nonprofits and foundations from chat model responses. If it is likely that a response contains
|
| 41 |
+
proper names then it should be processed through this tool.
|
| 42 |
+
|
| 43 |
+
Examples
|
| 44 |
+
--------
|
| 45 |
+
>>> `organization_identifier('My Favorite Foundation awarded a grant to My Favorite Nonprofit.')`
|
| 46 |
+
>>> `organization_identifier('The LoremIpsum Nonprofit will be running a community event this Thursday')`
|
| 47 |
+
"""
|
| 48 |
+
args_schema: type[OrganizationIdentifierArgs] = OrganizationIdentifierArgs
|
| 49 |
+
|
| 50 |
+
def _build_pipeline(self):
|
| 51 |
+
prompt = PromptTemplate(
|
| 52 |
+
template=self.template,
|
| 53 |
+
input_variables=["chatbot_output"],
|
| 54 |
+
partial_variables={"format_instructions": self.parser.get_format_instructions()}
|
| 55 |
+
)
|
| 56 |
+
return RunnableSequence(prompt, self.llm, self.parser)
|
| 57 |
+
|
| 58 |
+
def _run(self, text: str) -> str:
|
| 59 |
+
chain = self._build_pipeline()
|
| 60 |
+
result: OrganizationNames = chain.invoke({"chatbot_output": text})
|
| 61 |
+
return result.orgnames
|
| 62 |
+
|
| 63 |
+
async def _arun(self, text: str) -> str:
|
| 64 |
+
chain = self._build_pipeline()
|
| 65 |
+
result: OrganizationNames = await chain.ainvoke({"chatbot_output": text})
|
| 66 |
+
return result.orgnames
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def name_search(name: str) -> list[dict[str, Any]]:
|
| 70 |
+
candid_org_search = BaseAPI(
|
| 71 |
+
url=f'{CANDID_SEARCH_API["url"]}/v1/search',
|
| 72 |
+
headers={"x-api-key": CANDID_SEARCH_API["key"]}
|
| 73 |
+
)
|
| 74 |
+
results = candid_org_search.get(
|
| 75 |
+
query=f"'{name}'",
|
| 76 |
+
searchMode="organization_only",
|
| 77 |
+
rowCount=5
|
| 78 |
+
)
|
| 79 |
+
return results.get("returnedOrgs") or []
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def find_similar(name: str, potential_matches: list[dict[str, Any]], threshold: int = 80):
|
| 83 |
+
for org in potential_matches:
|
| 84 |
+
similarity = max(
|
| 85 |
+
fuzz.ratio(name.lower(), (org["orgName"] or "").lower()),
|
| 86 |
+
fuzz.ratio(name.lower(), (org["akaName"] or "").lower()),
|
| 87 |
+
fuzz.ratio(name.lower(), (org["dbaName"] or "").lower()),
|
| 88 |
+
)
|
| 89 |
+
if similarity >= threshold:
|
| 90 |
+
yield org, similarity
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@tool(response_format="content_and_artifact")
|
| 94 |
+
def find_mentioned_organizations(organizations: list[str]) -> tuple[str, dict[str, str]]:
|
| 95 |
+
"""Match organization names found in a chat response to official organizations tracked by Candid. This involves
|
| 96 |
+
using the Candid Search API in a lookup mode, and then finding the best result(s) using a heuristic string
|
| 97 |
+
similarity search.
|
| 98 |
+
|
| 99 |
+
This tool is focused on getting links to the organization's Candid profile for the user to click and explore in
|
| 100 |
+
more detail.
|
| 101 |
+
|
| 102 |
+
Use the URLs here to replace organization names in the chat response with links to the organization's profile. Links
|
| 103 |
+
to Candid profiles **MUST** be used to do the following:
|
| 104 |
+
1. Generate direct links to Candid organization profiles
|
| 105 |
+
2. Provide a mechanism for users to easily access detailed organizational information
|
| 106 |
+
3. Enhance responses with authoritative source links
|
| 107 |
+
|
| 108 |
+
Key Usage Requirements:
|
| 109 |
+
- Always incorporate returned profile URLs directly into the response text
|
| 110 |
+
- Replace organization name mentions with hyperlinked Candid profile URLs
|
| 111 |
+
- Prioritize creating a seamless user experience by making URLs contextually relevant
|
| 112 |
+
|
| 113 |
+
Example Desired Output:
|
| 114 |
+
Instead of: 'The Gates Foundation does impressive work.'
|
| 115 |
+
Use: 'The [Gates Foundation](https://app.candid.org/profile/XXXXX) does impressive work.'
|
| 116 |
+
|
| 117 |
+
The function returns a tuple with:
|
| 118 |
+
- A link information text (optional)
|
| 119 |
+
- A dictionary mapping input names to their best Candid Search profile URL
|
| 120 |
+
|
| 121 |
+
Failure to integrate the URLs into the response is considered an incomplete implementation.",
|
| 122 |
+
|
| 123 |
+
Examples
|
| 124 |
+
--------
|
| 125 |
+
>>> find_mentioned_organizations(organizations=['Gates Foundation', 'Candid'])
|
| 126 |
+
|
| 127 |
+
Parameters
|
| 128 |
+
----------
|
| 129 |
+
organizations : list[str]
|
| 130 |
+
A list of organization name strings found in a chat response message which need to be matches
|
| 131 |
+
|
| 132 |
+
Returns
|
| 133 |
+
-------
|
| 134 |
+
tuple[str, dict[str, str]]
|
| 135 |
+
(Link information text, mapping input name --> Candid Search profile URL of the best potential match)
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
output = {}
|
| 139 |
+
for name in organizations:
|
| 140 |
+
search_results = name_search(name)
|
| 141 |
+
try:
|
| 142 |
+
best_result, _ = max(find_similar(name=name, potential_matches=search_results), key=lambda x: x[-1])
|
| 143 |
+
except ValueError:
|
| 144 |
+
# no similar organizations could be found for this one, keep going
|
| 145 |
+
continue
|
| 146 |
+
output[name] = format_candid_profile_link(best_result["candidEntityID"])
|
| 147 |
+
|
| 148 |
+
response = [f"The Candid profile link for {name} is {url}" for name, url in output.items()]
|
| 149 |
+
return '. '.join(response), output
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@tool
|
| 153 |
+
def find_mentioned_organizations_detailed(organizations: list[str]) -> dict[str, dict[str, Any]]:
|
| 154 |
+
"""Match organization names found in a chat response to official organizations tracked by Candid. This involves
|
| 155 |
+
using the Candid Search API in a lookup mode, and then finding the best result(s) using a heuristic string
|
| 156 |
+
similarity search.
|
| 157 |
+
|
| 158 |
+
Examples
|
| 159 |
+
--------
|
| 160 |
+
>>> find_mentioned_organizations(organizations=['Gates Foundation', 'Candid'])
|
| 161 |
+
|
| 162 |
+
Parameters
|
| 163 |
+
----------
|
| 164 |
+
organizations : list[str]
|
| 165 |
+
A list of organization name strings found in a chat response message which need to be matches
|
| 166 |
+
|
| 167 |
+
Returns
|
| 168 |
+
-------
|
| 169 |
+
dict[str, dict[str, Any]]
|
| 170 |
+
Mapping from the input name(s) to the best potential match.
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
output = {}
|
| 174 |
+
for name in organizations:
|
| 175 |
+
search_results = name_search(name)
|
| 176 |
+
try:
|
| 177 |
+
best_result, _ = max(find_similar(name=name, potential_matches=search_results), key=lambda x: x[-1])
|
| 178 |
+
except ValueError:
|
| 179 |
+
# no similar organizations could be found for this one, keep going
|
| 180 |
+
continue
|
| 181 |
+
output[name] = best_result
|
| 182 |
+
return output
|
ask_candid/tools/search.py
CHANGED
|
@@ -1,122 +1,67 @@
|
|
| 1 |
-
from typing import List, Tuple, Callable, Optional, Any
|
| 2 |
-
from functools import partial
|
| 3 |
-
import logging
|
| 4 |
-
|
| 5 |
-
from pydantic import BaseModel, Field
|
| 6 |
-
from langchain_core.language_models.llms import LLM
|
| 7 |
from langchain_core.documents import Document
|
| 8 |
-
from langchain_core.tools import
|
| 9 |
-
|
| 10 |
-
from ask_candid.retrieval.
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
Parameters
|
| 32 |
----------
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
Returns
|
| 41 |
-------
|
| 42 |
-
|
| 43 |
-
|
| 44 |
"""
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
logger.warning("User callback was passed in but failed: %s", ex)
|
| 51 |
-
|
| 52 |
-
output = ["Search didn't return any Candid sources"]
|
| 53 |
-
page_content = []
|
| 54 |
-
content = "Search didn't return any Candid sources"
|
| 55 |
-
results = get_query_results(search_text=user_input, indices=indices)
|
| 56 |
-
if results:
|
| 57 |
-
output = get_reranked_results(results, search_text=user_input)
|
| 58 |
-
for doc in output:
|
| 59 |
-
page_content.append(doc.page_content)
|
| 60 |
-
content = "\n\n".join(page_content)
|
| 61 |
-
|
| 62 |
-
# for the tool we need to return a tuple for content_and_artifact type
|
| 63 |
-
return content, output
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def retriever_tool(
|
| 67 |
-
indices: List[DataIndices],
|
| 68 |
-
user_callback: Optional[Callable[[str], Any]] = None
|
| 69 |
-
) -> Tool:
|
| 70 |
-
"""Tool component for use in conditional edge building for RAG execution graph.
|
| 71 |
-
Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way
|
| 72 |
-
https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution
|
| 73 |
-
|
| 74 |
-
Parameters
|
| 75 |
-
----------
|
| 76 |
-
indices : List[DataIndices]
|
| 77 |
-
Semantic index names to search over
|
| 78 |
-
user_callback : Optional[Callable[[str], Any]], optional
|
| 79 |
-
Optional UI callback to inform the user of apps states, by default None
|
| 80 |
-
|
| 81 |
-
Returns
|
| 82 |
-
-------
|
| 83 |
-
Tool
|
| 84 |
-
"""
|
| 85 |
-
|
| 86 |
-
return Tool(
|
| 87 |
-
name="retrieve_social_sector_information",
|
| 88 |
-
func=partial(get_search_results, indices=indices, user_callback=user_callback),
|
| 89 |
-
description=(
|
| 90 |
-
"Return additional information about social and philanthropic sector, "
|
| 91 |
-
"including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid."
|
| 92 |
-
),
|
| 93 |
-
args_schema=RetrieverInput,
|
| 94 |
-
response_format="content_and_artifact"
|
| 95 |
)
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
state : _type_
|
| 105 |
-
The current state
|
| 106 |
-
llm : LLM
|
| 107 |
-
tools : List[Tool]
|
| 108 |
-
|
| 109 |
-
Returns
|
| 110 |
-
-------
|
| 111 |
-
AgentState
|
| 112 |
-
The updated state with the agent response appended to messages
|
| 113 |
-
"""
|
| 114 |
-
|
| 115 |
-
logger.info("---SEARCH AGENT---")
|
| 116 |
-
messages = state["messages"]
|
| 117 |
-
question = messages[-1].content
|
| 118 |
-
|
| 119 |
-
model = llm.bind_tools(tools)
|
| 120 |
-
response = model.invoke(messages)
|
| 121 |
-
# return a list, because this will get added to the existing list
|
| 122 |
-
return {"messages": [response], "user_input": question}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from langchain_core.documents import Document
|
| 2 |
+
from langchain_core.tools import tool
|
| 3 |
+
|
| 4 |
+
from ask_candid.base.retrieval.knowledge_base import (
|
| 5 |
+
SourceNames,
|
| 6 |
+
generate_queries,
|
| 7 |
+
run_search,
|
| 8 |
+
reranker,
|
| 9 |
+
process_hit
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@tool(response_format="content_and_artifact")
|
| 14 |
+
def search_candid_knowledge_base(
|
| 15 |
+
query: str,
|
| 16 |
+
sources: list[SourceNames],
|
| 17 |
+
news_days_ago: int = 60
|
| 18 |
+
) -> tuple[str, list[Document]]:
|
| 19 |
+
"""Search Candid's subject matter expert knowledge base to find answers about the social and philanthropic sector.
|
| 20 |
+
This knowledge includes help articles and video training sessions from Candid's subject matter experts, blog posts
|
| 21 |
+
about the sector from Candid staff and trusted partner authors, research documents about the sector and news
|
| 22 |
+
articles curated about activity happening in the sector around the world.
|
| 23 |
+
|
| 24 |
+
Searches are performed through a combination of vector and keyword searching. Results are then re-ranked against
|
| 25 |
+
the original query to get the best results.
|
| 26 |
+
|
| 27 |
+
Search results often come back with specific organizations named, especially if referencing the news. In these cases
|
| 28 |
+
the organizations should be identified in Candid's data and links to their profiles **MUST** be included in final
|
| 29 |
+
chat response to the user.
|
| 30 |
|
| 31 |
Parameters
|
| 32 |
----------
|
| 33 |
+
query : str
|
| 34 |
+
Text describing a user's question or a description of investigative work which requires support from Candid's
|
| 35 |
+
knowledge base
|
| 36 |
+
sources : list[SourceNames]
|
| 37 |
+
One or more sources of knowledge from different areas at Candid.
|
| 38 |
+
* Candid Blog: Blog posts from Candid staff and trusted partners intended to help those in the sector or
|
| 39 |
+
illuminate ongoing work
|
| 40 |
+
* Candid Help: Candid FAQs to help user's get started with Candid's product platform and learning resources
|
| 41 |
+
* Candid Learning: Training documents from Candid's subject matter experts
|
| 42 |
+
* Candid News: News articles and press releases about real-time activity in the philanthropic sector
|
| 43 |
+
* IssueLab Research Reports: Academic research reports about the social/philanthropic sector
|
| 44 |
+
* YouTube Training: Transcripts from video-based training seminars from Candid's subject matter experts
|
| 45 |
+
news_days_ago : int, optional
|
| 46 |
+
How many days in the past to search for news articles, if a user is asking for recent trends then this value
|
| 47 |
+
should be set lower >~ 10, by default 60
|
| 48 |
|
| 49 |
Returns
|
| 50 |
-------
|
| 51 |
+
str
|
| 52 |
+
Re-ranked document text
|
| 53 |
"""
|
| 54 |
|
| 55 |
+
vector_queries, quasi_vector_queries = generate_queries(
|
| 56 |
+
query=query,
|
| 57 |
+
sources=sources,
|
| 58 |
+
news_days_ago=news_days_ago
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
)
|
| 60 |
|
| 61 |
+
results = run_search(vector_searches=vector_queries, non_vector_searches=quasi_vector_queries)
|
| 62 |
+
text_response = []
|
| 63 |
+
response_sources = []
|
| 64 |
+
for hit in map(process_hit, reranker(results, search_text=query)):
|
| 65 |
+
text_response.append(hit.page_content)
|
| 66 |
+
response_sources.append(hit)
|
| 67 |
+
return '\n\n'.join(text_response), response_sources
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ask_candid/tools/utils.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def format_candid_profile_link(candid_entity_id: int | str) -> str:
|
| 2 |
+
"""Format the Candid Search organization profile link.
|
| 3 |
+
|
| 4 |
+
Parameters
|
| 5 |
+
----------
|
| 6 |
+
candid_entity_id : int | str
|
| 7 |
+
|
| 8 |
+
Returns
|
| 9 |
+
-------
|
| 10 |
+
str
|
| 11 |
+
URL
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
return f"https://app.candid.org/profile/{candid_entity_id}"
|
chat_v2.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TypedDict, Any
|
| 2 |
+
from collections.abc import Iterator, AsyncIterator
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
from langgraph.graph.state import CompiledStateGraph
|
| 8 |
+
from langgraph.prebuilt import create_react_agent
|
| 9 |
+
from langchain_aws import ChatBedrock
|
| 10 |
+
import boto3
|
| 11 |
+
|
| 12 |
+
from ask_candid.tools.org_search import OrganizationIdentifier, find_mentioned_organizations
|
| 13 |
+
from ask_candid.tools.search import search_candid_knowledge_base
|
| 14 |
+
from ask_candid.tools.general import get_current_day
|
| 15 |
+
from ask_candid.utils import html_format_docs_chat
|
| 16 |
+
from ask_candid.base.config.constants import START_SYSTEM_PROMPT
|
| 17 |
+
from ask_candid.base.config.models import Name2Endpoint
|
| 18 |
+
from ask_candid.chat import convert_history_for_graph_agent, format_tool_call, format_tool_response
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from feedback import FeedbackApi
|
| 22 |
+
ROOT = "."
|
| 23 |
+
except ImportError:
|
| 24 |
+
from demos.feedback import FeedbackApi
|
| 25 |
+
ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..")
|
| 26 |
+
|
| 27 |
+
BOT_LOGO = os.path.join(ROOT, "static", "candid_logo_yellow.png")
|
| 28 |
+
if not os.path.isfile(BOT_LOGO):
|
| 29 |
+
BOT_LOGO = os.path.join(ROOT, "..", "..", "static", "candid_logo_yellow.png")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class LoggedComponents(TypedDict):
|
| 33 |
+
context: list[gr.Component]
|
| 34 |
+
found_helpful: gr.Component
|
| 35 |
+
will_recommend: gr.Component
|
| 36 |
+
comments: gr.Component
|
| 37 |
+
email: gr.Component
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def build_execution_graph() -> CompiledStateGraph:
|
| 41 |
+
llm = ChatBedrock(
|
| 42 |
+
client=boto3.client("bedrock-runtime", region_name="us-east-1"),
|
| 43 |
+
model=Name2Endpoint["claude-3.5-haiku"]
|
| 44 |
+
)
|
| 45 |
+
org_name_recognition = OrganizationIdentifier(llm=llm) # bind the main chat model to the tool
|
| 46 |
+
return create_react_agent(
|
| 47 |
+
model=llm,
|
| 48 |
+
tools=[
|
| 49 |
+
get_current_day,
|
| 50 |
+
org_name_recognition,
|
| 51 |
+
find_mentioned_organizations,
|
| 52 |
+
search_candid_knowledge_base
|
| 53 |
+
],
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def generate_postscript_messages(history: list[gr.ChatMessage]) -> Iterator[gr.ChatMessage]:
|
| 58 |
+
for record in history:
|
| 59 |
+
title = record.metadata.get("tool_name")
|
| 60 |
+
if title == search_candid_knowledge_base.name:
|
| 61 |
+
yield gr.ChatMessage(
|
| 62 |
+
role="assistant",
|
| 63 |
+
content=html_format_docs_chat(record.metadata.get("documents")),
|
| 64 |
+
metadata={
|
| 65 |
+
"title": "Source citations",
|
| 66 |
+
}
|
| 67 |
+
)
|
| 68 |
+
elif title == find_mentioned_organizations.name:
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
async def execute(
|
| 73 |
+
user_input: dict[str, Any],
|
| 74 |
+
history: list[gr.ChatMessage]
|
| 75 |
+
) -> AsyncIterator[tuple[gr.Component, list[gr.ChatMessage]]]:
|
| 76 |
+
if len(history) == 0:
|
| 77 |
+
history.append(gr.ChatMessage(role="system", content=START_SYSTEM_PROMPT))
|
| 78 |
+
|
| 79 |
+
history.append(gr.ChatMessage(role="user", content=user_input["text"]))
|
| 80 |
+
for fname in user_input.get("files") or []:
|
| 81 |
+
fname: str
|
| 82 |
+
if fname.endswith('.txt'):
|
| 83 |
+
with open(fname, 'r', encoding='utf8') as f:
|
| 84 |
+
history.append(gr.ChatMessage(role="user", content=f.read()))
|
| 85 |
+
yield gr.MultimodalTextbox(value=None, interactive=True), history
|
| 86 |
+
|
| 87 |
+
horizon = len(history)
|
| 88 |
+
inputs = {"messages": convert_history_for_graph_agent(history)}
|
| 89 |
+
|
| 90 |
+
graph = build_execution_graph()
|
| 91 |
+
|
| 92 |
+
history.append(gr.ChatMessage(role="assistant", content=""))
|
| 93 |
+
async for stream_mode, chunk in graph.astream(inputs, stream_mode=["messages", "tasks"]):
|
| 94 |
+
if stream_mode == "messages" and chunk[0].content:
|
| 95 |
+
for msg in chunk[0].content:
|
| 96 |
+
if 'text' in msg:
|
| 97 |
+
history[-1].content += msg["text"]
|
| 98 |
+
yield gr.MultimodalTextbox(value=None, interactive=True), history
|
| 99 |
+
|
| 100 |
+
elif stream_mode == "tasks" and chunk.get("name") == "tools" and chunk.get("error") is None:
|
| 101 |
+
if "input" in chunk:
|
| 102 |
+
for msg in format_tool_call(chunk):
|
| 103 |
+
history.append(msg)
|
| 104 |
+
yield gr.MultimodalTextbox(value=None, interactive=True), history
|
| 105 |
+
elif "result" in chunk:
|
| 106 |
+
for msg in format_tool_response(chunk):
|
| 107 |
+
history.append(msg)
|
| 108 |
+
yield gr.MultimodalTextbox(value=None, interactive=True), history
|
| 109 |
+
history.append(gr.ChatMessage(role="assistant", content=""))
|
| 110 |
+
|
| 111 |
+
for post_msg in generate_postscript_messages(history=history[horizon:]):
|
| 112 |
+
history.append(post_msg)
|
| 113 |
+
yield gr.MultimodalTextbox(value=None, interactive=True), history
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def send_feedback(
|
| 117 |
+
chat_context,
|
| 118 |
+
found_helpful,
|
| 119 |
+
will_recommend,
|
| 120 |
+
comments,
|
| 121 |
+
email
|
| 122 |
+
):
|
| 123 |
+
api = FeedbackApi()
|
| 124 |
+
total_submissions = 0
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
response = api(
|
| 128 |
+
context=chat_context,
|
| 129 |
+
found_helpful=found_helpful,
|
| 130 |
+
will_recommend=will_recommend,
|
| 131 |
+
comments=comments,
|
| 132 |
+
email=email
|
| 133 |
+
)
|
| 134 |
+
total_submissions = response.get("response", 0)
|
| 135 |
+
gr.Info("Thank you for submitting feedback")
|
| 136 |
+
except Exception as ex:
|
| 137 |
+
raise gr.Error(f"Error submitting feedback: {ex}")
|
| 138 |
+
return total_submissions
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def build_chat_app():
|
| 142 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="Chat") as demo:
|
| 143 |
+
|
| 144 |
+
gr.Markdown(
|
| 145 |
+
"""
|
| 146 |
+
<h1>Candid's AI assistant</h1>
|
| 147 |
+
|
| 148 |
+
<p>
|
| 149 |
+
Please read the <a
|
| 150 |
+
href='https://info.candid.org/chatbot-reference-guide'
|
| 151 |
+
target="_blank"
|
| 152 |
+
rel="noopener noreferrer"
|
| 153 |
+
>guide</a> to get started.
|
| 154 |
+
</p>
|
| 155 |
+
<hr>
|
| 156 |
+
"""
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
with gr.Column():
|
| 160 |
+
chatbot = gr.Chatbot(
|
| 161 |
+
label="AskCandid",
|
| 162 |
+
elem_id="chatbot",
|
| 163 |
+
editable="user",
|
| 164 |
+
avatar_images=(
|
| 165 |
+
None, # user
|
| 166 |
+
BOT_LOGO, # bot
|
| 167 |
+
),
|
| 168 |
+
height="60vh",
|
| 169 |
+
type="messages",
|
| 170 |
+
show_label=False,
|
| 171 |
+
show_copy_button=True,
|
| 172 |
+
autoscroll=True,
|
| 173 |
+
layout="panel",
|
| 174 |
+
)
|
| 175 |
+
msg = gr.MultimodalTextbox(label="Your message", interactive=True)
|
| 176 |
+
gr.ClearButton(components=[msg, chatbot], size="sm")
|
| 177 |
+
|
| 178 |
+
# pylint: disable=no-member
|
| 179 |
+
# chatbot.like(fn=like_callback, inputs=chatbot, outputs=None)
|
| 180 |
+
msg.submit(
|
| 181 |
+
fn=execute,
|
| 182 |
+
inputs=[msg, chatbot],
|
| 183 |
+
outputs=[msg, chatbot],
|
| 184 |
+
show_api=False
|
| 185 |
+
)
|
| 186 |
+
logged = LoggedComponents(context=chatbot)
|
| 187 |
+
|
| 188 |
+
return demo, logged
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def build_feedback(components: LoggedComponents) -> gr.Blocks:
|
| 192 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="Candid AI demo") as demo:
|
| 193 |
+
gr.Markdown("<h1>Help us improve this tool with your valuable feedback</h1>")
|
| 194 |
+
|
| 195 |
+
with gr.Row():
|
| 196 |
+
with gr.Column():
|
| 197 |
+
found_helpful = gr.Radio(
|
| 198 |
+
[True, False], label="Did you find what you were looking for?"
|
| 199 |
+
)
|
| 200 |
+
will_recommend = gr.Radio(
|
| 201 |
+
[True, False],
|
| 202 |
+
label="Will you recommend this Chatbot to others?",
|
| 203 |
+
)
|
| 204 |
+
comment = gr.Textbox(label="Additional comments (optional)", lines=4)
|
| 205 |
+
email = gr.Textbox(label="Your email (optional)", lines=1)
|
| 206 |
+
submit = gr.Button("Submit Feedback")
|
| 207 |
+
|
| 208 |
+
components["found_helpful"] = found_helpful
|
| 209 |
+
components["will_recommend"] = will_recommend
|
| 210 |
+
components["comments"] = comment
|
| 211 |
+
components["email"] = email
|
| 212 |
+
|
| 213 |
+
# pylint: disable=no-member
|
| 214 |
+
submit.click(
|
| 215 |
+
fn=send_feedback,
|
| 216 |
+
inputs=[
|
| 217 |
+
components["context"],
|
| 218 |
+
components["found_helpful"],
|
| 219 |
+
components["will_recommend"],
|
| 220 |
+
components["comments"],
|
| 221 |
+
components["email"]
|
| 222 |
+
],
|
| 223 |
+
outputs=None,
|
| 224 |
+
show_api=False,
|
| 225 |
+
api_name=False,
|
| 226 |
+
preprocess=False,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
return demo
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def build_app():
|
| 233 |
+
candid_chat, logger = build_chat_app()
|
| 234 |
+
feedback = build_feedback(logger)
|
| 235 |
+
|
| 236 |
+
with open(os.path.join(ROOT, "static", "chatStyle.css"), "r", encoding="utf8") as f:
|
| 237 |
+
css_chat = f.read()
|
| 238 |
+
|
| 239 |
+
demo = gr.TabbedInterface(
|
| 240 |
+
interface_list=[
|
| 241 |
+
candid_chat,
|
| 242 |
+
feedback
|
| 243 |
+
],
|
| 244 |
+
tab_names=[
|
| 245 |
+
"Candid's AI assistant",
|
| 246 |
+
"Feedback"
|
| 247 |
+
],
|
| 248 |
+
title="Candid's AI assistant",
|
| 249 |
+
theme=gr.themes.Soft(),
|
| 250 |
+
css=css_chat,
|
| 251 |
+
)
|
| 252 |
+
return demo
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
if __name__ == "__main__":
|
| 256 |
+
app = build_app()
|
| 257 |
+
app.queue(max_size=5).launch(
|
| 258 |
+
show_api=False,
|
| 259 |
+
mcp_server=False,
|
| 260 |
+
auth=[
|
| 261 |
+
(os.getenv("APP_USERNAME"), os.getenv("APP_PASSWORD")),
|
| 262 |
+
(os.getenv("APP_PUBLIC_USERNAME"), os.getenv("APP_PUBLIC_PASSWORD")),
|
| 263 |
+
],
|
| 264 |
+
auth_message="Login to Candid's AI assistant",
|
| 265 |
+
)
|
requirements.txt
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
boto3
|
| 2 |
elasticsearch==7.17.6
|
| 3 |
thefuzz
|
| 4 |
-
gradio==5.
|
| 5 |
-
langchain
|
| 6 |
-
langchain-aws
|
| 7 |
-
|
| 8 |
-
langgraph
|
| 9 |
pydantic==2.10.6
|
| 10 |
pyopenssl>22.0.0
|
| 11 |
python-dotenv
|
|
|
|
| 1 |
boto3
|
| 2 |
elasticsearch==7.17.6
|
| 3 |
thefuzz
|
| 4 |
+
gradio==5.42.0
|
| 5 |
+
langchain==0.3.27
|
| 6 |
+
langchain-aws==0.2.30
|
| 7 |
+
langgraph==0.6.5
|
| 8 |
+
langgraph-prebuilt==0.6.4
|
| 9 |
pydantic==2.10.6
|
| 10 |
pyopenssl>22.0.0
|
| 11 |
python-dotenv
|