Spaces:
Runtime error
Runtime error
Commit
·
1ce95c4
0
Parent(s):
Duplicate from hlydecker/Augmented-Retrieval-qa-ChatGPT
Browse files- .gitattributes +34 -0
- .gitignore +7 -0
- README.md +15 -0
- __init__.py +0 -0
- requirements.txt +0 -0
- static/__init__.py +0 -0
- static/mini_nttdata.jpg +0 -0
- streamlit_langchain_chat/__init__.py +1 -0
- streamlit_langchain_chat/__version__.py +1 -0
- streamlit_langchain_chat/constants.py +53 -0
- streamlit_langchain_chat/customized_langchain/__init__.py +10 -0
- streamlit_langchain_chat/customized_langchain/docstore/__init__.py +7 -0
- streamlit_langchain_chat/customized_langchain/docstore/in_memory.py +27 -0
- streamlit_langchain_chat/customized_langchain/indexes/__init__.py +7 -0
- streamlit_langchain_chat/customized_langchain/indexes/graph.py +20 -0
- streamlit_langchain_chat/customized_langchain/llms/__init__.py +1 -0
- streamlit_langchain_chat/customized_langchain/llms/openai.py +708 -0
- streamlit_langchain_chat/customized_langchain/vectorstores/__init__.py +8 -0
- streamlit_langchain_chat/customized_langchain/vectorstores/faiss.py +100 -0
- streamlit_langchain_chat/customized_langchain/vectorstores/pinecone.py +79 -0
- streamlit_langchain_chat/dataset.py +740 -0
- streamlit_langchain_chat/inputs/__init__.py +0 -0
- streamlit_langchain_chat/prompts.py +91 -0
- streamlit_langchain_chat/streamlit_app.py +561 -0
- streamlit_langchain_chat/utils.py +52 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
venv*/
|
2 |
+
tempDir/
|
3 |
+
.idea/
|
4 |
+
*.env
|
5 |
+
*.pkl
|
6 |
+
*.pickle
|
7 |
+
*testing*.py
|
README.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Augmented Retrieval Qa ChatGPT
|
3 |
+
emoji: 🚀
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: blue
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.19.0
|
8 |
+
app_file: streamlit_langchain_chat/streamlit_app.py
|
9 |
+
pinned: false
|
10 |
+
python_version: 3.10.4
|
11 |
+
license: cc-by-nc-sa-4.0
|
12 |
+
duplicated_from: hlydecker/Augmented-Retrieval-qa-ChatGPT
|
13 |
+
---
|
14 |
+
|
15 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
__init__.py
ADDED
File without changes
|
requirements.txt
ADDED
Binary file (5.03 kB). View file
|
|
static/__init__.py
ADDED
File without changes
|
static/mini_nttdata.jpg
ADDED
![]() |
streamlit_langchain_chat/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
streamlit_langchain_chat/__version__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__VERSION__ = "1.0.4"
|
streamlit_langchain_chat/constants.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
# from dotenv import load_dotenv, find_dotenv # pip install python-dotenv==1.0.0
|
5 |
+
|
6 |
+
from __version__ import __VERSION__ as APP_VERSION
|
7 |
+
|
8 |
+
_SCRIPT_PATH = Path(__file__).absolute()
|
9 |
+
PARENT_APP_DIR = _SCRIPT_PATH.parent
|
10 |
+
TEMP_DIR = PARENT_APP_DIR / 'tempDir'
|
11 |
+
ROOT_DIR = PARENT_APP_DIR.parent
|
12 |
+
STATIC_DIR = ROOT_DIR / 'static'
|
13 |
+
|
14 |
+
# _env_file_path = find_dotenv(str(CODE_DIR / '.env')) # Check if this path is correct
|
15 |
+
# if _env_file_path:
|
16 |
+
# load_dotenv(_env_file_path)
|
17 |
+
|
18 |
+
ST_CONFIG = {
|
19 |
+
"page_title": "NTT Data - Chat Q&A",
|
20 |
+
# "page_icon": Image.open(STATIC_DIR / "mini_nttdata.jpg"),
|
21 |
+
}
|
22 |
+
|
23 |
+
OPERATING_MODE = "debug" # debug, preproduction, production
|
24 |
+
|
25 |
+
REUSE_ANSWERS = False
|
26 |
+
|
27 |
+
LOAD_INDEX_LOCALLY = False
|
28 |
+
SAVE_INDEX_LOCALLY = False
|
29 |
+
|
30 |
+
# x$ per 1000 tokens
|
31 |
+
PRICES = {
|
32 |
+
'text-embedding-ada-002': 0.0004,
|
33 |
+
'text-davinci-003': 0.02,
|
34 |
+
'gpt-3': 0.002,
|
35 |
+
'gpt-4': 0.06, # 8K context
|
36 |
+
}
|
37 |
+
|
38 |
+
SOURCES_IDS = {
|
39 |
+
# "Without source. Only chat": 4,
|
40 |
+
"local files": 1,
|
41 |
+
"urls": 3
|
42 |
+
}
|
43 |
+
|
44 |
+
TYPE_IDS = {
|
45 |
+
"MSF Azure OpenAI Service": 1,
|
46 |
+
"OpenAI": 2,
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
INDEX_IDS = {
|
51 |
+
"FAISS": 1,
|
52 |
+
"Pinecode": 2,
|
53 |
+
}
|
streamlit_langchain_chat/customized_langchain/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from streamlit_langchain_chat.customized_langchain.docstore.in_memory import InMemoryDocstore
|
2 |
+
from streamlit_langchain_chat.customized_langchain.vectorstores import FAISS
|
3 |
+
from streamlit_langchain_chat.customized_langchain.vectorstores import Pinecone
|
4 |
+
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
"FAISS",
|
8 |
+
"InMemoryDocstore",
|
9 |
+
"Pinecone",
|
10 |
+
]
|
streamlit_langchain_chat/customized_langchain/docstore/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Wrappers on top of docstores."""
|
2 |
+
from streamlit_langchain_chat.customized_langchain.docstore.in_memory import InMemoryDocstore
|
3 |
+
|
4 |
+
|
5 |
+
__all__ = [
|
6 |
+
"InMemoryDocstore",
|
7 |
+
]
|
streamlit_langchain_chat/customized_langchain/docstore/in_memory.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Simple in memory docstore in the form of a dict."""
|
2 |
+
from typing import Dict, Union
|
3 |
+
|
4 |
+
from langchain.docstore.base import AddableMixin, Docstore
|
5 |
+
from langchain.docstore.document import Document
|
6 |
+
|
7 |
+
|
8 |
+
class InMemoryDocstore(Docstore, AddableMixin):
|
9 |
+
"""Simple in memory docstore in the form of a dict."""
|
10 |
+
|
11 |
+
def __init__(self, dict_: Dict[str, Document]):
|
12 |
+
"""Initialize with dict."""
|
13 |
+
self.dict_ = dict_
|
14 |
+
|
15 |
+
def add(self, texts: Dict[str, Document]) -> None:
|
16 |
+
"""Add texts to in memory dictionary."""
|
17 |
+
overlapping = set(texts).intersection(self.dict_)
|
18 |
+
if overlapping:
|
19 |
+
raise ValueError(f"Tried to add ids that already exist: {overlapping}")
|
20 |
+
self.dict_ = dict(self.dict_, **texts)
|
21 |
+
|
22 |
+
def search(self, search: str) -> Union[str, Document]:
|
23 |
+
"""Search via direct lookup."""
|
24 |
+
if search not in self.dict_:
|
25 |
+
return f"ID {search} not found."
|
26 |
+
else:
|
27 |
+
return self.dict_[search]
|
streamlit_langchain_chat/customized_langchain/indexes/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from streamlit_langchain_chat.customized_langchain.indexes.graph import GraphIndexCreator
|
2 |
+
# from streamlit_langchain_chat.customized_langchain.vectorstore import VectorstoreIndexCreator
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
"GraphIndexCreator",
|
6 |
+
# "VectorstoreIndexCreator"
|
7 |
+
]
|
streamlit_langchain_chat/customized_langchain/indexes/graph.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
from langchain.indexes.graph import *
|
4 |
+
from langchain.indexes.graph import GraphIndexCreator as OriginalGraphIndexCreator
|
5 |
+
|
6 |
+
|
7 |
+
class GraphIndexCreator(OriginalGraphIndexCreator):
|
8 |
+
def from_texts(self, texts: List[str]) -> NetworkxEntityGraph:
|
9 |
+
"""Create graph index from text."""
|
10 |
+
if self.llm is None:
|
11 |
+
raise ValueError("llm should not be None")
|
12 |
+
graph = self.graph_type()
|
13 |
+
chain = LLMChain(llm=self.llm, prompt=KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT)
|
14 |
+
|
15 |
+
for text in texts:
|
16 |
+
output = chain.predict(text=text)
|
17 |
+
knowledge = parse_triples(output)
|
18 |
+
for triple in knowledge:
|
19 |
+
graph.add_triple(triple)
|
20 |
+
return graph
|
streamlit_langchain_chat/customized_langchain/llms/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from streamlit_langchain_chat.customized_langchain.llms.openai import AzureOpenAI, OpenAI, OpenAIChat, AzureOpenAIChat
|
streamlit_langchain_chat/customized_langchain/llms/openai.py
ADDED
@@ -0,0 +1,708 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Wrapper around OpenAI APIs."""
|
2 |
+
from __future__ import annotations
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import sys
|
6 |
+
from typing import (
|
7 |
+
Any,
|
8 |
+
Callable,
|
9 |
+
Dict,
|
10 |
+
Generator,
|
11 |
+
List,
|
12 |
+
Mapping,
|
13 |
+
Optional,
|
14 |
+
Set,
|
15 |
+
Tuple,
|
16 |
+
Union,
|
17 |
+
)
|
18 |
+
|
19 |
+
from pydantic import BaseModel, Extra, Field, root_validator
|
20 |
+
from tenacity import (
|
21 |
+
before_sleep_log,
|
22 |
+
retry,
|
23 |
+
retry_if_exception_type,
|
24 |
+
stop_after_attempt,
|
25 |
+
wait_exponential,
|
26 |
+
)
|
27 |
+
|
28 |
+
from langchain.llms.base import BaseLLM
|
29 |
+
from langchain.schema import Generation, LLMResult
|
30 |
+
from langchain.utils import get_from_dict_or_env
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
def update_token_usage(
|
36 |
+
keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
|
37 |
+
) -> None:
|
38 |
+
"""Update token usage."""
|
39 |
+
_keys_to_use = keys.intersection(response["usage"])
|
40 |
+
for _key in _keys_to_use:
|
41 |
+
if _key not in token_usage:
|
42 |
+
token_usage[_key] = response["usage"][_key]
|
43 |
+
else:
|
44 |
+
token_usage[_key] += response["usage"][_key]
|
45 |
+
|
46 |
+
|
47 |
+
def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
|
48 |
+
"""Update response from the stream response."""
|
49 |
+
response["choices"][0]["text"] += stream_response["choices"][0]["text"]
|
50 |
+
response["choices"][0]["finish_reason"] = stream_response["choices"][0][
|
51 |
+
"finish_reason"
|
52 |
+
]
|
53 |
+
response["choices"][0]["logprobs"] = stream_response["choices"][0]["logprobs"]
|
54 |
+
|
55 |
+
|
56 |
+
def _streaming_response_template() -> Dict[str, Any]:
|
57 |
+
return {
|
58 |
+
"choices": [
|
59 |
+
{
|
60 |
+
"text": "",
|
61 |
+
"finish_reason": None,
|
62 |
+
"logprobs": None,
|
63 |
+
}
|
64 |
+
]
|
65 |
+
}
|
66 |
+
|
67 |
+
|
68 |
+
def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any], Any]:
|
69 |
+
import openai
|
70 |
+
|
71 |
+
min_seconds = 4
|
72 |
+
max_seconds = 10
|
73 |
+
# Wait 2^x * 1 second between each retry starting with
|
74 |
+
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
75 |
+
return retry(
|
76 |
+
reraise=True,
|
77 |
+
stop=stop_after_attempt(llm.max_retries),
|
78 |
+
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
79 |
+
retry=(
|
80 |
+
retry_if_exception_type(openai.error.Timeout)
|
81 |
+
| retry_if_exception_type(openai.error.APIError)
|
82 |
+
| retry_if_exception_type(openai.error.APIConnectionError)
|
83 |
+
| retry_if_exception_type(openai.error.RateLimitError)
|
84 |
+
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
85 |
+
),
|
86 |
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
87 |
+
)
|
88 |
+
|
89 |
+
|
90 |
+
def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) -> Any:
|
91 |
+
"""Use tenacity to retry the completion call."""
|
92 |
+
retry_decorator = _create_retry_decorator(llm)
|
93 |
+
|
94 |
+
@retry_decorator
|
95 |
+
def _completion_with_retry(**kwargs: Any) -> Any:
|
96 |
+
return llm.client.create(**kwargs)
|
97 |
+
|
98 |
+
return _completion_with_retry(**kwargs)
|
99 |
+
|
100 |
+
|
101 |
+
async def acompletion_with_retry(
|
102 |
+
llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any
|
103 |
+
) -> Any:
|
104 |
+
"""Use tenacity to retry the async completion call."""
|
105 |
+
retry_decorator = _create_retry_decorator(llm)
|
106 |
+
|
107 |
+
@retry_decorator
|
108 |
+
async def _completion_with_retry(**kwargs: Any) -> Any:
|
109 |
+
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
|
110 |
+
return await llm.client.acreate(**kwargs)
|
111 |
+
|
112 |
+
return await _completion_with_retry(**kwargs)
|
113 |
+
|
114 |
+
|
115 |
+
class BaseOpenAI(BaseLLM, BaseModel):
|
116 |
+
"""Wrapper around OpenAI large language models.
|
117 |
+
|
118 |
+
To use, you should have the ``openai`` python package installed, and the
|
119 |
+
environment variable ``OPENAI_API_KEY`` set with your API key.
|
120 |
+
|
121 |
+
Any parameters that are valid to be passed to the openai.create call can be passed
|
122 |
+
in, even if not explicitly saved on this class.
|
123 |
+
|
124 |
+
Example:
|
125 |
+
.. code-block:: python
|
126 |
+
|
127 |
+
from langchain.llms import OpenAI
|
128 |
+
openai = OpenAI(model_name="text-davinci-003")
|
129 |
+
"""
|
130 |
+
|
131 |
+
client: Any #: :meta private:
|
132 |
+
model_name: str = "text-davinci-003"
|
133 |
+
"""Model name to use."""
|
134 |
+
temperature: float = 0.7
|
135 |
+
"""What sampling temperature to use."""
|
136 |
+
max_tokens: int = 256
|
137 |
+
"""The maximum number of tokens to generate in the completion.
|
138 |
+
-1 returns as many tokens as possible given the prompt and
|
139 |
+
the models maximal context size."""
|
140 |
+
top_p: float = 1
|
141 |
+
"""Total probability mass of tokens to consider at each step."""
|
142 |
+
frequency_penalty: float = 0
|
143 |
+
"""Penalizes repeated tokens according to frequency."""
|
144 |
+
presence_penalty: float = 0
|
145 |
+
"""Penalizes repeated tokens."""
|
146 |
+
n: int = 1
|
147 |
+
"""How many completions to generate for each prompt."""
|
148 |
+
best_of: int = 1
|
149 |
+
"""Generates best_of completions server-side and returns the "best"."""
|
150 |
+
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
151 |
+
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
152 |
+
openai_api_key: Optional[str] = None
|
153 |
+
batch_size: int = 20
|
154 |
+
"""Batch size to use when passing multiple documents to generate."""
|
155 |
+
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
156 |
+
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
157 |
+
logit_bias: Optional[Dict[str, float]] = Field(default_factory=dict)
|
158 |
+
"""Adjust the probability of specific tokens being generated."""
|
159 |
+
max_retries: int = 6
|
160 |
+
"""Maximum number of retries to make when generating."""
|
161 |
+
streaming: bool = False
|
162 |
+
"""Whether to stream the results or not."""
|
163 |
+
|
164 |
+
def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
|
165 |
+
"""Initialize the OpenAI object."""
|
166 |
+
if data.get("model_name", "").startswith("gpt-3.5-turbo"):
|
167 |
+
return OpenAIChat(**data)
|
168 |
+
return super().__new__(cls)
|
169 |
+
|
170 |
+
class Config:
|
171 |
+
"""Configuration for this pydantic object."""
|
172 |
+
|
173 |
+
extra = Extra.ignore
|
174 |
+
|
175 |
+
@root_validator(pre=True, allow_reuse=True)
|
176 |
+
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
177 |
+
"""Build extra kwargs from additional params that were passed in."""
|
178 |
+
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
179 |
+
|
180 |
+
extra = values.get("model_kwargs", {})
|
181 |
+
for field_name in list(values):
|
182 |
+
if field_name not in all_required_field_names:
|
183 |
+
if field_name in extra:
|
184 |
+
raise ValueError(f"Found {field_name} supplied twice.")
|
185 |
+
logger.warning(
|
186 |
+
f"""WARNING! {field_name} is not default parameter.
|
187 |
+
{field_name} was transfered to model_kwargs.
|
188 |
+
Please confirm that {field_name} is what you intended."""
|
189 |
+
)
|
190 |
+
extra[field_name] = values.pop(field_name)
|
191 |
+
values["model_kwargs"] = extra
|
192 |
+
return values
|
193 |
+
|
194 |
+
@root_validator(allow_reuse=True)
|
195 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
196 |
+
"""Validate that api key and python package exists in environment."""
|
197 |
+
openai_api_key = get_from_dict_or_env(
|
198 |
+
values, "openai_api_key", "OPENAI_API_KEY"
|
199 |
+
)
|
200 |
+
try:
|
201 |
+
import openai
|
202 |
+
|
203 |
+
openai.api_key = openai_api_key
|
204 |
+
values["client"] = openai.Completion
|
205 |
+
except ImportError:
|
206 |
+
raise ValueError(
|
207 |
+
"Could not import openai python package. "
|
208 |
+
"Please it install it with `pip install openai`."
|
209 |
+
)
|
210 |
+
if values["streaming"] and values["n"] > 1:
|
211 |
+
raise ValueError("Cannot stream results when n > 1.")
|
212 |
+
if values["streaming"] and values.get("best_of") and values["best_of"] > 1:
|
213 |
+
raise ValueError("Cannot stream results when best_of > 1.")
|
214 |
+
return values
|
215 |
+
|
216 |
+
@property
|
217 |
+
def _default_params(self) -> Dict[str, Any]:
|
218 |
+
"""Get the default parameters for calling OpenAI API."""
|
219 |
+
normal_params = {
|
220 |
+
"temperature": self.temperature,
|
221 |
+
"max_tokens": self.max_tokens,
|
222 |
+
"top_p": self.top_p,
|
223 |
+
"frequency_penalty": self.frequency_penalty,
|
224 |
+
"presence_penalty": self.presence_penalty,
|
225 |
+
"n": self.n,
|
226 |
+
# "best_of": self.best_of,
|
227 |
+
"request_timeout": self.request_timeout,
|
228 |
+
"logit_bias": self.logit_bias,
|
229 |
+
}
|
230 |
+
return {**normal_params, **self.model_kwargs}
|
231 |
+
|
232 |
+
def _generate(
|
233 |
+
self, prompts: List[str], stop: Optional[List[str]] = None
|
234 |
+
) -> LLMResult:
|
235 |
+
"""Call out to OpenAI's endpoint with k unique prompts.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
prompts: The prompts to pass into the model.
|
239 |
+
stop: Optional list of stop words to use when generating.
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
The full LLM output.
|
243 |
+
|
244 |
+
Example:
|
245 |
+
.. code-block:: python
|
246 |
+
|
247 |
+
response = openai.generate(["Tell me a joke."])
|
248 |
+
"""
|
249 |
+
# TODO: write a unit test for this
|
250 |
+
params = self._invocation_params
|
251 |
+
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
252 |
+
choices = []
|
253 |
+
token_usage: Dict[str, int] = {}
|
254 |
+
# Get the token usage from the response.
|
255 |
+
# Includes prompt, completion, and total tokens used.
|
256 |
+
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
257 |
+
for _prompts in sub_prompts:
|
258 |
+
if self.streaming:
|
259 |
+
if len(_prompts) > 1:
|
260 |
+
raise ValueError("Cannot stream results with multiple prompts.")
|
261 |
+
params["stream"] = True
|
262 |
+
response = _streaming_response_template()
|
263 |
+
for stream_resp in completion_with_retry(
|
264 |
+
self, prompt=_prompts, **params
|
265 |
+
):
|
266 |
+
self.callback_manager.on_llm_new_token(
|
267 |
+
stream_resp["choices"][0]["text"],
|
268 |
+
verbose=self.verbose,
|
269 |
+
logprobs=stream_resp["choices"][0]["logprobs"],
|
270 |
+
)
|
271 |
+
_update_response(response, stream_resp)
|
272 |
+
choices.extend(response["choices"])
|
273 |
+
else:
|
274 |
+
response = completion_with_retry(self, prompt=_prompts, **params)
|
275 |
+
choices.extend(response["choices"])
|
276 |
+
if not self.streaming:
|
277 |
+
# Can't update token usage if streaming
|
278 |
+
update_token_usage(_keys, response, token_usage)
|
279 |
+
return self.create_llm_result(choices, prompts, token_usage)
|
280 |
+
|
281 |
+
async def _agenerate(
|
282 |
+
self, prompts: List[str], stop: Optional[List[str]] = None
|
283 |
+
) -> LLMResult:
|
284 |
+
"""Call out to OpenAI's endpoint async with k unique prompts."""
|
285 |
+
params = self._invocation_params
|
286 |
+
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
287 |
+
choices = []
|
288 |
+
token_usage: Dict[str, int] = {}
|
289 |
+
# Get the token usage from the response.
|
290 |
+
# Includes prompt, completion, and total tokens used.
|
291 |
+
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
292 |
+
for _prompts in sub_prompts:
|
293 |
+
if self.streaming:
|
294 |
+
if len(_prompts) > 1:
|
295 |
+
raise ValueError("Cannot stream results with multiple prompts.")
|
296 |
+
params["stream"] = True
|
297 |
+
response = _streaming_response_template()
|
298 |
+
async for stream_resp in await acompletion_with_retry(
|
299 |
+
self, prompt=_prompts, **params
|
300 |
+
):
|
301 |
+
if self.callback_manager.is_async:
|
302 |
+
await self.callback_manager.on_llm_new_token(
|
303 |
+
stream_resp["choices"][0]["text"],
|
304 |
+
verbose=self.verbose,
|
305 |
+
logprobs=stream_resp["choices"][0]["logprobs"],
|
306 |
+
)
|
307 |
+
else:
|
308 |
+
self.callback_manager.on_llm_new_token(
|
309 |
+
stream_resp["choices"][0]["text"],
|
310 |
+
verbose=self.verbose,
|
311 |
+
logprobs=stream_resp["choices"][0]["logprobs"],
|
312 |
+
)
|
313 |
+
_update_response(response, stream_resp)
|
314 |
+
choices.extend(response["choices"])
|
315 |
+
else:
|
316 |
+
response = await acompletion_with_retry(self, prompt=_prompts, **params)
|
317 |
+
choices.extend(response["choices"])
|
318 |
+
if not self.streaming:
|
319 |
+
# Can't update token usage if streaming
|
320 |
+
update_token_usage(_keys, response, token_usage)
|
321 |
+
return self.create_llm_result(choices, prompts, token_usage)
|
322 |
+
|
323 |
+
def get_sub_prompts(
|
324 |
+
self,
|
325 |
+
params: Dict[str, Any],
|
326 |
+
prompts: List[str],
|
327 |
+
stop: Optional[List[str]] = None,
|
328 |
+
) -> List[List[str]]:
|
329 |
+
"""Get the sub prompts for llm call."""
|
330 |
+
if stop is not None:
|
331 |
+
if "stop" in params:
|
332 |
+
raise ValueError("`stop` found in both the input and default params.")
|
333 |
+
params["stop"] = stop
|
334 |
+
if params["max_tokens"] == -1:
|
335 |
+
if len(prompts) != 1:
|
336 |
+
raise ValueError(
|
337 |
+
"max_tokens set to -1 not supported for multiple inputs."
|
338 |
+
)
|
339 |
+
params["max_tokens"] = self.max_tokens_for_prompt(prompts[0])
|
340 |
+
sub_prompts = [
|
341 |
+
prompts[i : i + self.batch_size]
|
342 |
+
for i in range(0, len(prompts), self.batch_size)
|
343 |
+
]
|
344 |
+
return sub_prompts
|
345 |
+
|
346 |
+
def create_llm_result(
|
347 |
+
self, choices: Any, prompts: List[str], token_usage: Dict[str, int]
|
348 |
+
) -> LLMResult:
|
349 |
+
"""Create the LLMResult from the choices and prompts."""
|
350 |
+
generations = []
|
351 |
+
for i, _ in enumerate(prompts):
|
352 |
+
sub_choices = choices[i * self.n : (i + 1) * self.n]
|
353 |
+
generations.append(
|
354 |
+
[
|
355 |
+
Generation(
|
356 |
+
text=choice["text"],
|
357 |
+
generation_info=dict(
|
358 |
+
finish_reason=choice.get("finish_reason"),
|
359 |
+
logprobs=choice.get("logprobs"),
|
360 |
+
),
|
361 |
+
)
|
362 |
+
for choice in sub_choices
|
363 |
+
]
|
364 |
+
)
|
365 |
+
return LLMResult(
|
366 |
+
generations=generations, llm_output={"token_usage": token_usage}
|
367 |
+
)
|
368 |
+
|
369 |
+
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
|
370 |
+
"""Call OpenAI with streaming flag and return the resulting generator.
|
371 |
+
|
372 |
+
BETA: this is a beta feature while we figure out the right abstraction.
|
373 |
+
Once that happens, this interface could change.
|
374 |
+
|
375 |
+
Args:
|
376 |
+
prompt: The prompts to pass into the model.
|
377 |
+
stop: Optional list of stop words to use when generating.
|
378 |
+
|
379 |
+
Returns:
|
380 |
+
A generator representing the stream of tokens from OpenAI.
|
381 |
+
|
382 |
+
Example:
|
383 |
+
.. code-block:: python
|
384 |
+
|
385 |
+
generator = openai.stream("Tell me a joke.")
|
386 |
+
for token in generator:
|
387 |
+
yield token
|
388 |
+
"""
|
389 |
+
params = self.prep_streaming_params(stop)
|
390 |
+
generator = self.client.create(prompt=prompt, **params)
|
391 |
+
|
392 |
+
return generator
|
393 |
+
|
394 |
+
def prep_streaming_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
|
395 |
+
"""Prepare the params for streaming."""
|
396 |
+
params = self._invocation_params
|
397 |
+
if params.get('best_of') and params["best_of"] != 1:
|
398 |
+
raise ValueError("OpenAI only supports best_of == 1 for streaming")
|
399 |
+
if stop is not None:
|
400 |
+
if "stop" in params:
|
401 |
+
raise ValueError("`stop` found in both the input and default params.")
|
402 |
+
params["stop"] = stop
|
403 |
+
params["stream"] = True
|
404 |
+
return params
|
405 |
+
|
406 |
+
@property
|
407 |
+
def _invocation_params(self) -> Dict[str, Any]:
|
408 |
+
"""Get the parameters used to invoke the model."""
|
409 |
+
return self._default_params
|
410 |
+
|
411 |
+
@property
|
412 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
413 |
+
"""Get the identifying parameters."""
|
414 |
+
return {**{"model_name": self.model_name}, **self._default_params}
|
415 |
+
|
416 |
+
@property
|
417 |
+
def _llm_type(self) -> str:
|
418 |
+
"""Return type of llm."""
|
419 |
+
return "openai"
|
420 |
+
|
421 |
+
def get_num_tokens(self, text: str) -> int:
|
422 |
+
"""Calculate num tokens with tiktoken package."""
|
423 |
+
# tiktoken NOT supported for Python 3.8 or below
|
424 |
+
if sys.version_info[1] <= 8:
|
425 |
+
return super().get_num_tokens(text)
|
426 |
+
try:
|
427 |
+
import tiktoken
|
428 |
+
except ImportError:
|
429 |
+
raise ValueError(
|
430 |
+
"Could not import tiktoken python package. "
|
431 |
+
"This is needed in order to calculate get_num_tokens. "
|
432 |
+
"Please it install it with `pip install tiktoken`."
|
433 |
+
)
|
434 |
+
encoder = "gpt2"
|
435 |
+
if self.model_name in ("text-davinci-003", "text-davinci-002"):
|
436 |
+
encoder = "p50k_base"
|
437 |
+
if self.model_name.startswith("code"):
|
438 |
+
encoder = "p50k_base"
|
439 |
+
# create a GPT-3 encoder instance
|
440 |
+
enc = tiktoken.get_encoding(encoder)
|
441 |
+
|
442 |
+
# encode the text using the GPT-3 encoder
|
443 |
+
tokenized_text = enc.encode(text)
|
444 |
+
|
445 |
+
# calculate the number of tokens in the encoded text
|
446 |
+
return len(tokenized_text)
|
447 |
+
|
448 |
+
def modelname_to_contextsize(self, modelname: str) -> int:
|
449 |
+
"""Calculate the maximum number of tokens possible to generate for a model.
|
450 |
+
|
451 |
+
text-davinci-003: 4,097 tokens
|
452 |
+
text-curie-001: 2,048 tokens
|
453 |
+
text-babbage-001: 2,048 tokens
|
454 |
+
text-ada-001: 2,048 tokens
|
455 |
+
code-davinci-002: 8,000 tokens
|
456 |
+
code-cushman-001: 2,048 tokens
|
457 |
+
|
458 |
+
Args:
|
459 |
+
modelname: The modelname we want to know the context size for.
|
460 |
+
|
461 |
+
Returns:
|
462 |
+
The maximum context size
|
463 |
+
|
464 |
+
Example:
|
465 |
+
.. code-block:: python
|
466 |
+
|
467 |
+
max_tokens = openai.modelname_to_contextsize("text-davinci-003")
|
468 |
+
"""
|
469 |
+
if modelname == "text-davinci-003":
|
470 |
+
return 4097
|
471 |
+
elif modelname == "text-curie-001":
|
472 |
+
return 2048
|
473 |
+
elif modelname == "text-babbage-001":
|
474 |
+
return 2048
|
475 |
+
elif modelname == "text-ada-001":
|
476 |
+
return 2048
|
477 |
+
elif modelname == "code-davinci-002":
|
478 |
+
return 8000
|
479 |
+
elif modelname == "code-cushman-001":
|
480 |
+
return 2048
|
481 |
+
else:
|
482 |
+
return 4097
|
483 |
+
|
484 |
+
def max_tokens_for_prompt(self, prompt: str) -> int:
|
485 |
+
"""Calculate the maximum number of tokens possible to generate for a prompt.
|
486 |
+
|
487 |
+
Args:
|
488 |
+
prompt: The prompt to pass into the model.
|
489 |
+
|
490 |
+
Returns:
|
491 |
+
The maximum number of tokens to generate for a prompt.
|
492 |
+
|
493 |
+
Example:
|
494 |
+
.. code-block:: python
|
495 |
+
|
496 |
+
max_tokens = openai.max_token_for_prompt("Tell me a joke.")
|
497 |
+
"""
|
498 |
+
num_tokens = self.get_num_tokens(prompt)
|
499 |
+
|
500 |
+
# get max context size for model by name
|
501 |
+
max_size = self.modelname_to_contextsize(self.model_name)
|
502 |
+
return max_size - num_tokens
|
503 |
+
|
504 |
+
|
505 |
+
class OpenAI(BaseOpenAI):
|
506 |
+
"""Generic OpenAI class that uses model name."""
|
507 |
+
|
508 |
+
@property
|
509 |
+
def _invocation_params(self) -> Dict[str, Any]:
|
510 |
+
return {**{"model": self.model_name}, **super()._invocation_params}
|
511 |
+
|
512 |
+
|
513 |
+
class AzureOpenAI(BaseOpenAI):
|
514 |
+
"""Azure specific OpenAI class that uses deployment name."""
|
515 |
+
|
516 |
+
deployment_name: str = ""
|
517 |
+
"""Deployment name to use."""
|
518 |
+
|
519 |
+
@property
|
520 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
521 |
+
return {
|
522 |
+
**{"deployment_name": self.deployment_name},
|
523 |
+
**super()._identifying_params,
|
524 |
+
}
|
525 |
+
|
526 |
+
@property
|
527 |
+
def _invocation_params(self) -> Dict[str, Any]:
|
528 |
+
return {**{"engine": self.deployment_name}, **super()._invocation_params}
|
529 |
+
|
530 |
+
|
531 |
+
class OpenAIChat(BaseLLM, BaseModel):
|
532 |
+
"""Wrapper around OpenAI Chat large language models.
|
533 |
+
|
534 |
+
To use, you should have the ``openai`` python package installed, and the
|
535 |
+
environment variable ``OPENAI_API_KEY`` set with your API key.
|
536 |
+
|
537 |
+
Any parameters that are valid to be passed to the openai.create call can be passed
|
538 |
+
in, even if not explicitly saved on this class.
|
539 |
+
|
540 |
+
Example:
|
541 |
+
.. code-block:: python
|
542 |
+
|
543 |
+
from langchain.llms import OpenAIChat
|
544 |
+
openaichat = OpenAIChat(model_name="gpt-3.5-turbo")
|
545 |
+
"""
|
546 |
+
|
547 |
+
client: Any #: :meta private:
|
548 |
+
model_name: str = "gpt-3.5-turbo"
|
549 |
+
"""Model name to use."""
|
550 |
+
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
551 |
+
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
552 |
+
openai_api_key: Optional[str] = None
|
553 |
+
max_retries: int = 6
|
554 |
+
"""Maximum number of retries to make when generating."""
|
555 |
+
prefix_messages: List = Field(default_factory=list)
|
556 |
+
"""Series of messages for Chat input."""
|
557 |
+
streaming: bool = False
|
558 |
+
"""Whether to stream the results or not."""
|
559 |
+
|
560 |
+
class Config:
|
561 |
+
"""Configuration for this pydantic object."""
|
562 |
+
|
563 |
+
extra = Extra.ignore
|
564 |
+
|
565 |
+
@root_validator(pre=True, allow_reuse=True)
|
566 |
+
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
567 |
+
"""Build extra kwargs from additional params that were passed in."""
|
568 |
+
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
569 |
+
|
570 |
+
extra = values.get("model_kwargs", {})
|
571 |
+
for field_name in list(values):
|
572 |
+
if field_name not in all_required_field_names:
|
573 |
+
if field_name in extra:
|
574 |
+
raise ValueError(f"Found {field_name} supplied twice.")
|
575 |
+
extra[field_name] = values.pop(field_name)
|
576 |
+
values["model_kwargs"] = extra
|
577 |
+
return values
|
578 |
+
|
579 |
+
@root_validator(allow_reuse=True)
|
580 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
581 |
+
"""Validate that api key and python package exists in environment."""
|
582 |
+
openai_api_key = get_from_dict_or_env(
|
583 |
+
values, "openai_api_key", "OPENAI_API_KEY"
|
584 |
+
)
|
585 |
+
try:
|
586 |
+
import openai
|
587 |
+
|
588 |
+
openai.api_key = openai_api_key
|
589 |
+
except ImportError:
|
590 |
+
raise ValueError(
|
591 |
+
"Could not import openai python package. "
|
592 |
+
"Please it install it with `pip install openai`."
|
593 |
+
)
|
594 |
+
try:
|
595 |
+
values["client"] = openai.ChatCompletion
|
596 |
+
except AttributeError:
|
597 |
+
raise ValueError(
|
598 |
+
"`openai` has no `ChatCompletion` attribute, this is likely "
|
599 |
+
"due to an old version of the openai package. Try upgrading it "
|
600 |
+
"with `pip install --upgrade openai`."
|
601 |
+
)
|
602 |
+
return values
|
603 |
+
|
604 |
+
@property
|
605 |
+
def _default_params(self) -> Dict[str, Any]:
|
606 |
+
"""Get the default parameters for calling OpenAI API."""
|
607 |
+
return self.model_kwargs
|
608 |
+
|
609 |
+
def _get_chat_params(
|
610 |
+
self, prompts: List[str], stop: Optional[List[str]] = None
|
611 |
+
) -> Tuple:
|
612 |
+
if len(prompts) > 1:
|
613 |
+
raise ValueError(
|
614 |
+
f"OpenAIChat currently only supports single prompt, got {prompts}"
|
615 |
+
)
|
616 |
+
messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
|
617 |
+
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
|
618 |
+
if stop is not None:
|
619 |
+
if "stop" in params:
|
620 |
+
raise ValueError("`stop` found in both the input and default params.")
|
621 |
+
params["stop"] = stop
|
622 |
+
return messages, params
|
623 |
+
|
624 |
+
def _generate(
|
625 |
+
self, prompts: List[str], stop: Optional[List[str]] = None
|
626 |
+
) -> LLMResult:
|
627 |
+
messages, params = self._get_chat_params(prompts, stop)
|
628 |
+
if self.streaming:
|
629 |
+
response = ""
|
630 |
+
params["stream"] = True
|
631 |
+
for stream_resp in completion_with_retry(self, messages=messages, **params):
|
632 |
+
token = stream_resp["choices"][0]["delta"].get("content", "")
|
633 |
+
response += token
|
634 |
+
self.callback_manager.on_llm_new_token(
|
635 |
+
token,
|
636 |
+
verbose=self.verbose,
|
637 |
+
)
|
638 |
+
return LLMResult(
|
639 |
+
generations=[[Generation(text=response)]],
|
640 |
+
)
|
641 |
+
else:
|
642 |
+
full_response = completion_with_retry(self, messages=messages, **params)
|
643 |
+
return LLMResult(
|
644 |
+
generations=[
|
645 |
+
[Generation(text=full_response["choices"][0]["message"]["content"])]
|
646 |
+
],
|
647 |
+
llm_output={"token_usage": full_response["usage"]},
|
648 |
+
)
|
649 |
+
|
650 |
+
async def _agenerate(
|
651 |
+
self, prompts: List[str], stop: Optional[List[str]] = None
|
652 |
+
) -> LLMResult:
|
653 |
+
messages, params = self._get_chat_params(prompts, stop)
|
654 |
+
if self.streaming:
|
655 |
+
response = ""
|
656 |
+
params["stream"] = True
|
657 |
+
async for stream_resp in await acompletion_with_retry(
|
658 |
+
self, messages=messages, **params
|
659 |
+
):
|
660 |
+
token = stream_resp["choices"][0]["delta"].get("content", "")
|
661 |
+
response += token
|
662 |
+
if self.callback_manager.is_async:
|
663 |
+
await self.callback_manager.on_llm_new_token(
|
664 |
+
token,
|
665 |
+
verbose=self.verbose,
|
666 |
+
)
|
667 |
+
else:
|
668 |
+
self.callback_manager.on_llm_new_token(
|
669 |
+
token,
|
670 |
+
verbose=self.verbose,
|
671 |
+
)
|
672 |
+
return LLMResult(
|
673 |
+
generations=[[Generation(text=response)]],
|
674 |
+
)
|
675 |
+
else:
|
676 |
+
full_response = await acompletion_with_retry(
|
677 |
+
self, messages=messages, **params
|
678 |
+
)
|
679 |
+
return LLMResult(
|
680 |
+
generations=[
|
681 |
+
[Generation(text=full_response["choices"][0]["message"]["content"])]
|
682 |
+
],
|
683 |
+
llm_output={"token_usage": full_response["usage"]},
|
684 |
+
)
|
685 |
+
|
686 |
+
@property
|
687 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
688 |
+
"""Get the identifying parameters."""
|
689 |
+
return {**{"model_name": self.model_name}, **self._default_params}
|
690 |
+
|
691 |
+
@property
|
692 |
+
def _llm_type(self) -> str:
|
693 |
+
"""Return type of llm."""
|
694 |
+
return "openai-chat"
|
695 |
+
|
696 |
+
|
697 |
+
class AzureOpenAIChat(OpenAIChat):
|
698 |
+
"""Azure specific OpenAI class that uses deployment name."""
|
699 |
+
|
700 |
+
deployment_name: str = ""
|
701 |
+
"""Deployment name to use."""
|
702 |
+
|
703 |
+
@property
|
704 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
705 |
+
return {
|
706 |
+
**{"deployment_name": self.deployment_name},
|
707 |
+
**super()._identifying_params,
|
708 |
+
}
|
streamlit_langchain_chat/customized_langchain/vectorstores/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Wrappers on top of vector stores."""
|
2 |
+
from streamlit_langchain_chat.customized_langchain.vectorstores.faiss import FAISS
|
3 |
+
from streamlit_langchain_chat.customized_langchain.vectorstores.pinecone import Pinecone
|
4 |
+
|
5 |
+
__all__ = [
|
6 |
+
"FAISS",
|
7 |
+
"Pinecone",
|
8 |
+
]
|
streamlit_langchain_chat/customized_langchain/vectorstores/faiss.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import hashlib
|
2 |
+
|
3 |
+
from langchain.vectorstores.faiss import *
|
4 |
+
from langchain.vectorstores.faiss import FAISS as OriginalFAISS
|
5 |
+
|
6 |
+
from streamlit_langchain_chat.customized_langchain.docstore.in_memory import InMemoryDocstore
|
7 |
+
|
8 |
+
|
9 |
+
class FAISS(OriginalFAISS):
|
10 |
+
def __add(
|
11 |
+
self,
|
12 |
+
texts: Iterable[str],
|
13 |
+
embeddings: Iterable[List[float]],
|
14 |
+
metadatas: Optional[List[dict]] = None,
|
15 |
+
**kwargs: Any,
|
16 |
+
) -> List[str]:
|
17 |
+
if not isinstance(self.docstore, AddableMixin):
|
18 |
+
raise ValueError(
|
19 |
+
"If trying to add texts, the underlying docstore should support "
|
20 |
+
f"adding items, which {self.docstore} does not"
|
21 |
+
)
|
22 |
+
documents = []
|
23 |
+
for i, text in enumerate(texts):
|
24 |
+
metadata = metadatas[i] if metadatas else {}
|
25 |
+
documents.append(Document(page_content=text, metadata=metadata))
|
26 |
+
# Add to the index, the index_to_id mapping, and the docstore.
|
27 |
+
starting_len = len(self.index_to_docstore_id)
|
28 |
+
self.index.add(np.array(embeddings, dtype=np.float32))
|
29 |
+
# Get list of index, id, and docs.
|
30 |
+
full_info = [
|
31 |
+
(starting_len + i, str(uuid.uuid4()), doc)
|
32 |
+
for i, doc in enumerate(documents)
|
33 |
+
]
|
34 |
+
# Add information to docstore and index.
|
35 |
+
self.docstore.add({_id: doc for _, _id, doc in full_info})
|
36 |
+
index_to_id = {index: _id for index, _id, _ in full_info}
|
37 |
+
self.index_to_docstore_id.update(index_to_id)
|
38 |
+
return [_id for _, _id, _ in full_info]
|
39 |
+
|
40 |
+
@classmethod
|
41 |
+
def __from(
|
42 |
+
cls,
|
43 |
+
texts: List[str],
|
44 |
+
embeddings: List[List[float]],
|
45 |
+
embedding: Embeddings,
|
46 |
+
metadatas: Optional[List[dict]] = None,
|
47 |
+
**kwargs: Any,
|
48 |
+
) -> FAISS:
|
49 |
+
faiss = dependable_faiss_import()
|
50 |
+
index = faiss.IndexFlatL2(len(embeddings[0]))
|
51 |
+
index.add(np.array(embeddings, dtype=np.float32))
|
52 |
+
documents = []
|
53 |
+
for i, text in enumerate(texts):
|
54 |
+
metadata = metadatas[i] if metadatas else {}
|
55 |
+
documents.append(Document(page_content=text, metadata=metadata))
|
56 |
+
index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
|
57 |
+
|
58 |
+
# # TODO: cambiar para usar el hash. Y ver donde se pondria para que no cargara el chunk en el dataset
|
59 |
+
# index_to_id_2 = dict()
|
60 |
+
# for i in range(len(documents)):
|
61 |
+
# h = hashlib.new('sha256')
|
62 |
+
# text_ = documents[i].page_content
|
63 |
+
# h.update(text_.encode())
|
64 |
+
# index_to_id_2[i] = str(h.hexdigest())
|
65 |
+
# #
|
66 |
+
docstore = InMemoryDocstore(
|
67 |
+
{index_to_id[i]: doc for i, doc in enumerate(documents)}
|
68 |
+
)
|
69 |
+
return cls(embedding.embed_query, index, docstore, index_to_id)
|
70 |
+
|
71 |
+
@classmethod
|
72 |
+
def from_texts(
|
73 |
+
cls,
|
74 |
+
texts: List[str],
|
75 |
+
embedding: Embeddings,
|
76 |
+
metadatas: Optional[List[dict]] = None,
|
77 |
+
**kwargs: Any,
|
78 |
+
) -> FAISS:
|
79 |
+
"""Construct FAISS wrapper from raw documents.
|
80 |
+
|
81 |
+
This is a user friendly interface that:
|
82 |
+
1. Embeds documents.
|
83 |
+
2. Creates an in memory docstore
|
84 |
+
3. Initializes the FAISS database
|
85 |
+
|
86 |
+
This is intended to be a quick way to get started.
|
87 |
+
|
88 |
+
Example:
|
89 |
+
.. code-block:: python
|
90 |
+
|
91 |
+
from langchain import FAISS
|
92 |
+
from langchain.embeddings import OpenAIEmbeddings
|
93 |
+
embeddings = OpenAIEmbeddings()
|
94 |
+
faiss = FAISS.from_texts(texts, embeddings)
|
95 |
+
"""
|
96 |
+
# embeddings = embedding.embed_documents(texts)
|
97 |
+
print(f"len(texts): {len(texts)}") # TODO: borrar
|
98 |
+
embeddings = [embedding.embed_documents([text])[0] for text in texts]
|
99 |
+
print(f"len(embeddings): {len(embeddings)}") # TODO: borrar
|
100 |
+
return cls.__from(texts, embeddings, embedding, metadatas, **kwargs)
|
streamlit_langchain_chat/customized_langchain/vectorstores/pinecone.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.vectorstores.pinecone import *
|
2 |
+
from langchain.vectorstores.pinecone import Pinecone as OriginalPinecone
|
3 |
+
|
4 |
+
|
5 |
+
class Pinecone(OriginalPinecone):
|
6 |
+
@classmethod
|
7 |
+
def from_texts(
|
8 |
+
cls,
|
9 |
+
texts: List[str],
|
10 |
+
embedding: Embeddings,
|
11 |
+
metadatas: Optional[List[dict]] = None,
|
12 |
+
ids: Optional[List[str]] = None,
|
13 |
+
batch_size: int = 32,
|
14 |
+
text_key: str = "text",
|
15 |
+
index_name: Optional[str] = None,
|
16 |
+
namespace: Optional[str] = None,
|
17 |
+
**kwargs: Any,
|
18 |
+
) -> Pinecone:
|
19 |
+
"""Construct Pinecone wrapper from raw documents.
|
20 |
+
|
21 |
+
This is a user friendly interface that:
|
22 |
+
1. Embeds documents.
|
23 |
+
2. Adds the documents to a provided Pinecone index
|
24 |
+
|
25 |
+
This is intended to be a quick way to get started.
|
26 |
+
|
27 |
+
Example:
|
28 |
+
.. code-block:: python
|
29 |
+
|
30 |
+
from langchain import Pinecone
|
31 |
+
from langchain.embeddings import OpenAIEmbeddings
|
32 |
+
embeddings = OpenAIEmbeddings()
|
33 |
+
pinecone = Pinecone.from_texts(
|
34 |
+
texts,
|
35 |
+
embeddings,
|
36 |
+
index_name="langchain-demo"
|
37 |
+
)
|
38 |
+
"""
|
39 |
+
try:
|
40 |
+
import pinecone
|
41 |
+
except ImportError:
|
42 |
+
raise ValueError(
|
43 |
+
"Could not import pinecone python package. "
|
44 |
+
"Please install it with `pip install pinecone-client`."
|
45 |
+
)
|
46 |
+
_index_name = index_name or str(uuid.uuid4())
|
47 |
+
indexes = pinecone.list_indexes() # checks if provided index exists
|
48 |
+
if _index_name in indexes:
|
49 |
+
index = pinecone.Index(_index_name)
|
50 |
+
else:
|
51 |
+
index = None
|
52 |
+
for i in range(0, len(texts), batch_size):
|
53 |
+
# set end position of batch
|
54 |
+
i_end = min(i + batch_size, len(texts))
|
55 |
+
# get batch of texts and ids
|
56 |
+
lines_batch = texts[i:i_end]
|
57 |
+
# create ids if not provided
|
58 |
+
if ids:
|
59 |
+
ids_batch = ids[i:i_end]
|
60 |
+
else:
|
61 |
+
ids_batch = [str(uuid.uuid4()) for n in range(i, i_end)]
|
62 |
+
# create embeddings
|
63 |
+
# embeds = embedding.embed_documents(lines_batch)
|
64 |
+
embeds = [embedding.embed_documents([line_batch])[0] for line_batch in lines_batch]
|
65 |
+
# prep metadata and upsert batch
|
66 |
+
if metadatas:
|
67 |
+
metadata = metadatas[i:i_end]
|
68 |
+
else:
|
69 |
+
metadata = [{} for _ in range(i, i_end)]
|
70 |
+
for j, line in enumerate(lines_batch):
|
71 |
+
metadata[j][text_key] = line
|
72 |
+
to_upsert = zip(ids_batch, embeds, metadata)
|
73 |
+
# Create index if it does not exist
|
74 |
+
if index is None:
|
75 |
+
pinecone.create_index(_index_name, dimension=len(embeds[0]))
|
76 |
+
index = pinecone.Index(_index_name)
|
77 |
+
# upsert to Pinecone
|
78 |
+
index.upsert(vectors=list(to_upsert), namespace=namespace)
|
79 |
+
return cls(index, embedding.embed_query, text_key, namespace)
|
streamlit_langchain_chat/dataset.py
ADDED
@@ -0,0 +1,740 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from datetime import datetime
|
4 |
+
from functools import reduce
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
import re
|
9 |
+
import requests
|
10 |
+
from requests.models import MissingSchema
|
11 |
+
import sys
|
12 |
+
from typing import List, Optional, Tuple, Dict, Callable, Any
|
13 |
+
|
14 |
+
from bs4 import BeautifulSoup
|
15 |
+
import docx
|
16 |
+
from html2text import html2text
|
17 |
+
import langchain
|
18 |
+
from langchain.callbacks import get_openai_callback
|
19 |
+
from langchain.cache import SQLiteCache
|
20 |
+
from langchain.chains import LLMChain
|
21 |
+
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT
|
22 |
+
from langchain.chat_models import ChatOpenAI
|
23 |
+
from langchain.chat_models.base import BaseChatModel
|
24 |
+
from langchain.document_loaders import PyPDFLoader, PyMuPDFLoader
|
25 |
+
from langchain.embeddings.base import Embeddings
|
26 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
27 |
+
from langchain.llms import OpenAI
|
28 |
+
from langchain.llms.base import LLM, BaseLLM
|
29 |
+
from langchain.prompts.chat import AIMessagePromptTemplate
|
30 |
+
from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter
|
31 |
+
from langchain.vectorstores import Pinecone as OriginalPinecone
|
32 |
+
import numpy as np
|
33 |
+
import openai
|
34 |
+
import pinecone
|
35 |
+
from pptx import Presentation
|
36 |
+
from pypdf import PdfReader
|
37 |
+
import trafilatura
|
38 |
+
|
39 |
+
from streamlit_langchain_chat.constants import *
|
40 |
+
from streamlit_langchain_chat.customized_langchain.vectorstores import FAISS
|
41 |
+
from streamlit_langchain_chat.customized_langchain.vectorstores import Pinecone
|
42 |
+
from streamlit_langchain_chat.utils import maybe_is_text, maybe_is_truncated
|
43 |
+
from streamlit_langchain_chat.prompts import *
|
44 |
+
|
45 |
+
|
46 |
+
if REUSE_ANSWERS:
|
47 |
+
CACHE_PATH = TEMP_DIR / "llm_cache.db"
|
48 |
+
os.makedirs(os.path.dirname(CACHE_PATH), exist_ok=True)
|
49 |
+
langchain.llm_cache = SQLiteCache(str(CACHE_PATH))
|
50 |
+
|
51 |
+
# option 1
|
52 |
+
TextSplitter = TokenTextSplitter
|
53 |
+
# option 2
|
54 |
+
# TextSplitter = RecursiveCharacterTextSplitter # usado por gpt4_pdf_chatbot_langchain (aka GPCL)
|
55 |
+
|
56 |
+
|
57 |
+
@dataclass
|
58 |
+
class Answer:
|
59 |
+
"""A class to hold the answer to a question."""
|
60 |
+
question: str = ""
|
61 |
+
answer: str = ""
|
62 |
+
context: str = ""
|
63 |
+
chunks: str = ""
|
64 |
+
packages: List[Any] = None
|
65 |
+
references: str = ""
|
66 |
+
cost_str: str = ""
|
67 |
+
passages: Dict[str, str] = None
|
68 |
+
tokens: List[Dict] = None
|
69 |
+
|
70 |
+
def __post_init__(self):
|
71 |
+
"""Initialize the answer."""
|
72 |
+
if self.packages is None:
|
73 |
+
self.packages = []
|
74 |
+
if self.passages is None:
|
75 |
+
self.passages = {}
|
76 |
+
|
77 |
+
def __str__(self) -> str:
|
78 |
+
"""Return the answer as a string."""
|
79 |
+
return self.answer
|
80 |
+
|
81 |
+
|
82 |
+
def parse_docx(path, citation, key, chunk_chars=2000, overlap=50):
|
83 |
+
try:
|
84 |
+
document = docx.Document(path)
|
85 |
+
fullText = []
|
86 |
+
for paragraph in document.paragraphs:
|
87 |
+
fullText.append(paragraph.text)
|
88 |
+
doc = '\n'.join(fullText) + '\n'
|
89 |
+
except Exception as e:
|
90 |
+
print(f"code_error: {e}")
|
91 |
+
sys.exit(1)
|
92 |
+
|
93 |
+
if doc:
|
94 |
+
text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap)
|
95 |
+
texts = text_splitter.split_text(doc)
|
96 |
+
return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts)
|
97 |
+
else:
|
98 |
+
return [], []
|
99 |
+
|
100 |
+
|
101 |
+
# TODO: si pones un conector con el formato loader = ... ; data = loader.load();
|
102 |
+
# podrás poner todos los conectores de langchain
|
103 |
+
# https://langchain.readthedocs.io/en/stable/modules/document_loaders/examples/pdf.html
|
104 |
+
def parse_pdf(path, citation, key, chunk_chars=2000, overlap=50):
|
105 |
+
pdfFileObj = open(path, "rb")
|
106 |
+
pdfReader = PdfReader(pdfFileObj)
|
107 |
+
splits = []
|
108 |
+
split = ""
|
109 |
+
pages = []
|
110 |
+
metadatas = []
|
111 |
+
for i, page in enumerate(pdfReader.pages):
|
112 |
+
split += page.extract_text()
|
113 |
+
pages.append(str(i + 1))
|
114 |
+
# split could be so long it needs to be split
|
115 |
+
# into multiple chunks. Or it could be so short
|
116 |
+
# that it needs to be combined with the next chunk.
|
117 |
+
while len(split) > chunk_chars:
|
118 |
+
splits.append(split[:chunk_chars])
|
119 |
+
# pretty formatting of pages (e.g. 1-3, 4, 5-7)
|
120 |
+
pg = "-".join([pages[0], pages[-1]])
|
121 |
+
metadatas.append(
|
122 |
+
dict(
|
123 |
+
citation=citation,
|
124 |
+
dockey=key,
|
125 |
+
key=f"{key} pages {pg}",
|
126 |
+
)
|
127 |
+
)
|
128 |
+
split = split[chunk_chars - overlap:]
|
129 |
+
pages = [str(i + 1)]
|
130 |
+
if len(split) > overlap:
|
131 |
+
splits.append(split[:chunk_chars])
|
132 |
+
pg = "-".join([pages[0], pages[-1]])
|
133 |
+
metadatas.append(
|
134 |
+
dict(
|
135 |
+
citation=citation,
|
136 |
+
dockey=key,
|
137 |
+
key=f"{key} pages {pg}",
|
138 |
+
)
|
139 |
+
)
|
140 |
+
pdfFileObj.close()
|
141 |
+
|
142 |
+
# # ### option 2. PyPDFLoader
|
143 |
+
# loader = PyPDFLoader(path)
|
144 |
+
# data = loader.load_and_split()
|
145 |
+
# # ### option 2.1. PyPDFLoader usado por GPCL, aunque luego usa el
|
146 |
+
# loader = PyPDFLoader(path)
|
147 |
+
# rawDocs = loader.load()
|
148 |
+
# text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap)
|
149 |
+
# texts = text_splitter.split_documents(rawDocs)
|
150 |
+
# # ### option 3. PDFMiner. Este parece la mejor opcion
|
151 |
+
# loader = PyMuPDFLoader(path)
|
152 |
+
# data = loader.load()
|
153 |
+
return splits, metadatas
|
154 |
+
|
155 |
+
|
156 |
+
def parse_pptx(path, citation, key, chunk_chars=2000, overlap=50):
|
157 |
+
try:
|
158 |
+
presentation = Presentation(path)
|
159 |
+
fullText = []
|
160 |
+
for slide in presentation.slides:
|
161 |
+
for shape in slide.shapes:
|
162 |
+
if hasattr(shape, "text"):
|
163 |
+
fullText.append(shape.text)
|
164 |
+
doc = ''.join(fullText)
|
165 |
+
|
166 |
+
if doc:
|
167 |
+
text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap)
|
168 |
+
texts = text_splitter.split_text(doc)
|
169 |
+
return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts)
|
170 |
+
else:
|
171 |
+
return [], []
|
172 |
+
|
173 |
+
except Exception as e:
|
174 |
+
print(f"code_error: {e}")
|
175 |
+
sys.exit(1)
|
176 |
+
|
177 |
+
|
178 |
+
def parse_txt(path, citation, key, chunk_chars=2000, overlap=50, html=False):
|
179 |
+
try:
|
180 |
+
with open(path) as f:
|
181 |
+
doc = f.read()
|
182 |
+
except UnicodeDecodeError as e:
|
183 |
+
with open(path, encoding="utf-8", errors="ignore") as f:
|
184 |
+
doc = f.read()
|
185 |
+
if html:
|
186 |
+
doc = html2text(doc)
|
187 |
+
# yo, no idea why but the texts are not split correctly
|
188 |
+
text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap)
|
189 |
+
texts = text_splitter.split_text(doc)
|
190 |
+
return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts)
|
191 |
+
|
192 |
+
|
193 |
+
def parse_url(url: str, citation, key, chunk_chars=2000, overlap=50):
|
194 |
+
def beautifulsoup_extract_text_fallback(response_content):
|
195 |
+
"""
|
196 |
+
This is a fallback function, so that we can always return a value for text content.
|
197 |
+
Even for when both Trafilatura and BeautifulSoup are unable to extract the text from a
|
198 |
+
single URL.
|
199 |
+
"""
|
200 |
+
|
201 |
+
# Create the beautifulsoup object:
|
202 |
+
soup = BeautifulSoup(response_content, 'html.parser')
|
203 |
+
|
204 |
+
# Finding the text:
|
205 |
+
text = soup.find_all(text=True)
|
206 |
+
|
207 |
+
# Remove unwanted tag elements:
|
208 |
+
cleaned_text = ''
|
209 |
+
blacklist = [
|
210 |
+
'[document]',
|
211 |
+
'noscript',
|
212 |
+
'header',
|
213 |
+
'html',
|
214 |
+
'meta',
|
215 |
+
'head',
|
216 |
+
'input',
|
217 |
+
'script',
|
218 |
+
'style', ]
|
219 |
+
|
220 |
+
# Then we will loop over every item in the extract text and make sure that the beautifulsoup4 tag
|
221 |
+
# is NOT in the blacklist
|
222 |
+
for item in text:
|
223 |
+
if item.parent.name not in blacklist:
|
224 |
+
cleaned_text += f'{item} ' # cleaned_text += '{} '.format(item)
|
225 |
+
|
226 |
+
# Remove any tab separation and strip the text:
|
227 |
+
cleaned_text = cleaned_text.replace('\t', '')
|
228 |
+
return cleaned_text.strip()
|
229 |
+
|
230 |
+
def extract_text_from_single_web_page(url):
|
231 |
+
print(f"\n===========\n{url=}\n===========\n")
|
232 |
+
downloaded_url = trafilatura.fetch_url(url)
|
233 |
+
a = None
|
234 |
+
try:
|
235 |
+
a = trafilatura.extract(downloaded_url,
|
236 |
+
output_format='json',
|
237 |
+
with_metadata=True,
|
238 |
+
include_comments=False,
|
239 |
+
date_extraction_params={'extensive_search': True,
|
240 |
+
'original_date': True})
|
241 |
+
except AttributeError:
|
242 |
+
a = trafilatura.extract(downloaded_url,
|
243 |
+
output_format='json',
|
244 |
+
with_metadata=True,
|
245 |
+
date_extraction_params={'extensive_search': True,
|
246 |
+
'original_date': True})
|
247 |
+
except Exception as e:
|
248 |
+
print(f"code_error: {e}")
|
249 |
+
|
250 |
+
if a:
|
251 |
+
json_output = json.loads(a)
|
252 |
+
return json_output['text']
|
253 |
+
else:
|
254 |
+
try:
|
255 |
+
headers = {'User-Agent': 'Chrome/83.0.4103.106'}
|
256 |
+
resp = requests.get(url, headers=headers)
|
257 |
+
print(f"{resp=}\n")
|
258 |
+
# We will only extract the text from successful requests:
|
259 |
+
if resp.status_code == 200:
|
260 |
+
return beautifulsoup_extract_text_fallback(resp.content)
|
261 |
+
else:
|
262 |
+
# This line will handle for any failures in both the Trafilature and BeautifulSoup4 functions:
|
263 |
+
return np.nan
|
264 |
+
# Handling for any URLs that don't have the correct protocol
|
265 |
+
except MissingSchema:
|
266 |
+
return np.nan
|
267 |
+
|
268 |
+
text_to_split = extract_text_from_single_web_page(url)
|
269 |
+
text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap)
|
270 |
+
texts = text_splitter.split_text(text_to_split)
|
271 |
+
return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts)
|
272 |
+
|
273 |
+
|
274 |
+
def read_source(path: str = None,
|
275 |
+
citation: str = None,
|
276 |
+
key: str = None,
|
277 |
+
chunk_chars: int = 3000,
|
278 |
+
overlap: int = 100,
|
279 |
+
disable_check: bool = False):
|
280 |
+
if path.endswith(".pdf"):
|
281 |
+
return parse_pdf(path, citation, key, chunk_chars, overlap)
|
282 |
+
elif path.endswith(".txt"):
|
283 |
+
return parse_txt(path, citation, key, chunk_chars, overlap)
|
284 |
+
elif path.endswith(".html"):
|
285 |
+
return parse_txt(path, citation, key, chunk_chars, overlap, html=True)
|
286 |
+
elif path.endswith(".docx"):
|
287 |
+
return parse_docx(path, citation, key, chunk_chars, overlap)
|
288 |
+
elif path.endswith(".pptx"):
|
289 |
+
return parse_pptx(path, citation, key, chunk_chars, overlap)
|
290 |
+
elif path.startswith("http://") or path.startswith("https://"):
|
291 |
+
return parse_url(path, citation, key, chunk_chars, overlap)
|
292 |
+
# TODO: poner mas conectores
|
293 |
+
# else:
|
294 |
+
# return parse_code_txt(path, citation, key, chunk_chars, overlap)
|
295 |
+
else:
|
296 |
+
raise "unknown extension"
|
297 |
+
|
298 |
+
|
299 |
+
class Dataset:
|
300 |
+
"""A collection of documents to be used for answering questions."""
|
301 |
+
def __init__(
|
302 |
+
self,
|
303 |
+
chunk_size_limit: int = 3000,
|
304 |
+
llm: Optional[BaseLLM] | Optional[BaseChatModel] = None,
|
305 |
+
summary_llm: Optional[BaseLLM] = None,
|
306 |
+
name: str = "default",
|
307 |
+
index_path: Optional[Path] = None,
|
308 |
+
) -> None:
|
309 |
+
"""Initialize the collection of documents.
|
310 |
+
|
311 |
+
Args:
|
312 |
+
chunk_size_limit: The maximum number of characters to use for a single chunk of text.
|
313 |
+
llm: The language model to use for answering questions. Default - OpenAI chat-gpt-turbo
|
314 |
+
summary_llm: The language model to use for summarizing documents. If None, llm is used.
|
315 |
+
name: The name of the collection.
|
316 |
+
index_path: The path to the index file IF pickled. If None, defaults to using name in $HOME/.paperqa/name
|
317 |
+
"""
|
318 |
+
self.docs = dict()
|
319 |
+
self.keys = set()
|
320 |
+
self.chunk_size_limit = chunk_size_limit
|
321 |
+
|
322 |
+
self.index_docstore = None
|
323 |
+
|
324 |
+
if llm is None:
|
325 |
+
llm = ChatOpenAI(temperature=0.1, max_tokens=512)
|
326 |
+
if summary_llm is None:
|
327 |
+
summary_llm = llm
|
328 |
+
self.update_llm(llm, summary_llm)
|
329 |
+
|
330 |
+
if index_path is None:
|
331 |
+
index_path = TEMP_DIR / name
|
332 |
+
self.index_path = index_path
|
333 |
+
self.name = name
|
334 |
+
|
335 |
+
def update_llm(self, llm: BaseLLM | ChatOpenAI, summary_llm: Optional[BaseLLM] = None) -> None:
|
336 |
+
"""Update the LLM for answering questions."""
|
337 |
+
self.llm = llm
|
338 |
+
if summary_llm is None:
|
339 |
+
summary_llm = llm
|
340 |
+
self.summary_llm = summary_llm
|
341 |
+
self.summary_chain = LLMChain(prompt=chat_summary_prompt, llm=summary_llm)
|
342 |
+
self.search_chain = LLMChain(prompt=search_prompt, llm=llm)
|
343 |
+
self.cite_chain = LLMChain(prompt=citation_prompt, llm=llm)
|
344 |
+
|
345 |
+
def add(
|
346 |
+
self,
|
347 |
+
path: str,
|
348 |
+
citation: Optional[str] = None,
|
349 |
+
key: Optional[str] = None,
|
350 |
+
disable_check: bool = False,
|
351 |
+
chunk_chars: Optional[int] = 3000,
|
352 |
+
) -> None:
|
353 |
+
"""Add a document to the collection."""
|
354 |
+
|
355 |
+
if path in self.docs:
|
356 |
+
print(f"Document {path} already in collection.")
|
357 |
+
return None
|
358 |
+
|
359 |
+
if citation is None:
|
360 |
+
# peak first chunk
|
361 |
+
texts, _ = read_source(path, "", "", chunk_chars=chunk_chars)
|
362 |
+
with get_openai_callback() as cb:
|
363 |
+
citation = self.cite_chain.run(texts[0])
|
364 |
+
if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation:
|
365 |
+
citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}"
|
366 |
+
|
367 |
+
if key is None:
|
368 |
+
# get first name and year from citation
|
369 |
+
try:
|
370 |
+
author = re.search(r"([A-Z][a-z]+)", citation).group(1)
|
371 |
+
except AttributeError:
|
372 |
+
# panicking - no word??
|
373 |
+
raise ValueError(
|
374 |
+
f"Could not parse key from citation {citation}. Consider just passing key explicitly - e.g. docs.py (path, citation, key='mykey')"
|
375 |
+
)
|
376 |
+
try:
|
377 |
+
year = re.search(r"(\d{4})", citation).group(1)
|
378 |
+
except AttributeError:
|
379 |
+
year = ""
|
380 |
+
key = f"{author}{year}"
|
381 |
+
suffix = ""
|
382 |
+
while key + suffix in self.keys:
|
383 |
+
# move suffix to next letter
|
384 |
+
if suffix == "":
|
385 |
+
suffix = "a"
|
386 |
+
else:
|
387 |
+
suffix = chr(ord(suffix) + 1)
|
388 |
+
key += suffix
|
389 |
+
self.keys.add(key)
|
390 |
+
|
391 |
+
texts, metadata = read_source(path, citation, key, chunk_chars=chunk_chars)
|
392 |
+
# loose check to see if document was loaded
|
393 |
+
#
|
394 |
+
if len("".join(texts)) < 10 or (
|
395 |
+
not disable_check and not maybe_is_text("".join(texts))
|
396 |
+
):
|
397 |
+
raise ValueError(
|
398 |
+
f"This does not look like a text document: {path}. Path disable_check to ignore this error."
|
399 |
+
)
|
400 |
+
|
401 |
+
self.docs[path] = dict(texts=texts, metadata=metadata, key=key)
|
402 |
+
if self.index_docstore is not None:
|
403 |
+
self.index_docstore.add_texts(texts, metadatas=metadata)
|
404 |
+
|
405 |
+
def clear(self) -> None:
|
406 |
+
"""Clear the collection of documents."""
|
407 |
+
self.docs = dict()
|
408 |
+
self.keys = set()
|
409 |
+
self.index_docstore = None
|
410 |
+
# delete index file
|
411 |
+
pkl = self.index_path / "index.pkl"
|
412 |
+
if pkl.exists():
|
413 |
+
pkl.unlink()
|
414 |
+
fs = self.index_path / "index.faiss"
|
415 |
+
if fs.exists():
|
416 |
+
fs.unlink()
|
417 |
+
|
418 |
+
@property
|
419 |
+
def doc_previews(self) -> List[Tuple[int, str, str]]:
|
420 |
+
"""Return a list of tuples of (key, citation) for each document."""
|
421 |
+
return [
|
422 |
+
(
|
423 |
+
len(doc["texts"]),
|
424 |
+
doc["metadata"][0]["dockey"],
|
425 |
+
doc["metadata"][0]["citation"],
|
426 |
+
)
|
427 |
+
for doc in self.docs.values()
|
428 |
+
]
|
429 |
+
|
430 |
+
# to pickle, we have to save the index as a file
|
431 |
+
def __getstate__(self, embedding: Embeddings):
|
432 |
+
if embedding is None:
|
433 |
+
embedding = OpenAIEmbeddings()
|
434 |
+
if self.index_docstore is None and len(self.docs) > 0:
|
435 |
+
self._build_faiss_index(embedding)
|
436 |
+
state = self.__dict__.copy()
|
437 |
+
if self.index_docstore is not None:
|
438 |
+
state["_index"].save_local(self.index_path)
|
439 |
+
del state["_index"]
|
440 |
+
# remove LLMs (they can have callbacks, which can't be pickled)
|
441 |
+
del state["summary_chain"]
|
442 |
+
del state["qa_chain"]
|
443 |
+
del state["cite_chain"]
|
444 |
+
del state["search_chain"]
|
445 |
+
return state
|
446 |
+
|
447 |
+
def __setstate__(self, state):
|
448 |
+
self.__dict__.update(state)
|
449 |
+
try:
|
450 |
+
self.index_docstore = FAISS.load_local(self.index_path, OpenAIEmbeddings())
|
451 |
+
except:
|
452 |
+
# they use some special exception type, but I don't want to import it
|
453 |
+
self.index_docstore = None
|
454 |
+
self.update_llm(
|
455 |
+
ChatOpenAI(temperature=0.1, max_tokens=512)
|
456 |
+
)
|
457 |
+
|
458 |
+
def _build_faiss_index(self, embedding: Embeddings = None):
|
459 |
+
if embedding is None:
|
460 |
+
embedding = OpenAIEmbeddings()
|
461 |
+
if self.index_docstore is None:
|
462 |
+
texts = reduce(
|
463 |
+
lambda x, y: x + y, [doc["texts"] for doc in self.docs.values()], []
|
464 |
+
)
|
465 |
+
metadatas = reduce(
|
466 |
+
lambda x, y: x + y, [doc["metadata"] for doc in self.docs.values()], []
|
467 |
+
)
|
468 |
+
|
469 |
+
# if the index exists, load it
|
470 |
+
if LOAD_INDEX_LOCALLY and (self.index_path / "index.faiss").exists():
|
471 |
+
self.index_docstore = FAISS.load_local(self.index_path, embedding)
|
472 |
+
|
473 |
+
# search if the text and metadata already existed in the index
|
474 |
+
for i in reversed(range(len(texts))):
|
475 |
+
text = texts[i]
|
476 |
+
metadata = metadatas[i]
|
477 |
+
for key, value in self.index_docstore.docstore.dict_.items():
|
478 |
+
if value.page_content == text:
|
479 |
+
if value.metadata.get('citation').split(os.sep)[-1] != metadata.get('citation').split(os.sep)[-1]:
|
480 |
+
self.index_docstore.docstore.dict_[key].metadata['citation'] = metadata.get('citation').split(os.sep)[-1]
|
481 |
+
self.index_docstore.docstore.dict_[key].metadata['dockey'] = metadata.get('citation').split(os.sep)[-1]
|
482 |
+
self.index_docstore.docstore.dict_[key].metadata['key'] = metadata.get('citation').split(os.sep)[-1]
|
483 |
+
texts.pop(i)
|
484 |
+
metadatas.pop(i)
|
485 |
+
|
486 |
+
# add remaining texts
|
487 |
+
if texts:
|
488 |
+
self.index_docstore.add_texts(texts=texts, metadatas=metadatas)
|
489 |
+
else:
|
490 |
+
# crete new index
|
491 |
+
self.index_docstore = FAISS.from_texts(texts, embedding, metadatas=metadatas)
|
492 |
+
#
|
493 |
+
|
494 |
+
if SAVE_INDEX_LOCALLY:
|
495 |
+
# save index.
|
496 |
+
self.index_docstore.save_local(self.index_path)
|
497 |
+
|
498 |
+
def _build_pinecone_index(self, embedding: Embeddings = None):
|
499 |
+
if embedding is None:
|
500 |
+
embedding = OpenAIEmbeddings()
|
501 |
+
if self.index_docstore is None:
|
502 |
+
pinecone.init(
|
503 |
+
api_key=os.environ['PINECONE_API_KEY'], # find at app.pinecone.io
|
504 |
+
environment=os.environ['PINECONE_ENVIRONMENT'] # next to api key in console
|
505 |
+
)
|
506 |
+
texts = reduce(
|
507 |
+
lambda x, y: x + y, [doc["texts"] for doc in self.docs.values()], []
|
508 |
+
)
|
509 |
+
metadatas = reduce(
|
510 |
+
lambda x, y: x + y, [doc["metadata"] for doc in self.docs.values()], []
|
511 |
+
)
|
512 |
+
|
513 |
+
# TODO: que cuando exista que no lo borre, sino que lo actualice
|
514 |
+
# index_name = "langchain-demo1"
|
515 |
+
# if index_name in pinecone.list_indexes():
|
516 |
+
# self.index_docstore = pinecone.Index(index_name)
|
517 |
+
# vectors = []
|
518 |
+
# for text, metadata in zip(texts, metadatas):
|
519 |
+
# # embed = <faltaria saber con que embedding se hizo el index que ya existia>
|
520 |
+
# self.index_docstore.upsert(vectors=vectors)
|
521 |
+
# else:
|
522 |
+
# if openai.api_type == 'azure':
|
523 |
+
# self.index_docstore = Pinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name)
|
524 |
+
# else:
|
525 |
+
# self.index_docstore = OriginalPinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name)
|
526 |
+
|
527 |
+
index_name = "langchain-demo1"
|
528 |
+
|
529 |
+
# if the index exists, delete it
|
530 |
+
if index_name in pinecone.list_indexes():
|
531 |
+
pinecone.delete_index(index_name)
|
532 |
+
|
533 |
+
# create new index
|
534 |
+
if openai.api_type == 'azure':
|
535 |
+
self.index_docstore = Pinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name)
|
536 |
+
else:
|
537 |
+
self.index_docstore = OriginalPinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name)
|
538 |
+
|
539 |
+
def get_evidence(
|
540 |
+
self,
|
541 |
+
answer: Answer,
|
542 |
+
embedding: Embeddings,
|
543 |
+
k: int = 3,
|
544 |
+
max_sources: int = 5,
|
545 |
+
marginal_relevance: bool = True,
|
546 |
+
) -> str:
|
547 |
+
if self.index_docstore is None:
|
548 |
+
self._build_faiss_index(embedding)
|
549 |
+
|
550 |
+
init_search_time = time.time()
|
551 |
+
|
552 |
+
# want to work through indices but less k
|
553 |
+
if marginal_relevance:
|
554 |
+
docs = self.index_docstore.max_marginal_relevance_search(
|
555 |
+
answer.question, k=k, fetch_k=5 * k
|
556 |
+
)
|
557 |
+
else:
|
558 |
+
docs = self.index_docstore.similarity_search(
|
559 |
+
answer.question, k=k, fetch_k=5 * k
|
560 |
+
)
|
561 |
+
if OPERATING_MODE == "debug":
|
562 |
+
print(f"time to search docs to build context: {time.time() - init_search_time:.2f} [s]")
|
563 |
+
init_summary_time = time.time()
|
564 |
+
partial_summary_time = ""
|
565 |
+
for i, doc in enumerate(docs):
|
566 |
+
with get_openai_callback() as cb:
|
567 |
+
init__partial_summary_time = time.time()
|
568 |
+
summary_of_chunked_text = self.summary_chain.run(
|
569 |
+
question=answer.question, context_str=doc.page_content
|
570 |
+
)
|
571 |
+
if OPERATING_MODE == "debug":
|
572 |
+
partial_summary_time += f"- time to make relevant summary of doc '{i}': {time.time() - init__partial_summary_time:.2f} [s]\n"
|
573 |
+
engine = self.summary_chain.llm.model_kwargs.get('deployment_id') or self.summary_chain.llm.model_name
|
574 |
+
if not answer.tokens:
|
575 |
+
answer.tokens = [{
|
576 |
+
'engine': engine,
|
577 |
+
'total_tokens': cb.total_tokens}]
|
578 |
+
else:
|
579 |
+
answer.tokens.append({
|
580 |
+
'engine': engine,
|
581 |
+
'total_tokens': cb.total_tokens
|
582 |
+
})
|
583 |
+
summarized_package = (
|
584 |
+
doc.metadata["key"],
|
585 |
+
doc.metadata["citation"],
|
586 |
+
summary_of_chunked_text,
|
587 |
+
doc.page_content,
|
588 |
+
)
|
589 |
+
if "Not applicable" not in summary_of_chunked_text and summarized_package not in answer.packages:
|
590 |
+
answer.packages.append(summarized_package)
|
591 |
+
yield answer
|
592 |
+
if len(answer.packages) == max_sources:
|
593 |
+
break
|
594 |
+
if OPERATING_MODE == "debug":
|
595 |
+
print(f"time to make all relevant summaries: {time.time() - init_summary_time:.2f} [s]")
|
596 |
+
# no se printea el ultimo caracter porque es un \n
|
597 |
+
print(partial_summary_time[:-1])
|
598 |
+
context_str = "\n\n".join(
|
599 |
+
[f"{citation}: {summary_of_chunked_text}"
|
600 |
+
for key, citation, summary_of_chunked_text, chunked_text in answer.packages
|
601 |
+
if "Not applicable" not in summary_of_chunked_text]
|
602 |
+
)
|
603 |
+
chunks_str = "\n\n".join(
|
604 |
+
[f"{citation}: {chunked_text}"
|
605 |
+
for key, citation, summary_of_chunked_text, chunked_text in answer.packages
|
606 |
+
if "Not applicable" not in summary_of_chunked_text]
|
607 |
+
)
|
608 |
+
valid_keys = [key
|
609 |
+
for key, citation, summary_of_chunked_text, chunked_textin in answer.packages
|
610 |
+
if "Not applicable" not in summary_of_chunked_text]
|
611 |
+
if len(valid_keys) > 0:
|
612 |
+
context_str += "\n\nValid keys: " + ", ".join(valid_keys)
|
613 |
+
chunks_str += "\n\nValid keys: " + ", ".join(valid_keys)
|
614 |
+
answer.context = context_str
|
615 |
+
answer.chunks = chunks_str
|
616 |
+
yield answer
|
617 |
+
|
618 |
+
def query(
|
619 |
+
self,
|
620 |
+
query: str,
|
621 |
+
embedding: Embeddings,
|
622 |
+
chat_history: list[tuple[str, str]],
|
623 |
+
k: int = 10,
|
624 |
+
max_sources: int = 5,
|
625 |
+
length_prompt: str = "about 100 words",
|
626 |
+
marginal_relevance: bool = True,
|
627 |
+
):
|
628 |
+
for answer in self._query(
|
629 |
+
query,
|
630 |
+
embedding,
|
631 |
+
chat_history,
|
632 |
+
k=k,
|
633 |
+
max_sources=max_sources,
|
634 |
+
length_prompt=length_prompt,
|
635 |
+
marginal_relevance=marginal_relevance,
|
636 |
+
):
|
637 |
+
pass
|
638 |
+
return answer
|
639 |
+
|
640 |
+
def _query(
|
641 |
+
self,
|
642 |
+
query: str,
|
643 |
+
embedding: Embeddings,
|
644 |
+
chat_history: list[tuple[str, str]],
|
645 |
+
k: int,
|
646 |
+
max_sources: int,
|
647 |
+
length_prompt: str,
|
648 |
+
marginal_relevance: bool,
|
649 |
+
):
|
650 |
+
if k < max_sources:
|
651 |
+
k = max_sources + 1
|
652 |
+
|
653 |
+
answer = Answer(question=query)
|
654 |
+
|
655 |
+
messages_qa = [system_message_prompt]
|
656 |
+
if len(chat_history) != 0:
|
657 |
+
for conversation in chat_history:
|
658 |
+
messages_qa.append(HumanMessagePromptTemplate.from_template(conversation[0]))
|
659 |
+
messages_qa.append(AIMessagePromptTemplate.from_template(conversation[1]))
|
660 |
+
messages_qa.append(human_qa_message_prompt)
|
661 |
+
chat_qa_prompt = ChatPromptTemplate.from_messages(messages_qa)
|
662 |
+
self.qa_chain = LLMChain(prompt=chat_qa_prompt, llm=self.llm)
|
663 |
+
|
664 |
+
for answer in self.get_evidence(
|
665 |
+
answer,
|
666 |
+
embedding,
|
667 |
+
k=k,
|
668 |
+
max_sources=max_sources,
|
669 |
+
marginal_relevance=marginal_relevance,
|
670 |
+
):
|
671 |
+
yield answer
|
672 |
+
|
673 |
+
references_dict = dict()
|
674 |
+
passages = dict()
|
675 |
+
if len(answer.context) < 10:
|
676 |
+
answer_text = "I cannot answer this question due to insufficient information."
|
677 |
+
else:
|
678 |
+
with get_openai_callback() as cb:
|
679 |
+
init_qa_time = time.time()
|
680 |
+
answer_text = self.qa_chain.run(
|
681 |
+
question=answer.question, context_str=answer.context, length=length_prompt
|
682 |
+
)
|
683 |
+
if OPERATING_MODE == "debug":
|
684 |
+
print(f"time to make the Q&A answer: {time.time() - init_qa_time:.2f} [s]")
|
685 |
+
engine = self.qa_chain.llm.model_kwargs.get('deployment_id') or self.qa_chain.llm.model_name
|
686 |
+
if not answer.tokens:
|
687 |
+
answer.tokens = [{
|
688 |
+
'engine': engine,
|
689 |
+
'total_tokens': cb.total_tokens}]
|
690 |
+
else:
|
691 |
+
answer.tokens.append({
|
692 |
+
'engine': engine,
|
693 |
+
'total_tokens': cb.total_tokens
|
694 |
+
})
|
695 |
+
|
696 |
+
# it still happens lol
|
697 |
+
if "(Foo2012)" in answer_text:
|
698 |
+
answer_text = answer_text.replace("(Foo2012)", "")
|
699 |
+
for key, citation, summary, text in answer.packages:
|
700 |
+
# do check for whole key (so we don't catch Callahan2019a with Callahan2019)
|
701 |
+
skey = key.split(" ")[0]
|
702 |
+
if skey + " " in answer_text or skey + ")" in answer_text:
|
703 |
+
references_dict[skey] = citation
|
704 |
+
passages[key] = text
|
705 |
+
references_str = "\n\n".join(
|
706 |
+
[f"{i+1}. ({k}): {c}" for i, (k, c) in enumerate(references_dict.items())]
|
707 |
+
)
|
708 |
+
|
709 |
+
# cost_str = f"{answer_text}\n\n"
|
710 |
+
cost_str = ""
|
711 |
+
itemized_cost = ""
|
712 |
+
total_amount = 0
|
713 |
+
for d in answer.tokens:
|
714 |
+
total_tokens = d.get('total_tokens')
|
715 |
+
if total_tokens:
|
716 |
+
engine = d.get('engine')
|
717 |
+
key_price = None
|
718 |
+
for key in PRICES.keys():
|
719 |
+
if re.match(f"{key}", engine):
|
720 |
+
key_price = key
|
721 |
+
break
|
722 |
+
if PRICES.get(key_price):
|
723 |
+
partial_amount = total_tokens / 1000 * PRICES.get(key_price)
|
724 |
+
total_amount += partial_amount
|
725 |
+
itemized_cost += f"- {engine}: {total_tokens} tokens\t ---> ${partial_amount:.4f},\n"
|
726 |
+
else:
|
727 |
+
itemized_cost += f"- {engine}: {total_tokens} tokens,\n"
|
728 |
+
# delete ,\n
|
729 |
+
itemized_cost = itemized_cost[:-2]
|
730 |
+
|
731 |
+
# add tokens to formatted answer
|
732 |
+
cost_str += f"Total cost: ${total_amount:.4f}\nItemized cost:\n{itemized_cost}"
|
733 |
+
|
734 |
+
answer.answer = answer_text
|
735 |
+
answer.cost_str = cost_str
|
736 |
+
answer.references = references_str
|
737 |
+
answer.passages = passages
|
738 |
+
yield answer
|
739 |
+
|
740 |
+
|
streamlit_langchain_chat/inputs/__init__.py
ADDED
File without changes
|
streamlit_langchain_chat/prompts.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import langchain.prompts as prompts
|
2 |
+
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
3 |
+
from datetime import datetime
|
4 |
+
|
5 |
+
summary_template = """Summarize and provide direct quotes from the text below to help answer a question.
|
6 |
+
Do not directly answer the question, instead provide a summary and quotes with the context of the user's question.
|
7 |
+
Do not use outside sources.
|
8 |
+
Reply with "Not applicable" if the text is unrelated to the question.
|
9 |
+
Use 75 or less words.
|
10 |
+
Remember, if the user does not specify a language, reply in the language of the user's question.
|
11 |
+
|
12 |
+
{context_str}
|
13 |
+
|
14 |
+
User's question: {question}
|
15 |
+
Relevant Information Summary:"""
|
16 |
+
summary_prompt = prompts.PromptTemplate(
|
17 |
+
input_variables=["question", "context_str"],
|
18 |
+
template=summary_template,
|
19 |
+
)
|
20 |
+
|
21 |
+
qa_template = """Write an answer for the user's question below solely based on the provided context.
|
22 |
+
If the user does not specify how many words the answer should be, the length of the answer should be {length}.
|
23 |
+
If the context is irrelevant, reply "Your question falls outside the scope of University of Sydney policy, so I cannot answer".
|
24 |
+
For each sentence in your answer, indicate which sources most support it via valid citation markers at the end of sentences, like (Example2012).
|
25 |
+
Answer in an unbiased and professional tone.
|
26 |
+
Make clear what is your opinion.
|
27 |
+
Use Markdown for formatting code or text, and try to use direct quotes to support arguments.
|
28 |
+
Remember, if the user does not specify a language, answer in the language of the user's question.
|
29 |
+
|
30 |
+
Context:
|
31 |
+
{context_str}
|
32 |
+
|
33 |
+
|
34 |
+
User's question: {question}
|
35 |
+
Answer:
|
36 |
+
"""
|
37 |
+
qa_prompt = prompts.PromptTemplate(
|
38 |
+
input_variables=["question", "context_str", "length"],
|
39 |
+
template=qa_template,
|
40 |
+
)
|
41 |
+
|
42 |
+
# usado por GPCL
|
43 |
+
qa_prompt_GPCL = prompts.PromptTemplate(
|
44 |
+
input_variables=["question", "context_str"],
|
45 |
+
template="You are an AI assistant providing helpful advice about University of Sydney policy. You are given the following extracted parts of a long document and a question. Provide a conversational answer based on the context provided."
|
46 |
+
"You should only provide hyperlinks that reference the context below. Do NOT make up hyperlinks."
|
47 |
+
'If you can not find the answer in the context below, just say "Hmm, I am not sure. Could you please rephrase your question?" Do not try to make up an answer.'
|
48 |
+
"If the question is not related to the context, politely respond that you are tuned to only answer questions that are related to the context.\n\n"
|
49 |
+
"Question: {question}\n"
|
50 |
+
"=========\n"
|
51 |
+
"{context_str}\n"
|
52 |
+
"=========\n"
|
53 |
+
"Answer in Markdown:",
|
54 |
+
)
|
55 |
+
|
56 |
+
search_prompt = prompts.PromptTemplate(
|
57 |
+
input_variables=["question"],
|
58 |
+
template="We want to answer the following question: {question} \n"
|
59 |
+
"Provide three different targeted keyword searches (one search per line) "
|
60 |
+
"that will find papers that help answer the question. Do not use boolean operators. "
|
61 |
+
"Recent years are 2021, 2022, 2023.\n\n"
|
62 |
+
"1.",
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
def _get_datetime():
|
67 |
+
now = datetime.now()
|
68 |
+
return now.strftime("%m/%d/%Y")
|
69 |
+
|
70 |
+
|
71 |
+
citation_prompt = prompts.PromptTemplate(
|
72 |
+
input_variables=["text"],
|
73 |
+
template="Provide a possible citation for the following text in MLA Format. Today's date is {date}\n"
|
74 |
+
"{text}\n\n"
|
75 |
+
"Citation:",
|
76 |
+
partial_variables={"date": _get_datetime},
|
77 |
+
)
|
78 |
+
|
79 |
+
system_template = """You are an AI chatbot with knowledge of the University of Sydney's legal policies that answers in an unbiased, professional tone.
|
80 |
+
You sometimes refuse to answer if there is insufficient information.
|
81 |
+
If the user does not specify a language, answer in the language of the user's question. """
|
82 |
+
system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
|
83 |
+
|
84 |
+
human_summary_message_prompt = HumanMessagePromptTemplate.from_template(summary_template)
|
85 |
+
chat_summary_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_summary_message_prompt])
|
86 |
+
|
87 |
+
human_qa_message_prompt = HumanMessagePromptTemplate.from_template(qa_template)
|
88 |
+
# chat_qa_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_qa_message_prompt]) # TODO: borrar
|
89 |
+
|
90 |
+
# human_condense_message_prompt = HumanMessagePromptTemplate.from_template(condense_template)
|
91 |
+
# chat_condense_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_condense_message_prompt])
|
streamlit_langchain_chat/streamlit_app.py
ADDED
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
To run:
|
4 |
+
- activate the virtual environment
|
5 |
+
- streamlit run path\to\streamlit_app.py
|
6 |
+
"""
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import re
|
10 |
+
import sys
|
11 |
+
import time
|
12 |
+
import warnings
|
13 |
+
import shutil
|
14 |
+
|
15 |
+
from langchain.chat_models import ChatOpenAI
|
16 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
17 |
+
import openai
|
18 |
+
import pandas as pd
|
19 |
+
import streamlit as st
|
20 |
+
from st_aggrid import GridOptionsBuilder, AgGrid, GridUpdateMode, ColumnsAutoSizeMode
|
21 |
+
from streamlit_chat import message
|
22 |
+
|
23 |
+
from streamlit_langchain_chat.constants import *
|
24 |
+
from streamlit_langchain_chat.customized_langchain.llms import OpenAI, AzureOpenAI, AzureOpenAIChat
|
25 |
+
from streamlit_langchain_chat.dataset import Dataset
|
26 |
+
|
27 |
+
# Configure logger
|
28 |
+
logging.basicConfig(format="\n%(asctime)s\n%(message)s", level=logging.INFO, force=True)
|
29 |
+
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
|
30 |
+
|
31 |
+
warnings.filterwarnings('ignore')
|
32 |
+
|
33 |
+
if 'generated' not in st.session_state:
|
34 |
+
st.session_state['generated'] = []
|
35 |
+
if 'past' not in st.session_state:
|
36 |
+
st.session_state['past'] = []
|
37 |
+
if 'costs' not in st.session_state:
|
38 |
+
st.session_state['costs'] = []
|
39 |
+
if 'contexts' not in st.session_state:
|
40 |
+
st.session_state['contexts'] = []
|
41 |
+
if 'chunks' not in st.session_state:
|
42 |
+
st.session_state['chunks'] = []
|
43 |
+
if 'user_input' not in st.session_state:
|
44 |
+
st.session_state['user_input'] = ""
|
45 |
+
if 'dataset' not in st.session_state:
|
46 |
+
st.session_state['dataset'] = None
|
47 |
+
|
48 |
+
|
49 |
+
def check_api_keys() -> bool:
|
50 |
+
source_id = app.params['source_id']
|
51 |
+
index_id = app.params['index_id']
|
52 |
+
|
53 |
+
open_api_key = os.getenv('OPENAI_API_KEY', '')
|
54 |
+
openapi_api_key_ready = type(open_api_key) is str and len(open_api_key) > 0
|
55 |
+
|
56 |
+
pinecone_api_key = os.getenv('PINECONE_API_KEY', '')
|
57 |
+
pinecone_api_key_ready = type(pinecone_api_key) is str and len(pinecone_api_key) > 0 if index_id == 2 else True
|
58 |
+
|
59 |
+
is_ready = True if openapi_api_key_ready and pinecone_api_key_ready else False
|
60 |
+
return is_ready
|
61 |
+
|
62 |
+
|
63 |
+
def check_combination_point() -> bool:
|
64 |
+
type_id = app.params['type_id']
|
65 |
+
open_api_key = os.getenv('OPENAI_API_KEY', '')
|
66 |
+
openapi_api_key_ready = type(open_api_key) is str and len(open_api_key) > 0
|
67 |
+
api_base = app.params['api_base']
|
68 |
+
|
69 |
+
if type_id == 1:
|
70 |
+
deployment_id = app.params['deployment_id']
|
71 |
+
return True if openapi_api_key_ready and api_base and deployment_id else False
|
72 |
+
elif type_id == 2:
|
73 |
+
return True if openapi_api_key_ready and api_base else False
|
74 |
+
else:
|
75 |
+
return False
|
76 |
+
|
77 |
+
|
78 |
+
def check_index() -> bool:
|
79 |
+
dataset = st.session_state['dataset']
|
80 |
+
|
81 |
+
index_built = dataset.index_docstore if hasattr(dataset, "index_docstore") else False
|
82 |
+
without_source = app.params['source_id'] == 4
|
83 |
+
is_ready = True if index_built or without_source else False
|
84 |
+
return is_ready
|
85 |
+
|
86 |
+
|
87 |
+
def check_index_point() -> bool:
|
88 |
+
index_id = app.params['index_id']
|
89 |
+
|
90 |
+
pinecone_api_key = os.getenv('PINECONE_API_KEY', '')
|
91 |
+
pinecone_api_key_ready = type(pinecone_api_key) is str and len(pinecone_api_key) > 0 if index_id == 2 else True
|
92 |
+
pinecone_environment = os.getenv('PINECONE_ENVIRONMENT', False) if index_id == 2 else True
|
93 |
+
|
94 |
+
is_ready = True if index_id and pinecone_api_key_ready and pinecone_environment else False
|
95 |
+
return is_ready
|
96 |
+
|
97 |
+
|
98 |
+
def check_params_point() -> bool:
|
99 |
+
max_sources = app.params['max_sources']
|
100 |
+
temperature = app.params['temperature']
|
101 |
+
|
102 |
+
is_ready = True if max_sources and isinstance(temperature, float) else False
|
103 |
+
return is_ready
|
104 |
+
|
105 |
+
|
106 |
+
def check_source_point() -> bool:
|
107 |
+
return True
|
108 |
+
|
109 |
+
|
110 |
+
def clear_chat_history():
|
111 |
+
if st.session_state['past'] or st.session_state['generated'] or st.session_state['contexts'] or st.session_state['chunks'] or st.session_state['costs']:
|
112 |
+
st.session_state['past'] = []
|
113 |
+
st.session_state['generated'] = []
|
114 |
+
st.session_state['contexts'] = []
|
115 |
+
st.session_state['chunks'] = []
|
116 |
+
st.session_state['costs'] = []
|
117 |
+
|
118 |
+
|
119 |
+
def clear_index():
|
120 |
+
if dataset := st.session_state['dataset']:
|
121 |
+
# delete directory (with files)
|
122 |
+
index_path = dataset.index_path
|
123 |
+
if index_path.exists():
|
124 |
+
shutil.rmtree(str(index_path))
|
125 |
+
|
126 |
+
# update variable
|
127 |
+
st.session_state['dataset'] = None
|
128 |
+
|
129 |
+
elif (TEMP_DIR / "default").exists():
|
130 |
+
shutil.rmtree(str(TEMP_DIR / "default"))
|
131 |
+
|
132 |
+
|
133 |
+
def check_sources() -> bool:
|
134 |
+
uploaded_files_rows = app.params['uploaded_files_rows']
|
135 |
+
urls_df = app.params['urls_df']
|
136 |
+
source_id = app.params['source_id']
|
137 |
+
|
138 |
+
some_files = True if uploaded_files_rows and uploaded_files_rows[-1].get('filepath') != "" else False
|
139 |
+
some_urls = bool([True for url, citation in urls_df.to_numpy() if url])
|
140 |
+
|
141 |
+
only_local_files = some_files and not some_urls
|
142 |
+
only_urls = not some_files and some_urls
|
143 |
+
is_ready = only_local_files or only_urls or (source_id == 4)
|
144 |
+
return is_ready
|
145 |
+
|
146 |
+
|
147 |
+
def collect_dataset_and_built_index():
|
148 |
+
start = time.time()
|
149 |
+
uploaded_files_rows = app.params['uploaded_files_rows']
|
150 |
+
urls_df = app.params['urls_df']
|
151 |
+
type_id = app.params['type_id']
|
152 |
+
temperature = app.params['temperature']
|
153 |
+
index_id = app.params['index_id']
|
154 |
+
api_base = app.params['api_base']
|
155 |
+
deployment_id = app.params['deployment_id']
|
156 |
+
|
157 |
+
some_files = True if uploaded_files_rows and uploaded_files_rows[-1].get('filepath') != "" else False
|
158 |
+
some_urls = bool([True for url, citation in urls_df.to_numpy() if url])
|
159 |
+
|
160 |
+
openai.api_type = "azure" if type_id == 1 else "open_ai"
|
161 |
+
openai.api_base = api_base
|
162 |
+
openai.api_version = "2023-03-15-preview" if type_id == 1 else None
|
163 |
+
|
164 |
+
if deployment_id != "text-davinci-003":
|
165 |
+
dataset = Dataset(
|
166 |
+
llm=ChatOpenAI(
|
167 |
+
temperature=temperature,
|
168 |
+
max_tokens=512,
|
169 |
+
deployment_id=deployment_id,
|
170 |
+
)
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
dataset = Dataset(
|
174 |
+
llm=OpenAI(
|
175 |
+
temperature=temperature,
|
176 |
+
max_tokens=512,
|
177 |
+
deployment_id=COMBINATIONS_OPTIONS.get(combination_id).get('deployment_name'),
|
178 |
+
)
|
179 |
+
)
|
180 |
+
|
181 |
+
# get url documents
|
182 |
+
if some_urls:
|
183 |
+
urls_df = urls_df.reset_index()
|
184 |
+
for url_index, url_row in urls_df.iterrows():
|
185 |
+
url = url_row.get('urls', '')
|
186 |
+
citation = url_row.get('citation string', '')
|
187 |
+
if url:
|
188 |
+
try:
|
189 |
+
dataset.add(
|
190 |
+
url,
|
191 |
+
citation,
|
192 |
+
citation,
|
193 |
+
disable_check=True # True to accept Japanese letters
|
194 |
+
)
|
195 |
+
except Exception as e:
|
196 |
+
print(e)
|
197 |
+
pass
|
198 |
+
|
199 |
+
# dataset is pandas dataframe
|
200 |
+
if some_files:
|
201 |
+
for uploaded_files_row in uploaded_files_rows:
|
202 |
+
key = uploaded_files_row.get('citation string') if ',' not in uploaded_files_row.get('citation string') else None
|
203 |
+
dataset.add(
|
204 |
+
uploaded_files_row.get('filepath'),
|
205 |
+
uploaded_files_row.get('citation string'),
|
206 |
+
key=key,
|
207 |
+
disable_check=True # True to accept Japanese letters
|
208 |
+
)
|
209 |
+
|
210 |
+
openai_embeddings = OpenAIEmbeddings(
|
211 |
+
document_model_name="text-embedding-ada-002",
|
212 |
+
query_model_name="text-embedding-ada-002",
|
213 |
+
)
|
214 |
+
if index_id == 1:
|
215 |
+
dataset._build_faiss_index(openai_embeddings)
|
216 |
+
else:
|
217 |
+
dataset._build_pinecone_index(openai_embeddings)
|
218 |
+
st.session_state['dataset'] = dataset
|
219 |
+
|
220 |
+
if OPERATING_MODE == "debug":
|
221 |
+
print(f"time to collect dataset: {time.time() - start:.2f} [s]")
|
222 |
+
|
223 |
+
|
224 |
+
def configure_streamlit_and_page():
|
225 |
+
# Configure Streamlit page and state
|
226 |
+
st.set_page_config(**ST_CONFIG)
|
227 |
+
|
228 |
+
# Force responsive layout for columns also on mobile
|
229 |
+
st.write(
|
230 |
+
"""<style>
|
231 |
+
[data-testid="column"] {
|
232 |
+
width: calc(50% - 1rem);
|
233 |
+
flex: 1 1 calc(50% - 1rem);
|
234 |
+
min-width: calc(50% - 1rem);
|
235 |
+
}
|
236 |
+
</style>""",
|
237 |
+
unsafe_allow_html=True,
|
238 |
+
)
|
239 |
+
|
240 |
+
|
241 |
+
def get_answer():
|
242 |
+
query = st.session_state['user_input']
|
243 |
+
dataset = st.session_state['dataset']
|
244 |
+
type_id = app.params['type_id']
|
245 |
+
index_id = app.params['index_id']
|
246 |
+
max_sources = app.params['max_sources']
|
247 |
+
|
248 |
+
if query and dataset and type_id and index_id:
|
249 |
+
chat_history = [(past, generated)
|
250 |
+
for (past, generated) in zip(st.session_state['past'], st.session_state['generated'])]
|
251 |
+
marginal_relevance = False if not index_id == 1 else True
|
252 |
+
start = time.time()
|
253 |
+
openai_embeddings = OpenAIEmbeddings(
|
254 |
+
document_model_name="text-embedding-ada-002",
|
255 |
+
query_model_name="text-embedding-ada-002",
|
256 |
+
)
|
257 |
+
result = dataset.query(
|
258 |
+
query,
|
259 |
+
openai_embeddings,
|
260 |
+
chat_history,
|
261 |
+
marginal_relevance=marginal_relevance, # if pinecone is used it must be False
|
262 |
+
)
|
263 |
+
if OPERATING_MODE == "debug":
|
264 |
+
print(f"time to get answer: {time.time() - start:.2f} [s]")
|
265 |
+
print("-" * 10)
|
266 |
+
# response = {'generated_text': result.formatted_answer}
|
267 |
+
# response = {'generated_text': f"test_{len(st.session_state['generated'])} by {query}"} # @debug
|
268 |
+
return result
|
269 |
+
else:
|
270 |
+
return None
|
271 |
+
|
272 |
+
|
273 |
+
def load_main_page():
|
274 |
+
"""
|
275 |
+
Load the body of web.
|
276 |
+
"""
|
277 |
+
# Streamlit HTML Markdown
|
278 |
+
# st.title <h1> #
|
279 |
+
# st.header <h2> ##
|
280 |
+
# st.subheader <h3> ###
|
281 |
+
st.markdown(f"## Augmented-Retrieval Q&A ChatGPT ({APP_VERSION})")
|
282 |
+
validate_status()
|
283 |
+
st.markdown(f"#### **Status**: {app.params['status']}")
|
284 |
+
|
285 |
+
# hidden div with anchor
|
286 |
+
st.markdown("<div id='linkto_top'></div>", unsafe_allow_html=True)
|
287 |
+
col1, col2, col3 = st.columns(3)
|
288 |
+
col1.button(label="clear index", type="primary", on_click=clear_index)
|
289 |
+
col2.button(label="clear conversation", type="primary", on_click=clear_chat_history)
|
290 |
+
col3.markdown("<a href='#linkto_bottom'>Link to bottom</a>", unsafe_allow_html=True)
|
291 |
+
|
292 |
+
if st.session_state["generated"]:
|
293 |
+
for i in range(len(st.session_state["generated"])):
|
294 |
+
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
|
295 |
+
message(st.session_state['generated'][i], key=str(i))
|
296 |
+
with st.expander("See context"):
|
297 |
+
st.write(st.session_state['contexts'][i])
|
298 |
+
with st.expander("See chunks"):
|
299 |
+
st.write(st.session_state['chunks'][i])
|
300 |
+
with st.expander("See costs"):
|
301 |
+
st.write(st.session_state['costs'][i])
|
302 |
+
dataset = st.session_state['dataset']
|
303 |
+
index_built = dataset.index_docstore if hasattr(dataset, "index_docstore") else False
|
304 |
+
without_source = app.params['source_id'] == 4
|
305 |
+
enable_chat_button = index_built or without_source
|
306 |
+
st.text_input("You:",
|
307 |
+
key='user_input',
|
308 |
+
on_change=on_enter,
|
309 |
+
disabled=not enable_chat_button
|
310 |
+
)
|
311 |
+
|
312 |
+
st.markdown("<a href='#linkto_top'>Link to top</a>", unsafe_allow_html=True)
|
313 |
+
# hidden div with anchor
|
314 |
+
st.markdown("<div id='linkto_bottom'></div>", unsafe_allow_html=True)
|
315 |
+
|
316 |
+
|
317 |
+
def load_sidebar_page():
|
318 |
+
st.sidebar.markdown("## Instructions")
|
319 |
+
|
320 |
+
# ############ #
|
321 |
+
# SOURCES TYPE #
|
322 |
+
# ############ #
|
323 |
+
st.sidebar.markdown("1. Select a source:")
|
324 |
+
source_selected = st.sidebar.selectbox(
|
325 |
+
"Choose the location of your info to give context to chatgpt",
|
326 |
+
[key for key, value in SOURCES_IDS.items()])
|
327 |
+
app.params['source_id'] = SOURCES_IDS.get(source_selected, None)
|
328 |
+
|
329 |
+
# ##### #
|
330 |
+
# MODEL #
|
331 |
+
# ##### #
|
332 |
+
st.sidebar.markdown("2. Select a model (LLM):")
|
333 |
+
combination_selected = st.sidebar.selectbox(
|
334 |
+
"Choose type: MSF Azure OpenAI and model / OpenAI",
|
335 |
+
[key for key, value in TYPE_IDS.items()])
|
336 |
+
app.params['type_id'] = TYPE_IDS.get(combination_selected, None)
|
337 |
+
|
338 |
+
if app.params['type_id'] == 1: # with AzureOpenAI endpoint
|
339 |
+
# https://docs.streamlit.io/library/api-reference/widgets/st.text_input
|
340 |
+
os.environ['OPENAI_API_KEY'] = st.sidebar.text_input(
|
341 |
+
label="Enter Azure OpenAI API Key",
|
342 |
+
type="password"
|
343 |
+
).strip()
|
344 |
+
app.params['api_base'] = st.sidebar.text_input(
|
345 |
+
label="Enter Azure API base",
|
346 |
+
placeholder="https://<api_base_endpoint>.openai.azure.com/",
|
347 |
+
).strip()
|
348 |
+
app.params['deployment_id'] = st.sidebar.text_input(
|
349 |
+
label="Enter Azure deployment_id",
|
350 |
+
).strip()
|
351 |
+
elif app.params['type_id'] == 2: # with OpenAI endpoint
|
352 |
+
os.environ['OPENAI_API_KEY'] = st.sidebar.text_input(
|
353 |
+
label="Enter OpenAI API Key",
|
354 |
+
placeholder="sk-...",
|
355 |
+
type="password"
|
356 |
+
).strip()
|
357 |
+
app.params['api_base'] = "https://api.openai.com/v1"
|
358 |
+
app.params['deployment_id'] = None
|
359 |
+
|
360 |
+
# ####### #
|
361 |
+
# INDEXES #
|
362 |
+
# ####### #
|
363 |
+
st.sidebar.markdown("3. Select a index store:")
|
364 |
+
index_selected = st.sidebar.selectbox(
|
365 |
+
"Type of Index",
|
366 |
+
[key for key, value in INDEX_IDS.items()])
|
367 |
+
app.params['index_id'] = INDEX_IDS.get(index_selected, None)
|
368 |
+
if app.params['index_id'] == 2: # with pinecone
|
369 |
+
os.environ['PINECONE_API_KEY'] = st.sidebar.text_input(
|
370 |
+
label="Enter pinecone API Key",
|
371 |
+
type="password"
|
372 |
+
).strip()
|
373 |
+
|
374 |
+
os.environ['PINECONE_ENVIRONMENT'] = st.sidebar.text_input(
|
375 |
+
label="Enter pinecone environment",
|
376 |
+
placeholder="eu-west1-gcp",
|
377 |
+
).strip()
|
378 |
+
|
379 |
+
# ############## #
|
380 |
+
# CONFIGURATIONS #
|
381 |
+
# ############## #
|
382 |
+
st.sidebar.markdown("4. Choose configuration:")
|
383 |
+
# https://docs.streamlit.io/library/api-reference/widgets/st.number_input
|
384 |
+
max_sources = st.sidebar.number_input(
|
385 |
+
label="Top-k: Number of chunks/sections (1-5)",
|
386 |
+
step=1,
|
387 |
+
format="%d",
|
388 |
+
value=5
|
389 |
+
)
|
390 |
+
app.params['max_sources'] = max_sources
|
391 |
+
temperature = st.sidebar.number_input(
|
392 |
+
label="Temperature (0.0 – 1.0)",
|
393 |
+
step=0.1,
|
394 |
+
format="%f",
|
395 |
+
value=0.0,
|
396 |
+
min_value=0.0,
|
397 |
+
max_value=1.0
|
398 |
+
)
|
399 |
+
app.params['temperature'] = round(temperature, 1)
|
400 |
+
|
401 |
+
# ############## #
|
402 |
+
# UPLOAD SOURCES #
|
403 |
+
# ############## #
|
404 |
+
app.params['uploaded_files_rows'] = []
|
405 |
+
if app.params['source_id'] == 1:
|
406 |
+
# https://docs.streamlit.io/library/api-reference/widgets/st.file_uploader
|
407 |
+
# https://towardsdatascience.com/make-dataframes-interactive-in-streamlit-c3d0c4f84ccb
|
408 |
+
st.sidebar.markdown("""5. Upload your local documents and modify citation strings (optional)""")
|
409 |
+
uploaded_files = st.sidebar.file_uploader(
|
410 |
+
"Choose files",
|
411 |
+
accept_multiple_files=True,
|
412 |
+
type=['pdf', 'PDF',
|
413 |
+
'txt', 'TXT',
|
414 |
+
'html',
|
415 |
+
'docx', 'DOCX',
|
416 |
+
'pptx', 'PPTX',
|
417 |
+
],
|
418 |
+
)
|
419 |
+
uploaded_files_dataset = request_pathname(uploaded_files)
|
420 |
+
uploaded_files_df = pd.DataFrame(
|
421 |
+
uploaded_files_dataset,
|
422 |
+
columns=['filepath', 'citation string'])
|
423 |
+
uploaded_files_grid_options_builder = GridOptionsBuilder.from_dataframe(uploaded_files_df)
|
424 |
+
uploaded_files_grid_options_builder.configure_selection(
|
425 |
+
selection_mode='multiple',
|
426 |
+
pre_selected_rows=list(range(uploaded_files_df.shape[0])) if uploaded_files_df.iloc[-1, 0] != "" else [],
|
427 |
+
use_checkbox=True,
|
428 |
+
)
|
429 |
+
uploaded_files_grid_options_builder.configure_column("citation string", editable=True)
|
430 |
+
uploaded_files_grid_options_builder.configure_auto_height()
|
431 |
+
uploaded_files_grid_options = uploaded_files_grid_options_builder.build()
|
432 |
+
with st.sidebar:
|
433 |
+
uploaded_files_ag_grid = AgGrid(
|
434 |
+
uploaded_files_df,
|
435 |
+
gridOptions=uploaded_files_grid_options,
|
436 |
+
update_mode=GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED,
|
437 |
+
)
|
438 |
+
app.params['uploaded_files_rows'] = uploaded_files_ag_grid["selected_rows"]
|
439 |
+
|
440 |
+
app.params['urls_df'] = pd.DataFrame()
|
441 |
+
if app.params['source_id'] == 3:
|
442 |
+
st.sidebar.markdown("""5. Write some urls and modify citation strings if you want (to look prettier)""")
|
443 |
+
# option 1: with streamlit version 1.20.0+
|
444 |
+
# app.params['urls_df'] = st.sidebar.experimental_data_editor(
|
445 |
+
# pd.DataFrame([["", ""]], columns=['urls', 'citation string']),
|
446 |
+
# use_container_width=True,
|
447 |
+
# num_rows="dynamic",
|
448 |
+
# )
|
449 |
+
|
450 |
+
# option 2: with streamlit version 1.19.0
|
451 |
+
urls_dataset = [["", ""],
|
452 |
+
["", ""],
|
453 |
+
["", ""],
|
454 |
+
["", ""],
|
455 |
+
["", ""]]
|
456 |
+
urls_df = pd.DataFrame(
|
457 |
+
urls_dataset,
|
458 |
+
columns=['urls', 'citation string'])
|
459 |
+
|
460 |
+
urls_grid_options_builder = GridOptionsBuilder.from_dataframe(urls_df)
|
461 |
+
urls_grid_options_builder.configure_columns(['urls', 'citation string'], editable=True)
|
462 |
+
urls_grid_options_builder.configure_auto_height()
|
463 |
+
urls_grid_options = urls_grid_options_builder.build()
|
464 |
+
with st.sidebar:
|
465 |
+
urls_ag_grid = AgGrid(
|
466 |
+
urls_df,
|
467 |
+
gridOptions=urls_grid_options,
|
468 |
+
update_mode=GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED,
|
469 |
+
)
|
470 |
+
df = urls_ag_grid.data
|
471 |
+
df = df[df.urls != ""]
|
472 |
+
app.params['urls_df'] = df
|
473 |
+
|
474 |
+
if app.params['source_id'] in (1, 2, 3):
|
475 |
+
st.sidebar.markdown("""6. Build an index where you can ask""")
|
476 |
+
api_keys_ready = check_api_keys()
|
477 |
+
source_ready = check_sources()
|
478 |
+
enable_index_button = api_keys_ready and source_ready
|
479 |
+
if st.sidebar.button("Build index", disabled=not enable_index_button):
|
480 |
+
collect_dataset_and_built_index()
|
481 |
+
|
482 |
+
|
483 |
+
def main():
|
484 |
+
configure_streamlit_and_page()
|
485 |
+
load_sidebar_page()
|
486 |
+
load_main_page()
|
487 |
+
|
488 |
+
|
489 |
+
def on_enter():
|
490 |
+
output = get_answer()
|
491 |
+
if output:
|
492 |
+
st.session_state.past.append(st.session_state['user_input'])
|
493 |
+
st.session_state.generated.append(output.answer)
|
494 |
+
st.session_state.contexts.append(output.context)
|
495 |
+
st.session_state.chunks.append(output.chunks)
|
496 |
+
st.session_state.costs.append(output.cost_str)
|
497 |
+
st.session_state['user_input'] = ""
|
498 |
+
|
499 |
+
|
500 |
+
def request_pathname(files):
|
501 |
+
if not files:
|
502 |
+
return [["", ""]]
|
503 |
+
|
504 |
+
# check if temporal directory exist, if not create it
|
505 |
+
if not Path.exists(TEMP_DIR):
|
506 |
+
TEMP_DIR.mkdir(
|
507 |
+
parents=True,
|
508 |
+
exist_ok=True,
|
509 |
+
)
|
510 |
+
|
511 |
+
file_paths = []
|
512 |
+
for file in files:
|
513 |
+
# # absolut path
|
514 |
+
# file_path = str(TEMP_DIR / file.name)
|
515 |
+
# relative path
|
516 |
+
file_path = str((TEMP_DIR / file.name).relative_to(ROOT_DIR))
|
517 |
+
file_paths.append(file_path)
|
518 |
+
with open(file_path, "wb") as f:
|
519 |
+
f.write(file.getbuffer())
|
520 |
+
return [[filepath, filename.name] for filepath, filename in zip(file_paths, files)]
|
521 |
+
|
522 |
+
|
523 |
+
def validate_status():
|
524 |
+
source_point_ready = check_source_point()
|
525 |
+
combination_point_ready = check_combination_point()
|
526 |
+
index_point_ready = check_index_point()
|
527 |
+
params_point_ready = check_params_point()
|
528 |
+
sources_ready = check_sources()
|
529 |
+
index_ready = check_index()
|
530 |
+
|
531 |
+
if source_point_ready and combination_point_ready and index_point_ready and params_point_ready and sources_ready and index_ready:
|
532 |
+
app.params['status'] = "✨Ready✨"
|
533 |
+
elif not source_point_ready:
|
534 |
+
app.params['status'] = "⚠️Review step 1 on the sidebar."
|
535 |
+
elif not combination_point_ready:
|
536 |
+
app.params['status'] = "⚠️Review step 2 on the sidebar. API Keys or endpoint, ..."
|
537 |
+
elif not index_point_ready:
|
538 |
+
app.params['status'] = "⚠️Review step 3 on the sidebar. Index API Key or environment."
|
539 |
+
elif not params_point_ready:
|
540 |
+
app.params['status'] = "⚠️Review step 4 on the sidebar"
|
541 |
+
elif not sources_ready:
|
542 |
+
app.params['status'] = "⚠️Review step 5 on the sidebar. Waiting for some source..."
|
543 |
+
elif not index_ready:
|
544 |
+
app.params['status'] = "⚠️Review step 6 on the sidebar. Waiting for press button to create index ..."
|
545 |
+
else:
|
546 |
+
app.params['status'] = "⚠️Something is not ready..."
|
547 |
+
|
548 |
+
|
549 |
+
class StreamlitLangchainChatApp():
|
550 |
+
def __init__(self) -> None:
|
551 |
+
"""Use __init__ to define instance variables. It cannot have any arguments."""
|
552 |
+
self.params = dict()
|
553 |
+
|
554 |
+
def run(self, **state) -> None:
|
555 |
+
"""Define here all logic required by your application."""
|
556 |
+
main()
|
557 |
+
|
558 |
+
|
559 |
+
if __name__ == "__main__":
|
560 |
+
app = StreamlitLangchainChatApp()
|
561 |
+
app.run()
|
streamlit_langchain_chat/utils.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import string
|
3 |
+
|
4 |
+
|
5 |
+
def maybe_is_text(s, thresh=2.5):
|
6 |
+
if len(s) == 0:
|
7 |
+
return False
|
8 |
+
# Calculate the entropy of the string
|
9 |
+
entropy = 0
|
10 |
+
for c in string.printable:
|
11 |
+
p = s.count(c) / len(s)
|
12 |
+
if p > 0:
|
13 |
+
entropy += -p * math.log2(p)
|
14 |
+
|
15 |
+
# Check if the entropy is within a reasonable range for text
|
16 |
+
if entropy > thresh:
|
17 |
+
return True
|
18 |
+
return False
|
19 |
+
|
20 |
+
|
21 |
+
def maybe_is_code(s):
|
22 |
+
if len(s) == 0:
|
23 |
+
return False
|
24 |
+
# Check if the string contains a lot of non-ascii characters
|
25 |
+
if len([c for c in s if ord(c) > 128]) / len(s) > 0.1:
|
26 |
+
return True
|
27 |
+
return False
|
28 |
+
|
29 |
+
|
30 |
+
def strings_similarity(s1, s2):
|
31 |
+
if len(s1) == 0 or len(s2) == 0:
|
32 |
+
return 0
|
33 |
+
# break the strings into words
|
34 |
+
s1 = set(s1.split())
|
35 |
+
s2 = set(s2.split())
|
36 |
+
# return the similarity ratio
|
37 |
+
return len(s1.intersection(s2)) / len(s1.union(s2))
|
38 |
+
|
39 |
+
|
40 |
+
def maybe_is_truncated(s):
|
41 |
+
punct = [".", "!", "?", '"']
|
42 |
+
if s[-1] in punct:
|
43 |
+
return False
|
44 |
+
return True
|
45 |
+
|
46 |
+
|
47 |
+
def maybe_is_html(s):
|
48 |
+
if len(s) == 0:
|
49 |
+
return False
|
50 |
+
# check for html tags
|
51 |
+
if "<body" in s or "<html" in s or "<div" in s:
|
52 |
+
return True
|