Spaces:
Runtime error
Runtime error
Commit
·
1ce95c4
0
Parent(s):
Duplicate from hlydecker/Augmented-Retrieval-qa-ChatGPT
Browse files- .gitattributes +34 -0
- .gitignore +7 -0
- README.md +15 -0
- __init__.py +0 -0
- requirements.txt +0 -0
- static/__init__.py +0 -0
- static/mini_nttdata.jpg +0 -0
- streamlit_langchain_chat/__init__.py +1 -0
- streamlit_langchain_chat/__version__.py +1 -0
- streamlit_langchain_chat/constants.py +53 -0
- streamlit_langchain_chat/customized_langchain/__init__.py +10 -0
- streamlit_langchain_chat/customized_langchain/docstore/__init__.py +7 -0
- streamlit_langchain_chat/customized_langchain/docstore/in_memory.py +27 -0
- streamlit_langchain_chat/customized_langchain/indexes/__init__.py +7 -0
- streamlit_langchain_chat/customized_langchain/indexes/graph.py +20 -0
- streamlit_langchain_chat/customized_langchain/llms/__init__.py +1 -0
- streamlit_langchain_chat/customized_langchain/llms/openai.py +708 -0
- streamlit_langchain_chat/customized_langchain/vectorstores/__init__.py +8 -0
- streamlit_langchain_chat/customized_langchain/vectorstores/faiss.py +100 -0
- streamlit_langchain_chat/customized_langchain/vectorstores/pinecone.py +79 -0
- streamlit_langchain_chat/dataset.py +740 -0
- streamlit_langchain_chat/inputs/__init__.py +0 -0
- streamlit_langchain_chat/prompts.py +91 -0
- streamlit_langchain_chat/streamlit_app.py +561 -0
- streamlit_langchain_chat/utils.py +52 -0
.gitattributes
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
venv*/
|
| 2 |
+
tempDir/
|
| 3 |
+
.idea/
|
| 4 |
+
*.env
|
| 5 |
+
*.pkl
|
| 6 |
+
*.pickle
|
| 7 |
+
*testing*.py
|
README.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Augmented Retrieval Qa ChatGPT
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: 1.19.0
|
| 8 |
+
app_file: streamlit_langchain_chat/streamlit_app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
python_version: 3.10.4
|
| 11 |
+
license: cc-by-nc-sa-4.0
|
| 12 |
+
duplicated_from: hlydecker/Augmented-Retrieval-qa-ChatGPT
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
__init__.py
ADDED
|
File without changes
|
requirements.txt
ADDED
|
Binary file (5.03 kB). View file
|
|
|
static/__init__.py
ADDED
|
File without changes
|
static/mini_nttdata.jpg
ADDED
|
streamlit_langchain_chat/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
streamlit_langchain_chat/__version__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__VERSION__ = "1.0.4"
|
streamlit_langchain_chat/constants.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
# from dotenv import load_dotenv, find_dotenv # pip install python-dotenv==1.0.0
|
| 5 |
+
|
| 6 |
+
from __version__ import __VERSION__ as APP_VERSION
|
| 7 |
+
|
| 8 |
+
_SCRIPT_PATH = Path(__file__).absolute()
|
| 9 |
+
PARENT_APP_DIR = _SCRIPT_PATH.parent
|
| 10 |
+
TEMP_DIR = PARENT_APP_DIR / 'tempDir'
|
| 11 |
+
ROOT_DIR = PARENT_APP_DIR.parent
|
| 12 |
+
STATIC_DIR = ROOT_DIR / 'static'
|
| 13 |
+
|
| 14 |
+
# _env_file_path = find_dotenv(str(CODE_DIR / '.env')) # Check if this path is correct
|
| 15 |
+
# if _env_file_path:
|
| 16 |
+
# load_dotenv(_env_file_path)
|
| 17 |
+
|
| 18 |
+
ST_CONFIG = {
|
| 19 |
+
"page_title": "NTT Data - Chat Q&A",
|
| 20 |
+
# "page_icon": Image.open(STATIC_DIR / "mini_nttdata.jpg"),
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
OPERATING_MODE = "debug" # debug, preproduction, production
|
| 24 |
+
|
| 25 |
+
REUSE_ANSWERS = False
|
| 26 |
+
|
| 27 |
+
LOAD_INDEX_LOCALLY = False
|
| 28 |
+
SAVE_INDEX_LOCALLY = False
|
| 29 |
+
|
| 30 |
+
# x$ per 1000 tokens
|
| 31 |
+
PRICES = {
|
| 32 |
+
'text-embedding-ada-002': 0.0004,
|
| 33 |
+
'text-davinci-003': 0.02,
|
| 34 |
+
'gpt-3': 0.002,
|
| 35 |
+
'gpt-4': 0.06, # 8K context
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
SOURCES_IDS = {
|
| 39 |
+
# "Without source. Only chat": 4,
|
| 40 |
+
"local files": 1,
|
| 41 |
+
"urls": 3
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
TYPE_IDS = {
|
| 45 |
+
"MSF Azure OpenAI Service": 1,
|
| 46 |
+
"OpenAI": 2,
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
INDEX_IDS = {
|
| 51 |
+
"FAISS": 1,
|
| 52 |
+
"Pinecode": 2,
|
| 53 |
+
}
|
streamlit_langchain_chat/customized_langchain/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from streamlit_langchain_chat.customized_langchain.docstore.in_memory import InMemoryDocstore
|
| 2 |
+
from streamlit_langchain_chat.customized_langchain.vectorstores import FAISS
|
| 3 |
+
from streamlit_langchain_chat.customized_langchain.vectorstores import Pinecone
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"FAISS",
|
| 8 |
+
"InMemoryDocstore",
|
| 9 |
+
"Pinecone",
|
| 10 |
+
]
|
streamlit_langchain_chat/customized_langchain/docstore/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Wrappers on top of docstores."""
|
| 2 |
+
from streamlit_langchain_chat.customized_langchain.docstore.in_memory import InMemoryDocstore
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"InMemoryDocstore",
|
| 7 |
+
]
|
streamlit_langchain_chat/customized_langchain/docstore/in_memory.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Simple in memory docstore in the form of a dict."""
|
| 2 |
+
from typing import Dict, Union
|
| 3 |
+
|
| 4 |
+
from langchain.docstore.base import AddableMixin, Docstore
|
| 5 |
+
from langchain.docstore.document import Document
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class InMemoryDocstore(Docstore, AddableMixin):
|
| 9 |
+
"""Simple in memory docstore in the form of a dict."""
|
| 10 |
+
|
| 11 |
+
def __init__(self, dict_: Dict[str, Document]):
|
| 12 |
+
"""Initialize with dict."""
|
| 13 |
+
self.dict_ = dict_
|
| 14 |
+
|
| 15 |
+
def add(self, texts: Dict[str, Document]) -> None:
|
| 16 |
+
"""Add texts to in memory dictionary."""
|
| 17 |
+
overlapping = set(texts).intersection(self.dict_)
|
| 18 |
+
if overlapping:
|
| 19 |
+
raise ValueError(f"Tried to add ids that already exist: {overlapping}")
|
| 20 |
+
self.dict_ = dict(self.dict_, **texts)
|
| 21 |
+
|
| 22 |
+
def search(self, search: str) -> Union[str, Document]:
|
| 23 |
+
"""Search via direct lookup."""
|
| 24 |
+
if search not in self.dict_:
|
| 25 |
+
return f"ID {search} not found."
|
| 26 |
+
else:
|
| 27 |
+
return self.dict_[search]
|
streamlit_langchain_chat/customized_langchain/indexes/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from streamlit_langchain_chat.customized_langchain.indexes.graph import GraphIndexCreator
|
| 2 |
+
# from streamlit_langchain_chat.customized_langchain.vectorstore import VectorstoreIndexCreator
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"GraphIndexCreator",
|
| 6 |
+
# "VectorstoreIndexCreator"
|
| 7 |
+
]
|
streamlit_langchain_chat/customized_langchain/indexes/graph.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from langchain.indexes.graph import *
|
| 4 |
+
from langchain.indexes.graph import GraphIndexCreator as OriginalGraphIndexCreator
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class GraphIndexCreator(OriginalGraphIndexCreator):
|
| 8 |
+
def from_texts(self, texts: List[str]) -> NetworkxEntityGraph:
|
| 9 |
+
"""Create graph index from text."""
|
| 10 |
+
if self.llm is None:
|
| 11 |
+
raise ValueError("llm should not be None")
|
| 12 |
+
graph = self.graph_type()
|
| 13 |
+
chain = LLMChain(llm=self.llm, prompt=KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT)
|
| 14 |
+
|
| 15 |
+
for text in texts:
|
| 16 |
+
output = chain.predict(text=text)
|
| 17 |
+
knowledge = parse_triples(output)
|
| 18 |
+
for triple in knowledge:
|
| 19 |
+
graph.add_triple(triple)
|
| 20 |
+
return graph
|
streamlit_langchain_chat/customized_langchain/llms/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from streamlit_langchain_chat.customized_langchain.llms.openai import AzureOpenAI, OpenAI, OpenAIChat, AzureOpenAIChat
|
streamlit_langchain_chat/customized_langchain/llms/openai.py
ADDED
|
@@ -0,0 +1,708 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Wrapper around OpenAI APIs."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
import sys
|
| 6 |
+
from typing import (
|
| 7 |
+
Any,
|
| 8 |
+
Callable,
|
| 9 |
+
Dict,
|
| 10 |
+
Generator,
|
| 11 |
+
List,
|
| 12 |
+
Mapping,
|
| 13 |
+
Optional,
|
| 14 |
+
Set,
|
| 15 |
+
Tuple,
|
| 16 |
+
Union,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
from pydantic import BaseModel, Extra, Field, root_validator
|
| 20 |
+
from tenacity import (
|
| 21 |
+
before_sleep_log,
|
| 22 |
+
retry,
|
| 23 |
+
retry_if_exception_type,
|
| 24 |
+
stop_after_attempt,
|
| 25 |
+
wait_exponential,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
from langchain.llms.base import BaseLLM
|
| 29 |
+
from langchain.schema import Generation, LLMResult
|
| 30 |
+
from langchain.utils import get_from_dict_or_env
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def update_token_usage(
|
| 36 |
+
keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
|
| 37 |
+
) -> None:
|
| 38 |
+
"""Update token usage."""
|
| 39 |
+
_keys_to_use = keys.intersection(response["usage"])
|
| 40 |
+
for _key in _keys_to_use:
|
| 41 |
+
if _key not in token_usage:
|
| 42 |
+
token_usage[_key] = response["usage"][_key]
|
| 43 |
+
else:
|
| 44 |
+
token_usage[_key] += response["usage"][_key]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
|
| 48 |
+
"""Update response from the stream response."""
|
| 49 |
+
response["choices"][0]["text"] += stream_response["choices"][0]["text"]
|
| 50 |
+
response["choices"][0]["finish_reason"] = stream_response["choices"][0][
|
| 51 |
+
"finish_reason"
|
| 52 |
+
]
|
| 53 |
+
response["choices"][0]["logprobs"] = stream_response["choices"][0]["logprobs"]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _streaming_response_template() -> Dict[str, Any]:
|
| 57 |
+
return {
|
| 58 |
+
"choices": [
|
| 59 |
+
{
|
| 60 |
+
"text": "",
|
| 61 |
+
"finish_reason": None,
|
| 62 |
+
"logprobs": None,
|
| 63 |
+
}
|
| 64 |
+
]
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any], Any]:
|
| 69 |
+
import openai
|
| 70 |
+
|
| 71 |
+
min_seconds = 4
|
| 72 |
+
max_seconds = 10
|
| 73 |
+
# Wait 2^x * 1 second between each retry starting with
|
| 74 |
+
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
| 75 |
+
return retry(
|
| 76 |
+
reraise=True,
|
| 77 |
+
stop=stop_after_attempt(llm.max_retries),
|
| 78 |
+
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
| 79 |
+
retry=(
|
| 80 |
+
retry_if_exception_type(openai.error.Timeout)
|
| 81 |
+
| retry_if_exception_type(openai.error.APIError)
|
| 82 |
+
| retry_if_exception_type(openai.error.APIConnectionError)
|
| 83 |
+
| retry_if_exception_type(openai.error.RateLimitError)
|
| 84 |
+
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
| 85 |
+
),
|
| 86 |
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) -> Any:
|
| 91 |
+
"""Use tenacity to retry the completion call."""
|
| 92 |
+
retry_decorator = _create_retry_decorator(llm)
|
| 93 |
+
|
| 94 |
+
@retry_decorator
|
| 95 |
+
def _completion_with_retry(**kwargs: Any) -> Any:
|
| 96 |
+
return llm.client.create(**kwargs)
|
| 97 |
+
|
| 98 |
+
return _completion_with_retry(**kwargs)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
async def acompletion_with_retry(
|
| 102 |
+
llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any
|
| 103 |
+
) -> Any:
|
| 104 |
+
"""Use tenacity to retry the async completion call."""
|
| 105 |
+
retry_decorator = _create_retry_decorator(llm)
|
| 106 |
+
|
| 107 |
+
@retry_decorator
|
| 108 |
+
async def _completion_with_retry(**kwargs: Any) -> Any:
|
| 109 |
+
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
|
| 110 |
+
return await llm.client.acreate(**kwargs)
|
| 111 |
+
|
| 112 |
+
return await _completion_with_retry(**kwargs)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class BaseOpenAI(BaseLLM, BaseModel):
|
| 116 |
+
"""Wrapper around OpenAI large language models.
|
| 117 |
+
|
| 118 |
+
To use, you should have the ``openai`` python package installed, and the
|
| 119 |
+
environment variable ``OPENAI_API_KEY`` set with your API key.
|
| 120 |
+
|
| 121 |
+
Any parameters that are valid to be passed to the openai.create call can be passed
|
| 122 |
+
in, even if not explicitly saved on this class.
|
| 123 |
+
|
| 124 |
+
Example:
|
| 125 |
+
.. code-block:: python
|
| 126 |
+
|
| 127 |
+
from langchain.llms import OpenAI
|
| 128 |
+
openai = OpenAI(model_name="text-davinci-003")
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
client: Any #: :meta private:
|
| 132 |
+
model_name: str = "text-davinci-003"
|
| 133 |
+
"""Model name to use."""
|
| 134 |
+
temperature: float = 0.7
|
| 135 |
+
"""What sampling temperature to use."""
|
| 136 |
+
max_tokens: int = 256
|
| 137 |
+
"""The maximum number of tokens to generate in the completion.
|
| 138 |
+
-1 returns as many tokens as possible given the prompt and
|
| 139 |
+
the models maximal context size."""
|
| 140 |
+
top_p: float = 1
|
| 141 |
+
"""Total probability mass of tokens to consider at each step."""
|
| 142 |
+
frequency_penalty: float = 0
|
| 143 |
+
"""Penalizes repeated tokens according to frequency."""
|
| 144 |
+
presence_penalty: float = 0
|
| 145 |
+
"""Penalizes repeated tokens."""
|
| 146 |
+
n: int = 1
|
| 147 |
+
"""How many completions to generate for each prompt."""
|
| 148 |
+
best_of: int = 1
|
| 149 |
+
"""Generates best_of completions server-side and returns the "best"."""
|
| 150 |
+
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
| 151 |
+
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
| 152 |
+
openai_api_key: Optional[str] = None
|
| 153 |
+
batch_size: int = 20
|
| 154 |
+
"""Batch size to use when passing multiple documents to generate."""
|
| 155 |
+
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
| 156 |
+
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
| 157 |
+
logit_bias: Optional[Dict[str, float]] = Field(default_factory=dict)
|
| 158 |
+
"""Adjust the probability of specific tokens being generated."""
|
| 159 |
+
max_retries: int = 6
|
| 160 |
+
"""Maximum number of retries to make when generating."""
|
| 161 |
+
streaming: bool = False
|
| 162 |
+
"""Whether to stream the results or not."""
|
| 163 |
+
|
| 164 |
+
def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
|
| 165 |
+
"""Initialize the OpenAI object."""
|
| 166 |
+
if data.get("model_name", "").startswith("gpt-3.5-turbo"):
|
| 167 |
+
return OpenAIChat(**data)
|
| 168 |
+
return super().__new__(cls)
|
| 169 |
+
|
| 170 |
+
class Config:
|
| 171 |
+
"""Configuration for this pydantic object."""
|
| 172 |
+
|
| 173 |
+
extra = Extra.ignore
|
| 174 |
+
|
| 175 |
+
@root_validator(pre=True, allow_reuse=True)
|
| 176 |
+
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
| 177 |
+
"""Build extra kwargs from additional params that were passed in."""
|
| 178 |
+
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
| 179 |
+
|
| 180 |
+
extra = values.get("model_kwargs", {})
|
| 181 |
+
for field_name in list(values):
|
| 182 |
+
if field_name not in all_required_field_names:
|
| 183 |
+
if field_name in extra:
|
| 184 |
+
raise ValueError(f"Found {field_name} supplied twice.")
|
| 185 |
+
logger.warning(
|
| 186 |
+
f"""WARNING! {field_name} is not default parameter.
|
| 187 |
+
{field_name} was transfered to model_kwargs.
|
| 188 |
+
Please confirm that {field_name} is what you intended."""
|
| 189 |
+
)
|
| 190 |
+
extra[field_name] = values.pop(field_name)
|
| 191 |
+
values["model_kwargs"] = extra
|
| 192 |
+
return values
|
| 193 |
+
|
| 194 |
+
@root_validator(allow_reuse=True)
|
| 195 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
| 196 |
+
"""Validate that api key and python package exists in environment."""
|
| 197 |
+
openai_api_key = get_from_dict_or_env(
|
| 198 |
+
values, "openai_api_key", "OPENAI_API_KEY"
|
| 199 |
+
)
|
| 200 |
+
try:
|
| 201 |
+
import openai
|
| 202 |
+
|
| 203 |
+
openai.api_key = openai_api_key
|
| 204 |
+
values["client"] = openai.Completion
|
| 205 |
+
except ImportError:
|
| 206 |
+
raise ValueError(
|
| 207 |
+
"Could not import openai python package. "
|
| 208 |
+
"Please it install it with `pip install openai`."
|
| 209 |
+
)
|
| 210 |
+
if values["streaming"] and values["n"] > 1:
|
| 211 |
+
raise ValueError("Cannot stream results when n > 1.")
|
| 212 |
+
if values["streaming"] and values.get("best_of") and values["best_of"] > 1:
|
| 213 |
+
raise ValueError("Cannot stream results when best_of > 1.")
|
| 214 |
+
return values
|
| 215 |
+
|
| 216 |
+
@property
|
| 217 |
+
def _default_params(self) -> Dict[str, Any]:
|
| 218 |
+
"""Get the default parameters for calling OpenAI API."""
|
| 219 |
+
normal_params = {
|
| 220 |
+
"temperature": self.temperature,
|
| 221 |
+
"max_tokens": self.max_tokens,
|
| 222 |
+
"top_p": self.top_p,
|
| 223 |
+
"frequency_penalty": self.frequency_penalty,
|
| 224 |
+
"presence_penalty": self.presence_penalty,
|
| 225 |
+
"n": self.n,
|
| 226 |
+
# "best_of": self.best_of,
|
| 227 |
+
"request_timeout": self.request_timeout,
|
| 228 |
+
"logit_bias": self.logit_bias,
|
| 229 |
+
}
|
| 230 |
+
return {**normal_params, **self.model_kwargs}
|
| 231 |
+
|
| 232 |
+
def _generate(
|
| 233 |
+
self, prompts: List[str], stop: Optional[List[str]] = None
|
| 234 |
+
) -> LLMResult:
|
| 235 |
+
"""Call out to OpenAI's endpoint with k unique prompts.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
prompts: The prompts to pass into the model.
|
| 239 |
+
stop: Optional list of stop words to use when generating.
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
The full LLM output.
|
| 243 |
+
|
| 244 |
+
Example:
|
| 245 |
+
.. code-block:: python
|
| 246 |
+
|
| 247 |
+
response = openai.generate(["Tell me a joke."])
|
| 248 |
+
"""
|
| 249 |
+
# TODO: write a unit test for this
|
| 250 |
+
params = self._invocation_params
|
| 251 |
+
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
| 252 |
+
choices = []
|
| 253 |
+
token_usage: Dict[str, int] = {}
|
| 254 |
+
# Get the token usage from the response.
|
| 255 |
+
# Includes prompt, completion, and total tokens used.
|
| 256 |
+
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
| 257 |
+
for _prompts in sub_prompts:
|
| 258 |
+
if self.streaming:
|
| 259 |
+
if len(_prompts) > 1:
|
| 260 |
+
raise ValueError("Cannot stream results with multiple prompts.")
|
| 261 |
+
params["stream"] = True
|
| 262 |
+
response = _streaming_response_template()
|
| 263 |
+
for stream_resp in completion_with_retry(
|
| 264 |
+
self, prompt=_prompts, **params
|
| 265 |
+
):
|
| 266 |
+
self.callback_manager.on_llm_new_token(
|
| 267 |
+
stream_resp["choices"][0]["text"],
|
| 268 |
+
verbose=self.verbose,
|
| 269 |
+
logprobs=stream_resp["choices"][0]["logprobs"],
|
| 270 |
+
)
|
| 271 |
+
_update_response(response, stream_resp)
|
| 272 |
+
choices.extend(response["choices"])
|
| 273 |
+
else:
|
| 274 |
+
response = completion_with_retry(self, prompt=_prompts, **params)
|
| 275 |
+
choices.extend(response["choices"])
|
| 276 |
+
if not self.streaming:
|
| 277 |
+
# Can't update token usage if streaming
|
| 278 |
+
update_token_usage(_keys, response, token_usage)
|
| 279 |
+
return self.create_llm_result(choices, prompts, token_usage)
|
| 280 |
+
|
| 281 |
+
async def _agenerate(
|
| 282 |
+
self, prompts: List[str], stop: Optional[List[str]] = None
|
| 283 |
+
) -> LLMResult:
|
| 284 |
+
"""Call out to OpenAI's endpoint async with k unique prompts."""
|
| 285 |
+
params = self._invocation_params
|
| 286 |
+
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
| 287 |
+
choices = []
|
| 288 |
+
token_usage: Dict[str, int] = {}
|
| 289 |
+
# Get the token usage from the response.
|
| 290 |
+
# Includes prompt, completion, and total tokens used.
|
| 291 |
+
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
| 292 |
+
for _prompts in sub_prompts:
|
| 293 |
+
if self.streaming:
|
| 294 |
+
if len(_prompts) > 1:
|
| 295 |
+
raise ValueError("Cannot stream results with multiple prompts.")
|
| 296 |
+
params["stream"] = True
|
| 297 |
+
response = _streaming_response_template()
|
| 298 |
+
async for stream_resp in await acompletion_with_retry(
|
| 299 |
+
self, prompt=_prompts, **params
|
| 300 |
+
):
|
| 301 |
+
if self.callback_manager.is_async:
|
| 302 |
+
await self.callback_manager.on_llm_new_token(
|
| 303 |
+
stream_resp["choices"][0]["text"],
|
| 304 |
+
verbose=self.verbose,
|
| 305 |
+
logprobs=stream_resp["choices"][0]["logprobs"],
|
| 306 |
+
)
|
| 307 |
+
else:
|
| 308 |
+
self.callback_manager.on_llm_new_token(
|
| 309 |
+
stream_resp["choices"][0]["text"],
|
| 310 |
+
verbose=self.verbose,
|
| 311 |
+
logprobs=stream_resp["choices"][0]["logprobs"],
|
| 312 |
+
)
|
| 313 |
+
_update_response(response, stream_resp)
|
| 314 |
+
choices.extend(response["choices"])
|
| 315 |
+
else:
|
| 316 |
+
response = await acompletion_with_retry(self, prompt=_prompts, **params)
|
| 317 |
+
choices.extend(response["choices"])
|
| 318 |
+
if not self.streaming:
|
| 319 |
+
# Can't update token usage if streaming
|
| 320 |
+
update_token_usage(_keys, response, token_usage)
|
| 321 |
+
return self.create_llm_result(choices, prompts, token_usage)
|
| 322 |
+
|
| 323 |
+
def get_sub_prompts(
|
| 324 |
+
self,
|
| 325 |
+
params: Dict[str, Any],
|
| 326 |
+
prompts: List[str],
|
| 327 |
+
stop: Optional[List[str]] = None,
|
| 328 |
+
) -> List[List[str]]:
|
| 329 |
+
"""Get the sub prompts for llm call."""
|
| 330 |
+
if stop is not None:
|
| 331 |
+
if "stop" in params:
|
| 332 |
+
raise ValueError("`stop` found in both the input and default params.")
|
| 333 |
+
params["stop"] = stop
|
| 334 |
+
if params["max_tokens"] == -1:
|
| 335 |
+
if len(prompts) != 1:
|
| 336 |
+
raise ValueError(
|
| 337 |
+
"max_tokens set to -1 not supported for multiple inputs."
|
| 338 |
+
)
|
| 339 |
+
params["max_tokens"] = self.max_tokens_for_prompt(prompts[0])
|
| 340 |
+
sub_prompts = [
|
| 341 |
+
prompts[i : i + self.batch_size]
|
| 342 |
+
for i in range(0, len(prompts), self.batch_size)
|
| 343 |
+
]
|
| 344 |
+
return sub_prompts
|
| 345 |
+
|
| 346 |
+
def create_llm_result(
|
| 347 |
+
self, choices: Any, prompts: List[str], token_usage: Dict[str, int]
|
| 348 |
+
) -> LLMResult:
|
| 349 |
+
"""Create the LLMResult from the choices and prompts."""
|
| 350 |
+
generations = []
|
| 351 |
+
for i, _ in enumerate(prompts):
|
| 352 |
+
sub_choices = choices[i * self.n : (i + 1) * self.n]
|
| 353 |
+
generations.append(
|
| 354 |
+
[
|
| 355 |
+
Generation(
|
| 356 |
+
text=choice["text"],
|
| 357 |
+
generation_info=dict(
|
| 358 |
+
finish_reason=choice.get("finish_reason"),
|
| 359 |
+
logprobs=choice.get("logprobs"),
|
| 360 |
+
),
|
| 361 |
+
)
|
| 362 |
+
for choice in sub_choices
|
| 363 |
+
]
|
| 364 |
+
)
|
| 365 |
+
return LLMResult(
|
| 366 |
+
generations=generations, llm_output={"token_usage": token_usage}
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
|
| 370 |
+
"""Call OpenAI with streaming flag and return the resulting generator.
|
| 371 |
+
|
| 372 |
+
BETA: this is a beta feature while we figure out the right abstraction.
|
| 373 |
+
Once that happens, this interface could change.
|
| 374 |
+
|
| 375 |
+
Args:
|
| 376 |
+
prompt: The prompts to pass into the model.
|
| 377 |
+
stop: Optional list of stop words to use when generating.
|
| 378 |
+
|
| 379 |
+
Returns:
|
| 380 |
+
A generator representing the stream of tokens from OpenAI.
|
| 381 |
+
|
| 382 |
+
Example:
|
| 383 |
+
.. code-block:: python
|
| 384 |
+
|
| 385 |
+
generator = openai.stream("Tell me a joke.")
|
| 386 |
+
for token in generator:
|
| 387 |
+
yield token
|
| 388 |
+
"""
|
| 389 |
+
params = self.prep_streaming_params(stop)
|
| 390 |
+
generator = self.client.create(prompt=prompt, **params)
|
| 391 |
+
|
| 392 |
+
return generator
|
| 393 |
+
|
| 394 |
+
def prep_streaming_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
|
| 395 |
+
"""Prepare the params for streaming."""
|
| 396 |
+
params = self._invocation_params
|
| 397 |
+
if params.get('best_of') and params["best_of"] != 1:
|
| 398 |
+
raise ValueError("OpenAI only supports best_of == 1 for streaming")
|
| 399 |
+
if stop is not None:
|
| 400 |
+
if "stop" in params:
|
| 401 |
+
raise ValueError("`stop` found in both the input and default params.")
|
| 402 |
+
params["stop"] = stop
|
| 403 |
+
params["stream"] = True
|
| 404 |
+
return params
|
| 405 |
+
|
| 406 |
+
@property
|
| 407 |
+
def _invocation_params(self) -> Dict[str, Any]:
|
| 408 |
+
"""Get the parameters used to invoke the model."""
|
| 409 |
+
return self._default_params
|
| 410 |
+
|
| 411 |
+
@property
|
| 412 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
| 413 |
+
"""Get the identifying parameters."""
|
| 414 |
+
return {**{"model_name": self.model_name}, **self._default_params}
|
| 415 |
+
|
| 416 |
+
@property
|
| 417 |
+
def _llm_type(self) -> str:
|
| 418 |
+
"""Return type of llm."""
|
| 419 |
+
return "openai"
|
| 420 |
+
|
| 421 |
+
def get_num_tokens(self, text: str) -> int:
|
| 422 |
+
"""Calculate num tokens with tiktoken package."""
|
| 423 |
+
# tiktoken NOT supported for Python 3.8 or below
|
| 424 |
+
if sys.version_info[1] <= 8:
|
| 425 |
+
return super().get_num_tokens(text)
|
| 426 |
+
try:
|
| 427 |
+
import tiktoken
|
| 428 |
+
except ImportError:
|
| 429 |
+
raise ValueError(
|
| 430 |
+
"Could not import tiktoken python package. "
|
| 431 |
+
"This is needed in order to calculate get_num_tokens. "
|
| 432 |
+
"Please it install it with `pip install tiktoken`."
|
| 433 |
+
)
|
| 434 |
+
encoder = "gpt2"
|
| 435 |
+
if self.model_name in ("text-davinci-003", "text-davinci-002"):
|
| 436 |
+
encoder = "p50k_base"
|
| 437 |
+
if self.model_name.startswith("code"):
|
| 438 |
+
encoder = "p50k_base"
|
| 439 |
+
# create a GPT-3 encoder instance
|
| 440 |
+
enc = tiktoken.get_encoding(encoder)
|
| 441 |
+
|
| 442 |
+
# encode the text using the GPT-3 encoder
|
| 443 |
+
tokenized_text = enc.encode(text)
|
| 444 |
+
|
| 445 |
+
# calculate the number of tokens in the encoded text
|
| 446 |
+
return len(tokenized_text)
|
| 447 |
+
|
| 448 |
+
def modelname_to_contextsize(self, modelname: str) -> int:
|
| 449 |
+
"""Calculate the maximum number of tokens possible to generate for a model.
|
| 450 |
+
|
| 451 |
+
text-davinci-003: 4,097 tokens
|
| 452 |
+
text-curie-001: 2,048 tokens
|
| 453 |
+
text-babbage-001: 2,048 tokens
|
| 454 |
+
text-ada-001: 2,048 tokens
|
| 455 |
+
code-davinci-002: 8,000 tokens
|
| 456 |
+
code-cushman-001: 2,048 tokens
|
| 457 |
+
|
| 458 |
+
Args:
|
| 459 |
+
modelname: The modelname we want to know the context size for.
|
| 460 |
+
|
| 461 |
+
Returns:
|
| 462 |
+
The maximum context size
|
| 463 |
+
|
| 464 |
+
Example:
|
| 465 |
+
.. code-block:: python
|
| 466 |
+
|
| 467 |
+
max_tokens = openai.modelname_to_contextsize("text-davinci-003")
|
| 468 |
+
"""
|
| 469 |
+
if modelname == "text-davinci-003":
|
| 470 |
+
return 4097
|
| 471 |
+
elif modelname == "text-curie-001":
|
| 472 |
+
return 2048
|
| 473 |
+
elif modelname == "text-babbage-001":
|
| 474 |
+
return 2048
|
| 475 |
+
elif modelname == "text-ada-001":
|
| 476 |
+
return 2048
|
| 477 |
+
elif modelname == "code-davinci-002":
|
| 478 |
+
return 8000
|
| 479 |
+
elif modelname == "code-cushman-001":
|
| 480 |
+
return 2048
|
| 481 |
+
else:
|
| 482 |
+
return 4097
|
| 483 |
+
|
| 484 |
+
def max_tokens_for_prompt(self, prompt: str) -> int:
|
| 485 |
+
"""Calculate the maximum number of tokens possible to generate for a prompt.
|
| 486 |
+
|
| 487 |
+
Args:
|
| 488 |
+
prompt: The prompt to pass into the model.
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
The maximum number of tokens to generate for a prompt.
|
| 492 |
+
|
| 493 |
+
Example:
|
| 494 |
+
.. code-block:: python
|
| 495 |
+
|
| 496 |
+
max_tokens = openai.max_token_for_prompt("Tell me a joke.")
|
| 497 |
+
"""
|
| 498 |
+
num_tokens = self.get_num_tokens(prompt)
|
| 499 |
+
|
| 500 |
+
# get max context size for model by name
|
| 501 |
+
max_size = self.modelname_to_contextsize(self.model_name)
|
| 502 |
+
return max_size - num_tokens
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
class OpenAI(BaseOpenAI):
|
| 506 |
+
"""Generic OpenAI class that uses model name."""
|
| 507 |
+
|
| 508 |
+
@property
|
| 509 |
+
def _invocation_params(self) -> Dict[str, Any]:
|
| 510 |
+
return {**{"model": self.model_name}, **super()._invocation_params}
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
class AzureOpenAI(BaseOpenAI):
|
| 514 |
+
"""Azure specific OpenAI class that uses deployment name."""
|
| 515 |
+
|
| 516 |
+
deployment_name: str = ""
|
| 517 |
+
"""Deployment name to use."""
|
| 518 |
+
|
| 519 |
+
@property
|
| 520 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
| 521 |
+
return {
|
| 522 |
+
**{"deployment_name": self.deployment_name},
|
| 523 |
+
**super()._identifying_params,
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
@property
|
| 527 |
+
def _invocation_params(self) -> Dict[str, Any]:
|
| 528 |
+
return {**{"engine": self.deployment_name}, **super()._invocation_params}
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
class OpenAIChat(BaseLLM, BaseModel):
|
| 532 |
+
"""Wrapper around OpenAI Chat large language models.
|
| 533 |
+
|
| 534 |
+
To use, you should have the ``openai`` python package installed, and the
|
| 535 |
+
environment variable ``OPENAI_API_KEY`` set with your API key.
|
| 536 |
+
|
| 537 |
+
Any parameters that are valid to be passed to the openai.create call can be passed
|
| 538 |
+
in, even if not explicitly saved on this class.
|
| 539 |
+
|
| 540 |
+
Example:
|
| 541 |
+
.. code-block:: python
|
| 542 |
+
|
| 543 |
+
from langchain.llms import OpenAIChat
|
| 544 |
+
openaichat = OpenAIChat(model_name="gpt-3.5-turbo")
|
| 545 |
+
"""
|
| 546 |
+
|
| 547 |
+
client: Any #: :meta private:
|
| 548 |
+
model_name: str = "gpt-3.5-turbo"
|
| 549 |
+
"""Model name to use."""
|
| 550 |
+
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
| 551 |
+
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
| 552 |
+
openai_api_key: Optional[str] = None
|
| 553 |
+
max_retries: int = 6
|
| 554 |
+
"""Maximum number of retries to make when generating."""
|
| 555 |
+
prefix_messages: List = Field(default_factory=list)
|
| 556 |
+
"""Series of messages for Chat input."""
|
| 557 |
+
streaming: bool = False
|
| 558 |
+
"""Whether to stream the results or not."""
|
| 559 |
+
|
| 560 |
+
class Config:
|
| 561 |
+
"""Configuration for this pydantic object."""
|
| 562 |
+
|
| 563 |
+
extra = Extra.ignore
|
| 564 |
+
|
| 565 |
+
@root_validator(pre=True, allow_reuse=True)
|
| 566 |
+
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
| 567 |
+
"""Build extra kwargs from additional params that were passed in."""
|
| 568 |
+
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
| 569 |
+
|
| 570 |
+
extra = values.get("model_kwargs", {})
|
| 571 |
+
for field_name in list(values):
|
| 572 |
+
if field_name not in all_required_field_names:
|
| 573 |
+
if field_name in extra:
|
| 574 |
+
raise ValueError(f"Found {field_name} supplied twice.")
|
| 575 |
+
extra[field_name] = values.pop(field_name)
|
| 576 |
+
values["model_kwargs"] = extra
|
| 577 |
+
return values
|
| 578 |
+
|
| 579 |
+
@root_validator(allow_reuse=True)
|
| 580 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
| 581 |
+
"""Validate that api key and python package exists in environment."""
|
| 582 |
+
openai_api_key = get_from_dict_or_env(
|
| 583 |
+
values, "openai_api_key", "OPENAI_API_KEY"
|
| 584 |
+
)
|
| 585 |
+
try:
|
| 586 |
+
import openai
|
| 587 |
+
|
| 588 |
+
openai.api_key = openai_api_key
|
| 589 |
+
except ImportError:
|
| 590 |
+
raise ValueError(
|
| 591 |
+
"Could not import openai python package. "
|
| 592 |
+
"Please it install it with `pip install openai`."
|
| 593 |
+
)
|
| 594 |
+
try:
|
| 595 |
+
values["client"] = openai.ChatCompletion
|
| 596 |
+
except AttributeError:
|
| 597 |
+
raise ValueError(
|
| 598 |
+
"`openai` has no `ChatCompletion` attribute, this is likely "
|
| 599 |
+
"due to an old version of the openai package. Try upgrading it "
|
| 600 |
+
"with `pip install --upgrade openai`."
|
| 601 |
+
)
|
| 602 |
+
return values
|
| 603 |
+
|
| 604 |
+
@property
|
| 605 |
+
def _default_params(self) -> Dict[str, Any]:
|
| 606 |
+
"""Get the default parameters for calling OpenAI API."""
|
| 607 |
+
return self.model_kwargs
|
| 608 |
+
|
| 609 |
+
def _get_chat_params(
|
| 610 |
+
self, prompts: List[str], stop: Optional[List[str]] = None
|
| 611 |
+
) -> Tuple:
|
| 612 |
+
if len(prompts) > 1:
|
| 613 |
+
raise ValueError(
|
| 614 |
+
f"OpenAIChat currently only supports single prompt, got {prompts}"
|
| 615 |
+
)
|
| 616 |
+
messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
|
| 617 |
+
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
|
| 618 |
+
if stop is not None:
|
| 619 |
+
if "stop" in params:
|
| 620 |
+
raise ValueError("`stop` found in both the input and default params.")
|
| 621 |
+
params["stop"] = stop
|
| 622 |
+
return messages, params
|
| 623 |
+
|
| 624 |
+
def _generate(
|
| 625 |
+
self, prompts: List[str], stop: Optional[List[str]] = None
|
| 626 |
+
) -> LLMResult:
|
| 627 |
+
messages, params = self._get_chat_params(prompts, stop)
|
| 628 |
+
if self.streaming:
|
| 629 |
+
response = ""
|
| 630 |
+
params["stream"] = True
|
| 631 |
+
for stream_resp in completion_with_retry(self, messages=messages, **params):
|
| 632 |
+
token = stream_resp["choices"][0]["delta"].get("content", "")
|
| 633 |
+
response += token
|
| 634 |
+
self.callback_manager.on_llm_new_token(
|
| 635 |
+
token,
|
| 636 |
+
verbose=self.verbose,
|
| 637 |
+
)
|
| 638 |
+
return LLMResult(
|
| 639 |
+
generations=[[Generation(text=response)]],
|
| 640 |
+
)
|
| 641 |
+
else:
|
| 642 |
+
full_response = completion_with_retry(self, messages=messages, **params)
|
| 643 |
+
return LLMResult(
|
| 644 |
+
generations=[
|
| 645 |
+
[Generation(text=full_response["choices"][0]["message"]["content"])]
|
| 646 |
+
],
|
| 647 |
+
llm_output={"token_usage": full_response["usage"]},
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
async def _agenerate(
|
| 651 |
+
self, prompts: List[str], stop: Optional[List[str]] = None
|
| 652 |
+
) -> LLMResult:
|
| 653 |
+
messages, params = self._get_chat_params(prompts, stop)
|
| 654 |
+
if self.streaming:
|
| 655 |
+
response = ""
|
| 656 |
+
params["stream"] = True
|
| 657 |
+
async for stream_resp in await acompletion_with_retry(
|
| 658 |
+
self, messages=messages, **params
|
| 659 |
+
):
|
| 660 |
+
token = stream_resp["choices"][0]["delta"].get("content", "")
|
| 661 |
+
response += token
|
| 662 |
+
if self.callback_manager.is_async:
|
| 663 |
+
await self.callback_manager.on_llm_new_token(
|
| 664 |
+
token,
|
| 665 |
+
verbose=self.verbose,
|
| 666 |
+
)
|
| 667 |
+
else:
|
| 668 |
+
self.callback_manager.on_llm_new_token(
|
| 669 |
+
token,
|
| 670 |
+
verbose=self.verbose,
|
| 671 |
+
)
|
| 672 |
+
return LLMResult(
|
| 673 |
+
generations=[[Generation(text=response)]],
|
| 674 |
+
)
|
| 675 |
+
else:
|
| 676 |
+
full_response = await acompletion_with_retry(
|
| 677 |
+
self, messages=messages, **params
|
| 678 |
+
)
|
| 679 |
+
return LLMResult(
|
| 680 |
+
generations=[
|
| 681 |
+
[Generation(text=full_response["choices"][0]["message"]["content"])]
|
| 682 |
+
],
|
| 683 |
+
llm_output={"token_usage": full_response["usage"]},
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
@property
|
| 687 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
| 688 |
+
"""Get the identifying parameters."""
|
| 689 |
+
return {**{"model_name": self.model_name}, **self._default_params}
|
| 690 |
+
|
| 691 |
+
@property
|
| 692 |
+
def _llm_type(self) -> str:
|
| 693 |
+
"""Return type of llm."""
|
| 694 |
+
return "openai-chat"
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
class AzureOpenAIChat(OpenAIChat):
|
| 698 |
+
"""Azure specific OpenAI class that uses deployment name."""
|
| 699 |
+
|
| 700 |
+
deployment_name: str = ""
|
| 701 |
+
"""Deployment name to use."""
|
| 702 |
+
|
| 703 |
+
@property
|
| 704 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
| 705 |
+
return {
|
| 706 |
+
**{"deployment_name": self.deployment_name},
|
| 707 |
+
**super()._identifying_params,
|
| 708 |
+
}
|
streamlit_langchain_chat/customized_langchain/vectorstores/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Wrappers on top of vector stores."""
|
| 2 |
+
from streamlit_langchain_chat.customized_langchain.vectorstores.faiss import FAISS
|
| 3 |
+
from streamlit_langchain_chat.customized_langchain.vectorstores.pinecone import Pinecone
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"FAISS",
|
| 7 |
+
"Pinecone",
|
| 8 |
+
]
|
streamlit_langchain_chat/customized_langchain/vectorstores/faiss.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import hashlib
|
| 2 |
+
|
| 3 |
+
from langchain.vectorstores.faiss import *
|
| 4 |
+
from langchain.vectorstores.faiss import FAISS as OriginalFAISS
|
| 5 |
+
|
| 6 |
+
from streamlit_langchain_chat.customized_langchain.docstore.in_memory import InMemoryDocstore
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class FAISS(OriginalFAISS):
|
| 10 |
+
def __add(
|
| 11 |
+
self,
|
| 12 |
+
texts: Iterable[str],
|
| 13 |
+
embeddings: Iterable[List[float]],
|
| 14 |
+
metadatas: Optional[List[dict]] = None,
|
| 15 |
+
**kwargs: Any,
|
| 16 |
+
) -> List[str]:
|
| 17 |
+
if not isinstance(self.docstore, AddableMixin):
|
| 18 |
+
raise ValueError(
|
| 19 |
+
"If trying to add texts, the underlying docstore should support "
|
| 20 |
+
f"adding items, which {self.docstore} does not"
|
| 21 |
+
)
|
| 22 |
+
documents = []
|
| 23 |
+
for i, text in enumerate(texts):
|
| 24 |
+
metadata = metadatas[i] if metadatas else {}
|
| 25 |
+
documents.append(Document(page_content=text, metadata=metadata))
|
| 26 |
+
# Add to the index, the index_to_id mapping, and the docstore.
|
| 27 |
+
starting_len = len(self.index_to_docstore_id)
|
| 28 |
+
self.index.add(np.array(embeddings, dtype=np.float32))
|
| 29 |
+
# Get list of index, id, and docs.
|
| 30 |
+
full_info = [
|
| 31 |
+
(starting_len + i, str(uuid.uuid4()), doc)
|
| 32 |
+
for i, doc in enumerate(documents)
|
| 33 |
+
]
|
| 34 |
+
# Add information to docstore and index.
|
| 35 |
+
self.docstore.add({_id: doc for _, _id, doc in full_info})
|
| 36 |
+
index_to_id = {index: _id for index, _id, _ in full_info}
|
| 37 |
+
self.index_to_docstore_id.update(index_to_id)
|
| 38 |
+
return [_id for _, _id, _ in full_info]
|
| 39 |
+
|
| 40 |
+
@classmethod
|
| 41 |
+
def __from(
|
| 42 |
+
cls,
|
| 43 |
+
texts: List[str],
|
| 44 |
+
embeddings: List[List[float]],
|
| 45 |
+
embedding: Embeddings,
|
| 46 |
+
metadatas: Optional[List[dict]] = None,
|
| 47 |
+
**kwargs: Any,
|
| 48 |
+
) -> FAISS:
|
| 49 |
+
faiss = dependable_faiss_import()
|
| 50 |
+
index = faiss.IndexFlatL2(len(embeddings[0]))
|
| 51 |
+
index.add(np.array(embeddings, dtype=np.float32))
|
| 52 |
+
documents = []
|
| 53 |
+
for i, text in enumerate(texts):
|
| 54 |
+
metadata = metadatas[i] if metadatas else {}
|
| 55 |
+
documents.append(Document(page_content=text, metadata=metadata))
|
| 56 |
+
index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
|
| 57 |
+
|
| 58 |
+
# # TODO: cambiar para usar el hash. Y ver donde se pondria para que no cargara el chunk en el dataset
|
| 59 |
+
# index_to_id_2 = dict()
|
| 60 |
+
# for i in range(len(documents)):
|
| 61 |
+
# h = hashlib.new('sha256')
|
| 62 |
+
# text_ = documents[i].page_content
|
| 63 |
+
# h.update(text_.encode())
|
| 64 |
+
# index_to_id_2[i] = str(h.hexdigest())
|
| 65 |
+
# #
|
| 66 |
+
docstore = InMemoryDocstore(
|
| 67 |
+
{index_to_id[i]: doc for i, doc in enumerate(documents)}
|
| 68 |
+
)
|
| 69 |
+
return cls(embedding.embed_query, index, docstore, index_to_id)
|
| 70 |
+
|
| 71 |
+
@classmethod
|
| 72 |
+
def from_texts(
|
| 73 |
+
cls,
|
| 74 |
+
texts: List[str],
|
| 75 |
+
embedding: Embeddings,
|
| 76 |
+
metadatas: Optional[List[dict]] = None,
|
| 77 |
+
**kwargs: Any,
|
| 78 |
+
) -> FAISS:
|
| 79 |
+
"""Construct FAISS wrapper from raw documents.
|
| 80 |
+
|
| 81 |
+
This is a user friendly interface that:
|
| 82 |
+
1. Embeds documents.
|
| 83 |
+
2. Creates an in memory docstore
|
| 84 |
+
3. Initializes the FAISS database
|
| 85 |
+
|
| 86 |
+
This is intended to be a quick way to get started.
|
| 87 |
+
|
| 88 |
+
Example:
|
| 89 |
+
.. code-block:: python
|
| 90 |
+
|
| 91 |
+
from langchain import FAISS
|
| 92 |
+
from langchain.embeddings import OpenAIEmbeddings
|
| 93 |
+
embeddings = OpenAIEmbeddings()
|
| 94 |
+
faiss = FAISS.from_texts(texts, embeddings)
|
| 95 |
+
"""
|
| 96 |
+
# embeddings = embedding.embed_documents(texts)
|
| 97 |
+
print(f"len(texts): {len(texts)}") # TODO: borrar
|
| 98 |
+
embeddings = [embedding.embed_documents([text])[0] for text in texts]
|
| 99 |
+
print(f"len(embeddings): {len(embeddings)}") # TODO: borrar
|
| 100 |
+
return cls.__from(texts, embeddings, embedding, metadatas, **kwargs)
|
streamlit_langchain_chat/customized_langchain/vectorstores/pinecone.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.vectorstores.pinecone import *
|
| 2 |
+
from langchain.vectorstores.pinecone import Pinecone as OriginalPinecone
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Pinecone(OriginalPinecone):
|
| 6 |
+
@classmethod
|
| 7 |
+
def from_texts(
|
| 8 |
+
cls,
|
| 9 |
+
texts: List[str],
|
| 10 |
+
embedding: Embeddings,
|
| 11 |
+
metadatas: Optional[List[dict]] = None,
|
| 12 |
+
ids: Optional[List[str]] = None,
|
| 13 |
+
batch_size: int = 32,
|
| 14 |
+
text_key: str = "text",
|
| 15 |
+
index_name: Optional[str] = None,
|
| 16 |
+
namespace: Optional[str] = None,
|
| 17 |
+
**kwargs: Any,
|
| 18 |
+
) -> Pinecone:
|
| 19 |
+
"""Construct Pinecone wrapper from raw documents.
|
| 20 |
+
|
| 21 |
+
This is a user friendly interface that:
|
| 22 |
+
1. Embeds documents.
|
| 23 |
+
2. Adds the documents to a provided Pinecone index
|
| 24 |
+
|
| 25 |
+
This is intended to be a quick way to get started.
|
| 26 |
+
|
| 27 |
+
Example:
|
| 28 |
+
.. code-block:: python
|
| 29 |
+
|
| 30 |
+
from langchain import Pinecone
|
| 31 |
+
from langchain.embeddings import OpenAIEmbeddings
|
| 32 |
+
embeddings = OpenAIEmbeddings()
|
| 33 |
+
pinecone = Pinecone.from_texts(
|
| 34 |
+
texts,
|
| 35 |
+
embeddings,
|
| 36 |
+
index_name="langchain-demo"
|
| 37 |
+
)
|
| 38 |
+
"""
|
| 39 |
+
try:
|
| 40 |
+
import pinecone
|
| 41 |
+
except ImportError:
|
| 42 |
+
raise ValueError(
|
| 43 |
+
"Could not import pinecone python package. "
|
| 44 |
+
"Please install it with `pip install pinecone-client`."
|
| 45 |
+
)
|
| 46 |
+
_index_name = index_name or str(uuid.uuid4())
|
| 47 |
+
indexes = pinecone.list_indexes() # checks if provided index exists
|
| 48 |
+
if _index_name in indexes:
|
| 49 |
+
index = pinecone.Index(_index_name)
|
| 50 |
+
else:
|
| 51 |
+
index = None
|
| 52 |
+
for i in range(0, len(texts), batch_size):
|
| 53 |
+
# set end position of batch
|
| 54 |
+
i_end = min(i + batch_size, len(texts))
|
| 55 |
+
# get batch of texts and ids
|
| 56 |
+
lines_batch = texts[i:i_end]
|
| 57 |
+
# create ids if not provided
|
| 58 |
+
if ids:
|
| 59 |
+
ids_batch = ids[i:i_end]
|
| 60 |
+
else:
|
| 61 |
+
ids_batch = [str(uuid.uuid4()) for n in range(i, i_end)]
|
| 62 |
+
# create embeddings
|
| 63 |
+
# embeds = embedding.embed_documents(lines_batch)
|
| 64 |
+
embeds = [embedding.embed_documents([line_batch])[0] for line_batch in lines_batch]
|
| 65 |
+
# prep metadata and upsert batch
|
| 66 |
+
if metadatas:
|
| 67 |
+
metadata = metadatas[i:i_end]
|
| 68 |
+
else:
|
| 69 |
+
metadata = [{} for _ in range(i, i_end)]
|
| 70 |
+
for j, line in enumerate(lines_batch):
|
| 71 |
+
metadata[j][text_key] = line
|
| 72 |
+
to_upsert = zip(ids_batch, embeds, metadata)
|
| 73 |
+
# Create index if it does not exist
|
| 74 |
+
if index is None:
|
| 75 |
+
pinecone.create_index(_index_name, dimension=len(embeds[0]))
|
| 76 |
+
index = pinecone.Index(_index_name)
|
| 77 |
+
# upsert to Pinecone
|
| 78 |
+
index.upsert(vectors=list(to_upsert), namespace=namespace)
|
| 79 |
+
return cls(index, embedding.embed_query, text_key, namespace)
|
streamlit_langchain_chat/dataset.py
ADDED
|
@@ -0,0 +1,740 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from functools import reduce
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import re
|
| 9 |
+
import requests
|
| 10 |
+
from requests.models import MissingSchema
|
| 11 |
+
import sys
|
| 12 |
+
from typing import List, Optional, Tuple, Dict, Callable, Any
|
| 13 |
+
|
| 14 |
+
from bs4 import BeautifulSoup
|
| 15 |
+
import docx
|
| 16 |
+
from html2text import html2text
|
| 17 |
+
import langchain
|
| 18 |
+
from langchain.callbacks import get_openai_callback
|
| 19 |
+
from langchain.cache import SQLiteCache
|
| 20 |
+
from langchain.chains import LLMChain
|
| 21 |
+
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT
|
| 22 |
+
from langchain.chat_models import ChatOpenAI
|
| 23 |
+
from langchain.chat_models.base import BaseChatModel
|
| 24 |
+
from langchain.document_loaders import PyPDFLoader, PyMuPDFLoader
|
| 25 |
+
from langchain.embeddings.base import Embeddings
|
| 26 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
| 27 |
+
from langchain.llms import OpenAI
|
| 28 |
+
from langchain.llms.base import LLM, BaseLLM
|
| 29 |
+
from langchain.prompts.chat import AIMessagePromptTemplate
|
| 30 |
+
from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter
|
| 31 |
+
from langchain.vectorstores import Pinecone as OriginalPinecone
|
| 32 |
+
import numpy as np
|
| 33 |
+
import openai
|
| 34 |
+
import pinecone
|
| 35 |
+
from pptx import Presentation
|
| 36 |
+
from pypdf import PdfReader
|
| 37 |
+
import trafilatura
|
| 38 |
+
|
| 39 |
+
from streamlit_langchain_chat.constants import *
|
| 40 |
+
from streamlit_langchain_chat.customized_langchain.vectorstores import FAISS
|
| 41 |
+
from streamlit_langchain_chat.customized_langchain.vectorstores import Pinecone
|
| 42 |
+
from streamlit_langchain_chat.utils import maybe_is_text, maybe_is_truncated
|
| 43 |
+
from streamlit_langchain_chat.prompts import *
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if REUSE_ANSWERS:
|
| 47 |
+
CACHE_PATH = TEMP_DIR / "llm_cache.db"
|
| 48 |
+
os.makedirs(os.path.dirname(CACHE_PATH), exist_ok=True)
|
| 49 |
+
langchain.llm_cache = SQLiteCache(str(CACHE_PATH))
|
| 50 |
+
|
| 51 |
+
# option 1
|
| 52 |
+
TextSplitter = TokenTextSplitter
|
| 53 |
+
# option 2
|
| 54 |
+
# TextSplitter = RecursiveCharacterTextSplitter # usado por gpt4_pdf_chatbot_langchain (aka GPCL)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@dataclass
|
| 58 |
+
class Answer:
|
| 59 |
+
"""A class to hold the answer to a question."""
|
| 60 |
+
question: str = ""
|
| 61 |
+
answer: str = ""
|
| 62 |
+
context: str = ""
|
| 63 |
+
chunks: str = ""
|
| 64 |
+
packages: List[Any] = None
|
| 65 |
+
references: str = ""
|
| 66 |
+
cost_str: str = ""
|
| 67 |
+
passages: Dict[str, str] = None
|
| 68 |
+
tokens: List[Dict] = None
|
| 69 |
+
|
| 70 |
+
def __post_init__(self):
|
| 71 |
+
"""Initialize the answer."""
|
| 72 |
+
if self.packages is None:
|
| 73 |
+
self.packages = []
|
| 74 |
+
if self.passages is None:
|
| 75 |
+
self.passages = {}
|
| 76 |
+
|
| 77 |
+
def __str__(self) -> str:
|
| 78 |
+
"""Return the answer as a string."""
|
| 79 |
+
return self.answer
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def parse_docx(path, citation, key, chunk_chars=2000, overlap=50):
|
| 83 |
+
try:
|
| 84 |
+
document = docx.Document(path)
|
| 85 |
+
fullText = []
|
| 86 |
+
for paragraph in document.paragraphs:
|
| 87 |
+
fullText.append(paragraph.text)
|
| 88 |
+
doc = '\n'.join(fullText) + '\n'
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print(f"code_error: {e}")
|
| 91 |
+
sys.exit(1)
|
| 92 |
+
|
| 93 |
+
if doc:
|
| 94 |
+
text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap)
|
| 95 |
+
texts = text_splitter.split_text(doc)
|
| 96 |
+
return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts)
|
| 97 |
+
else:
|
| 98 |
+
return [], []
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# TODO: si pones un conector con el formato loader = ... ; data = loader.load();
|
| 102 |
+
# podrás poner todos los conectores de langchain
|
| 103 |
+
# https://langchain.readthedocs.io/en/stable/modules/document_loaders/examples/pdf.html
|
| 104 |
+
def parse_pdf(path, citation, key, chunk_chars=2000, overlap=50):
|
| 105 |
+
pdfFileObj = open(path, "rb")
|
| 106 |
+
pdfReader = PdfReader(pdfFileObj)
|
| 107 |
+
splits = []
|
| 108 |
+
split = ""
|
| 109 |
+
pages = []
|
| 110 |
+
metadatas = []
|
| 111 |
+
for i, page in enumerate(pdfReader.pages):
|
| 112 |
+
split += page.extract_text()
|
| 113 |
+
pages.append(str(i + 1))
|
| 114 |
+
# split could be so long it needs to be split
|
| 115 |
+
# into multiple chunks. Or it could be so short
|
| 116 |
+
# that it needs to be combined with the next chunk.
|
| 117 |
+
while len(split) > chunk_chars:
|
| 118 |
+
splits.append(split[:chunk_chars])
|
| 119 |
+
# pretty formatting of pages (e.g. 1-3, 4, 5-7)
|
| 120 |
+
pg = "-".join([pages[0], pages[-1]])
|
| 121 |
+
metadatas.append(
|
| 122 |
+
dict(
|
| 123 |
+
citation=citation,
|
| 124 |
+
dockey=key,
|
| 125 |
+
key=f"{key} pages {pg}",
|
| 126 |
+
)
|
| 127 |
+
)
|
| 128 |
+
split = split[chunk_chars - overlap:]
|
| 129 |
+
pages = [str(i + 1)]
|
| 130 |
+
if len(split) > overlap:
|
| 131 |
+
splits.append(split[:chunk_chars])
|
| 132 |
+
pg = "-".join([pages[0], pages[-1]])
|
| 133 |
+
metadatas.append(
|
| 134 |
+
dict(
|
| 135 |
+
citation=citation,
|
| 136 |
+
dockey=key,
|
| 137 |
+
key=f"{key} pages {pg}",
|
| 138 |
+
)
|
| 139 |
+
)
|
| 140 |
+
pdfFileObj.close()
|
| 141 |
+
|
| 142 |
+
# # ### option 2. PyPDFLoader
|
| 143 |
+
# loader = PyPDFLoader(path)
|
| 144 |
+
# data = loader.load_and_split()
|
| 145 |
+
# # ### option 2.1. PyPDFLoader usado por GPCL, aunque luego usa el
|
| 146 |
+
# loader = PyPDFLoader(path)
|
| 147 |
+
# rawDocs = loader.load()
|
| 148 |
+
# text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap)
|
| 149 |
+
# texts = text_splitter.split_documents(rawDocs)
|
| 150 |
+
# # ### option 3. PDFMiner. Este parece la mejor opcion
|
| 151 |
+
# loader = PyMuPDFLoader(path)
|
| 152 |
+
# data = loader.load()
|
| 153 |
+
return splits, metadatas
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def parse_pptx(path, citation, key, chunk_chars=2000, overlap=50):
|
| 157 |
+
try:
|
| 158 |
+
presentation = Presentation(path)
|
| 159 |
+
fullText = []
|
| 160 |
+
for slide in presentation.slides:
|
| 161 |
+
for shape in slide.shapes:
|
| 162 |
+
if hasattr(shape, "text"):
|
| 163 |
+
fullText.append(shape.text)
|
| 164 |
+
doc = ''.join(fullText)
|
| 165 |
+
|
| 166 |
+
if doc:
|
| 167 |
+
text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap)
|
| 168 |
+
texts = text_splitter.split_text(doc)
|
| 169 |
+
return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts)
|
| 170 |
+
else:
|
| 171 |
+
return [], []
|
| 172 |
+
|
| 173 |
+
except Exception as e:
|
| 174 |
+
print(f"code_error: {e}")
|
| 175 |
+
sys.exit(1)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def parse_txt(path, citation, key, chunk_chars=2000, overlap=50, html=False):
|
| 179 |
+
try:
|
| 180 |
+
with open(path) as f:
|
| 181 |
+
doc = f.read()
|
| 182 |
+
except UnicodeDecodeError as e:
|
| 183 |
+
with open(path, encoding="utf-8", errors="ignore") as f:
|
| 184 |
+
doc = f.read()
|
| 185 |
+
if html:
|
| 186 |
+
doc = html2text(doc)
|
| 187 |
+
# yo, no idea why but the texts are not split correctly
|
| 188 |
+
text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap)
|
| 189 |
+
texts = text_splitter.split_text(doc)
|
| 190 |
+
return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def parse_url(url: str, citation, key, chunk_chars=2000, overlap=50):
|
| 194 |
+
def beautifulsoup_extract_text_fallback(response_content):
|
| 195 |
+
"""
|
| 196 |
+
This is a fallback function, so that we can always return a value for text content.
|
| 197 |
+
Even for when both Trafilatura and BeautifulSoup are unable to extract the text from a
|
| 198 |
+
single URL.
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
# Create the beautifulsoup object:
|
| 202 |
+
soup = BeautifulSoup(response_content, 'html.parser')
|
| 203 |
+
|
| 204 |
+
# Finding the text:
|
| 205 |
+
text = soup.find_all(text=True)
|
| 206 |
+
|
| 207 |
+
# Remove unwanted tag elements:
|
| 208 |
+
cleaned_text = ''
|
| 209 |
+
blacklist = [
|
| 210 |
+
'[document]',
|
| 211 |
+
'noscript',
|
| 212 |
+
'header',
|
| 213 |
+
'html',
|
| 214 |
+
'meta',
|
| 215 |
+
'head',
|
| 216 |
+
'input',
|
| 217 |
+
'script',
|
| 218 |
+
'style', ]
|
| 219 |
+
|
| 220 |
+
# Then we will loop over every item in the extract text and make sure that the beautifulsoup4 tag
|
| 221 |
+
# is NOT in the blacklist
|
| 222 |
+
for item in text:
|
| 223 |
+
if item.parent.name not in blacklist:
|
| 224 |
+
cleaned_text += f'{item} ' # cleaned_text += '{} '.format(item)
|
| 225 |
+
|
| 226 |
+
# Remove any tab separation and strip the text:
|
| 227 |
+
cleaned_text = cleaned_text.replace('\t', '')
|
| 228 |
+
return cleaned_text.strip()
|
| 229 |
+
|
| 230 |
+
def extract_text_from_single_web_page(url):
|
| 231 |
+
print(f"\n===========\n{url=}\n===========\n")
|
| 232 |
+
downloaded_url = trafilatura.fetch_url(url)
|
| 233 |
+
a = None
|
| 234 |
+
try:
|
| 235 |
+
a = trafilatura.extract(downloaded_url,
|
| 236 |
+
output_format='json',
|
| 237 |
+
with_metadata=True,
|
| 238 |
+
include_comments=False,
|
| 239 |
+
date_extraction_params={'extensive_search': True,
|
| 240 |
+
'original_date': True})
|
| 241 |
+
except AttributeError:
|
| 242 |
+
a = trafilatura.extract(downloaded_url,
|
| 243 |
+
output_format='json',
|
| 244 |
+
with_metadata=True,
|
| 245 |
+
date_extraction_params={'extensive_search': True,
|
| 246 |
+
'original_date': True})
|
| 247 |
+
except Exception as e:
|
| 248 |
+
print(f"code_error: {e}")
|
| 249 |
+
|
| 250 |
+
if a:
|
| 251 |
+
json_output = json.loads(a)
|
| 252 |
+
return json_output['text']
|
| 253 |
+
else:
|
| 254 |
+
try:
|
| 255 |
+
headers = {'User-Agent': 'Chrome/83.0.4103.106'}
|
| 256 |
+
resp = requests.get(url, headers=headers)
|
| 257 |
+
print(f"{resp=}\n")
|
| 258 |
+
# We will only extract the text from successful requests:
|
| 259 |
+
if resp.status_code == 200:
|
| 260 |
+
return beautifulsoup_extract_text_fallback(resp.content)
|
| 261 |
+
else:
|
| 262 |
+
# This line will handle for any failures in both the Trafilature and BeautifulSoup4 functions:
|
| 263 |
+
return np.nan
|
| 264 |
+
# Handling for any URLs that don't have the correct protocol
|
| 265 |
+
except MissingSchema:
|
| 266 |
+
return np.nan
|
| 267 |
+
|
| 268 |
+
text_to_split = extract_text_from_single_web_page(url)
|
| 269 |
+
text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap)
|
| 270 |
+
texts = text_splitter.split_text(text_to_split)
|
| 271 |
+
return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def read_source(path: str = None,
|
| 275 |
+
citation: str = None,
|
| 276 |
+
key: str = None,
|
| 277 |
+
chunk_chars: int = 3000,
|
| 278 |
+
overlap: int = 100,
|
| 279 |
+
disable_check: bool = False):
|
| 280 |
+
if path.endswith(".pdf"):
|
| 281 |
+
return parse_pdf(path, citation, key, chunk_chars, overlap)
|
| 282 |
+
elif path.endswith(".txt"):
|
| 283 |
+
return parse_txt(path, citation, key, chunk_chars, overlap)
|
| 284 |
+
elif path.endswith(".html"):
|
| 285 |
+
return parse_txt(path, citation, key, chunk_chars, overlap, html=True)
|
| 286 |
+
elif path.endswith(".docx"):
|
| 287 |
+
return parse_docx(path, citation, key, chunk_chars, overlap)
|
| 288 |
+
elif path.endswith(".pptx"):
|
| 289 |
+
return parse_pptx(path, citation, key, chunk_chars, overlap)
|
| 290 |
+
elif path.startswith("http://") or path.startswith("https://"):
|
| 291 |
+
return parse_url(path, citation, key, chunk_chars, overlap)
|
| 292 |
+
# TODO: poner mas conectores
|
| 293 |
+
# else:
|
| 294 |
+
# return parse_code_txt(path, citation, key, chunk_chars, overlap)
|
| 295 |
+
else:
|
| 296 |
+
raise "unknown extension"
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class Dataset:
|
| 300 |
+
"""A collection of documents to be used for answering questions."""
|
| 301 |
+
def __init__(
|
| 302 |
+
self,
|
| 303 |
+
chunk_size_limit: int = 3000,
|
| 304 |
+
llm: Optional[BaseLLM] | Optional[BaseChatModel] = None,
|
| 305 |
+
summary_llm: Optional[BaseLLM] = None,
|
| 306 |
+
name: str = "default",
|
| 307 |
+
index_path: Optional[Path] = None,
|
| 308 |
+
) -> None:
|
| 309 |
+
"""Initialize the collection of documents.
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
chunk_size_limit: The maximum number of characters to use for a single chunk of text.
|
| 313 |
+
llm: The language model to use for answering questions. Default - OpenAI chat-gpt-turbo
|
| 314 |
+
summary_llm: The language model to use for summarizing documents. If None, llm is used.
|
| 315 |
+
name: The name of the collection.
|
| 316 |
+
index_path: The path to the index file IF pickled. If None, defaults to using name in $HOME/.paperqa/name
|
| 317 |
+
"""
|
| 318 |
+
self.docs = dict()
|
| 319 |
+
self.keys = set()
|
| 320 |
+
self.chunk_size_limit = chunk_size_limit
|
| 321 |
+
|
| 322 |
+
self.index_docstore = None
|
| 323 |
+
|
| 324 |
+
if llm is None:
|
| 325 |
+
llm = ChatOpenAI(temperature=0.1, max_tokens=512)
|
| 326 |
+
if summary_llm is None:
|
| 327 |
+
summary_llm = llm
|
| 328 |
+
self.update_llm(llm, summary_llm)
|
| 329 |
+
|
| 330 |
+
if index_path is None:
|
| 331 |
+
index_path = TEMP_DIR / name
|
| 332 |
+
self.index_path = index_path
|
| 333 |
+
self.name = name
|
| 334 |
+
|
| 335 |
+
def update_llm(self, llm: BaseLLM | ChatOpenAI, summary_llm: Optional[BaseLLM] = None) -> None:
|
| 336 |
+
"""Update the LLM for answering questions."""
|
| 337 |
+
self.llm = llm
|
| 338 |
+
if summary_llm is None:
|
| 339 |
+
summary_llm = llm
|
| 340 |
+
self.summary_llm = summary_llm
|
| 341 |
+
self.summary_chain = LLMChain(prompt=chat_summary_prompt, llm=summary_llm)
|
| 342 |
+
self.search_chain = LLMChain(prompt=search_prompt, llm=llm)
|
| 343 |
+
self.cite_chain = LLMChain(prompt=citation_prompt, llm=llm)
|
| 344 |
+
|
| 345 |
+
def add(
|
| 346 |
+
self,
|
| 347 |
+
path: str,
|
| 348 |
+
citation: Optional[str] = None,
|
| 349 |
+
key: Optional[str] = None,
|
| 350 |
+
disable_check: bool = False,
|
| 351 |
+
chunk_chars: Optional[int] = 3000,
|
| 352 |
+
) -> None:
|
| 353 |
+
"""Add a document to the collection."""
|
| 354 |
+
|
| 355 |
+
if path in self.docs:
|
| 356 |
+
print(f"Document {path} already in collection.")
|
| 357 |
+
return None
|
| 358 |
+
|
| 359 |
+
if citation is None:
|
| 360 |
+
# peak first chunk
|
| 361 |
+
texts, _ = read_source(path, "", "", chunk_chars=chunk_chars)
|
| 362 |
+
with get_openai_callback() as cb:
|
| 363 |
+
citation = self.cite_chain.run(texts[0])
|
| 364 |
+
if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation:
|
| 365 |
+
citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}"
|
| 366 |
+
|
| 367 |
+
if key is None:
|
| 368 |
+
# get first name and year from citation
|
| 369 |
+
try:
|
| 370 |
+
author = re.search(r"([A-Z][a-z]+)", citation).group(1)
|
| 371 |
+
except AttributeError:
|
| 372 |
+
# panicking - no word??
|
| 373 |
+
raise ValueError(
|
| 374 |
+
f"Could not parse key from citation {citation}. Consider just passing key explicitly - e.g. docs.py (path, citation, key='mykey')"
|
| 375 |
+
)
|
| 376 |
+
try:
|
| 377 |
+
year = re.search(r"(\d{4})", citation).group(1)
|
| 378 |
+
except AttributeError:
|
| 379 |
+
year = ""
|
| 380 |
+
key = f"{author}{year}"
|
| 381 |
+
suffix = ""
|
| 382 |
+
while key + suffix in self.keys:
|
| 383 |
+
# move suffix to next letter
|
| 384 |
+
if suffix == "":
|
| 385 |
+
suffix = "a"
|
| 386 |
+
else:
|
| 387 |
+
suffix = chr(ord(suffix) + 1)
|
| 388 |
+
key += suffix
|
| 389 |
+
self.keys.add(key)
|
| 390 |
+
|
| 391 |
+
texts, metadata = read_source(path, citation, key, chunk_chars=chunk_chars)
|
| 392 |
+
# loose check to see if document was loaded
|
| 393 |
+
#
|
| 394 |
+
if len("".join(texts)) < 10 or (
|
| 395 |
+
not disable_check and not maybe_is_text("".join(texts))
|
| 396 |
+
):
|
| 397 |
+
raise ValueError(
|
| 398 |
+
f"This does not look like a text document: {path}. Path disable_check to ignore this error."
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
self.docs[path] = dict(texts=texts, metadata=metadata, key=key)
|
| 402 |
+
if self.index_docstore is not None:
|
| 403 |
+
self.index_docstore.add_texts(texts, metadatas=metadata)
|
| 404 |
+
|
| 405 |
+
def clear(self) -> None:
|
| 406 |
+
"""Clear the collection of documents."""
|
| 407 |
+
self.docs = dict()
|
| 408 |
+
self.keys = set()
|
| 409 |
+
self.index_docstore = None
|
| 410 |
+
# delete index file
|
| 411 |
+
pkl = self.index_path / "index.pkl"
|
| 412 |
+
if pkl.exists():
|
| 413 |
+
pkl.unlink()
|
| 414 |
+
fs = self.index_path / "index.faiss"
|
| 415 |
+
if fs.exists():
|
| 416 |
+
fs.unlink()
|
| 417 |
+
|
| 418 |
+
@property
|
| 419 |
+
def doc_previews(self) -> List[Tuple[int, str, str]]:
|
| 420 |
+
"""Return a list of tuples of (key, citation) for each document."""
|
| 421 |
+
return [
|
| 422 |
+
(
|
| 423 |
+
len(doc["texts"]),
|
| 424 |
+
doc["metadata"][0]["dockey"],
|
| 425 |
+
doc["metadata"][0]["citation"],
|
| 426 |
+
)
|
| 427 |
+
for doc in self.docs.values()
|
| 428 |
+
]
|
| 429 |
+
|
| 430 |
+
# to pickle, we have to save the index as a file
|
| 431 |
+
def __getstate__(self, embedding: Embeddings):
|
| 432 |
+
if embedding is None:
|
| 433 |
+
embedding = OpenAIEmbeddings()
|
| 434 |
+
if self.index_docstore is None and len(self.docs) > 0:
|
| 435 |
+
self._build_faiss_index(embedding)
|
| 436 |
+
state = self.__dict__.copy()
|
| 437 |
+
if self.index_docstore is not None:
|
| 438 |
+
state["_index"].save_local(self.index_path)
|
| 439 |
+
del state["_index"]
|
| 440 |
+
# remove LLMs (they can have callbacks, which can't be pickled)
|
| 441 |
+
del state["summary_chain"]
|
| 442 |
+
del state["qa_chain"]
|
| 443 |
+
del state["cite_chain"]
|
| 444 |
+
del state["search_chain"]
|
| 445 |
+
return state
|
| 446 |
+
|
| 447 |
+
def __setstate__(self, state):
|
| 448 |
+
self.__dict__.update(state)
|
| 449 |
+
try:
|
| 450 |
+
self.index_docstore = FAISS.load_local(self.index_path, OpenAIEmbeddings())
|
| 451 |
+
except:
|
| 452 |
+
# they use some special exception type, but I don't want to import it
|
| 453 |
+
self.index_docstore = None
|
| 454 |
+
self.update_llm(
|
| 455 |
+
ChatOpenAI(temperature=0.1, max_tokens=512)
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
def _build_faiss_index(self, embedding: Embeddings = None):
|
| 459 |
+
if embedding is None:
|
| 460 |
+
embedding = OpenAIEmbeddings()
|
| 461 |
+
if self.index_docstore is None:
|
| 462 |
+
texts = reduce(
|
| 463 |
+
lambda x, y: x + y, [doc["texts"] for doc in self.docs.values()], []
|
| 464 |
+
)
|
| 465 |
+
metadatas = reduce(
|
| 466 |
+
lambda x, y: x + y, [doc["metadata"] for doc in self.docs.values()], []
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
# if the index exists, load it
|
| 470 |
+
if LOAD_INDEX_LOCALLY and (self.index_path / "index.faiss").exists():
|
| 471 |
+
self.index_docstore = FAISS.load_local(self.index_path, embedding)
|
| 472 |
+
|
| 473 |
+
# search if the text and metadata already existed in the index
|
| 474 |
+
for i in reversed(range(len(texts))):
|
| 475 |
+
text = texts[i]
|
| 476 |
+
metadata = metadatas[i]
|
| 477 |
+
for key, value in self.index_docstore.docstore.dict_.items():
|
| 478 |
+
if value.page_content == text:
|
| 479 |
+
if value.metadata.get('citation').split(os.sep)[-1] != metadata.get('citation').split(os.sep)[-1]:
|
| 480 |
+
self.index_docstore.docstore.dict_[key].metadata['citation'] = metadata.get('citation').split(os.sep)[-1]
|
| 481 |
+
self.index_docstore.docstore.dict_[key].metadata['dockey'] = metadata.get('citation').split(os.sep)[-1]
|
| 482 |
+
self.index_docstore.docstore.dict_[key].metadata['key'] = metadata.get('citation').split(os.sep)[-1]
|
| 483 |
+
texts.pop(i)
|
| 484 |
+
metadatas.pop(i)
|
| 485 |
+
|
| 486 |
+
# add remaining texts
|
| 487 |
+
if texts:
|
| 488 |
+
self.index_docstore.add_texts(texts=texts, metadatas=metadatas)
|
| 489 |
+
else:
|
| 490 |
+
# crete new index
|
| 491 |
+
self.index_docstore = FAISS.from_texts(texts, embedding, metadatas=metadatas)
|
| 492 |
+
#
|
| 493 |
+
|
| 494 |
+
if SAVE_INDEX_LOCALLY:
|
| 495 |
+
# save index.
|
| 496 |
+
self.index_docstore.save_local(self.index_path)
|
| 497 |
+
|
| 498 |
+
def _build_pinecone_index(self, embedding: Embeddings = None):
|
| 499 |
+
if embedding is None:
|
| 500 |
+
embedding = OpenAIEmbeddings()
|
| 501 |
+
if self.index_docstore is None:
|
| 502 |
+
pinecone.init(
|
| 503 |
+
api_key=os.environ['PINECONE_API_KEY'], # find at app.pinecone.io
|
| 504 |
+
environment=os.environ['PINECONE_ENVIRONMENT'] # next to api key in console
|
| 505 |
+
)
|
| 506 |
+
texts = reduce(
|
| 507 |
+
lambda x, y: x + y, [doc["texts"] for doc in self.docs.values()], []
|
| 508 |
+
)
|
| 509 |
+
metadatas = reduce(
|
| 510 |
+
lambda x, y: x + y, [doc["metadata"] for doc in self.docs.values()], []
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# TODO: que cuando exista que no lo borre, sino que lo actualice
|
| 514 |
+
# index_name = "langchain-demo1"
|
| 515 |
+
# if index_name in pinecone.list_indexes():
|
| 516 |
+
# self.index_docstore = pinecone.Index(index_name)
|
| 517 |
+
# vectors = []
|
| 518 |
+
# for text, metadata in zip(texts, metadatas):
|
| 519 |
+
# # embed = <faltaria saber con que embedding se hizo el index que ya existia>
|
| 520 |
+
# self.index_docstore.upsert(vectors=vectors)
|
| 521 |
+
# else:
|
| 522 |
+
# if openai.api_type == 'azure':
|
| 523 |
+
# self.index_docstore = Pinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name)
|
| 524 |
+
# else:
|
| 525 |
+
# self.index_docstore = OriginalPinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name)
|
| 526 |
+
|
| 527 |
+
index_name = "langchain-demo1"
|
| 528 |
+
|
| 529 |
+
# if the index exists, delete it
|
| 530 |
+
if index_name in pinecone.list_indexes():
|
| 531 |
+
pinecone.delete_index(index_name)
|
| 532 |
+
|
| 533 |
+
# create new index
|
| 534 |
+
if openai.api_type == 'azure':
|
| 535 |
+
self.index_docstore = Pinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name)
|
| 536 |
+
else:
|
| 537 |
+
self.index_docstore = OriginalPinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name)
|
| 538 |
+
|
| 539 |
+
def get_evidence(
|
| 540 |
+
self,
|
| 541 |
+
answer: Answer,
|
| 542 |
+
embedding: Embeddings,
|
| 543 |
+
k: int = 3,
|
| 544 |
+
max_sources: int = 5,
|
| 545 |
+
marginal_relevance: bool = True,
|
| 546 |
+
) -> str:
|
| 547 |
+
if self.index_docstore is None:
|
| 548 |
+
self._build_faiss_index(embedding)
|
| 549 |
+
|
| 550 |
+
init_search_time = time.time()
|
| 551 |
+
|
| 552 |
+
# want to work through indices but less k
|
| 553 |
+
if marginal_relevance:
|
| 554 |
+
docs = self.index_docstore.max_marginal_relevance_search(
|
| 555 |
+
answer.question, k=k, fetch_k=5 * k
|
| 556 |
+
)
|
| 557 |
+
else:
|
| 558 |
+
docs = self.index_docstore.similarity_search(
|
| 559 |
+
answer.question, k=k, fetch_k=5 * k
|
| 560 |
+
)
|
| 561 |
+
if OPERATING_MODE == "debug":
|
| 562 |
+
print(f"time to search docs to build context: {time.time() - init_search_time:.2f} [s]")
|
| 563 |
+
init_summary_time = time.time()
|
| 564 |
+
partial_summary_time = ""
|
| 565 |
+
for i, doc in enumerate(docs):
|
| 566 |
+
with get_openai_callback() as cb:
|
| 567 |
+
init__partial_summary_time = time.time()
|
| 568 |
+
summary_of_chunked_text = self.summary_chain.run(
|
| 569 |
+
question=answer.question, context_str=doc.page_content
|
| 570 |
+
)
|
| 571 |
+
if OPERATING_MODE == "debug":
|
| 572 |
+
partial_summary_time += f"- time to make relevant summary of doc '{i}': {time.time() - init__partial_summary_time:.2f} [s]\n"
|
| 573 |
+
engine = self.summary_chain.llm.model_kwargs.get('deployment_id') or self.summary_chain.llm.model_name
|
| 574 |
+
if not answer.tokens:
|
| 575 |
+
answer.tokens = [{
|
| 576 |
+
'engine': engine,
|
| 577 |
+
'total_tokens': cb.total_tokens}]
|
| 578 |
+
else:
|
| 579 |
+
answer.tokens.append({
|
| 580 |
+
'engine': engine,
|
| 581 |
+
'total_tokens': cb.total_tokens
|
| 582 |
+
})
|
| 583 |
+
summarized_package = (
|
| 584 |
+
doc.metadata["key"],
|
| 585 |
+
doc.metadata["citation"],
|
| 586 |
+
summary_of_chunked_text,
|
| 587 |
+
doc.page_content,
|
| 588 |
+
)
|
| 589 |
+
if "Not applicable" not in summary_of_chunked_text and summarized_package not in answer.packages:
|
| 590 |
+
answer.packages.append(summarized_package)
|
| 591 |
+
yield answer
|
| 592 |
+
if len(answer.packages) == max_sources:
|
| 593 |
+
break
|
| 594 |
+
if OPERATING_MODE == "debug":
|
| 595 |
+
print(f"time to make all relevant summaries: {time.time() - init_summary_time:.2f} [s]")
|
| 596 |
+
# no se printea el ultimo caracter porque es un \n
|
| 597 |
+
print(partial_summary_time[:-1])
|
| 598 |
+
context_str = "\n\n".join(
|
| 599 |
+
[f"{citation}: {summary_of_chunked_text}"
|
| 600 |
+
for key, citation, summary_of_chunked_text, chunked_text in answer.packages
|
| 601 |
+
if "Not applicable" not in summary_of_chunked_text]
|
| 602 |
+
)
|
| 603 |
+
chunks_str = "\n\n".join(
|
| 604 |
+
[f"{citation}: {chunked_text}"
|
| 605 |
+
for key, citation, summary_of_chunked_text, chunked_text in answer.packages
|
| 606 |
+
if "Not applicable" not in summary_of_chunked_text]
|
| 607 |
+
)
|
| 608 |
+
valid_keys = [key
|
| 609 |
+
for key, citation, summary_of_chunked_text, chunked_textin in answer.packages
|
| 610 |
+
if "Not applicable" not in summary_of_chunked_text]
|
| 611 |
+
if len(valid_keys) > 0:
|
| 612 |
+
context_str += "\n\nValid keys: " + ", ".join(valid_keys)
|
| 613 |
+
chunks_str += "\n\nValid keys: " + ", ".join(valid_keys)
|
| 614 |
+
answer.context = context_str
|
| 615 |
+
answer.chunks = chunks_str
|
| 616 |
+
yield answer
|
| 617 |
+
|
| 618 |
+
def query(
|
| 619 |
+
self,
|
| 620 |
+
query: str,
|
| 621 |
+
embedding: Embeddings,
|
| 622 |
+
chat_history: list[tuple[str, str]],
|
| 623 |
+
k: int = 10,
|
| 624 |
+
max_sources: int = 5,
|
| 625 |
+
length_prompt: str = "about 100 words",
|
| 626 |
+
marginal_relevance: bool = True,
|
| 627 |
+
):
|
| 628 |
+
for answer in self._query(
|
| 629 |
+
query,
|
| 630 |
+
embedding,
|
| 631 |
+
chat_history,
|
| 632 |
+
k=k,
|
| 633 |
+
max_sources=max_sources,
|
| 634 |
+
length_prompt=length_prompt,
|
| 635 |
+
marginal_relevance=marginal_relevance,
|
| 636 |
+
):
|
| 637 |
+
pass
|
| 638 |
+
return answer
|
| 639 |
+
|
| 640 |
+
def _query(
|
| 641 |
+
self,
|
| 642 |
+
query: str,
|
| 643 |
+
embedding: Embeddings,
|
| 644 |
+
chat_history: list[tuple[str, str]],
|
| 645 |
+
k: int,
|
| 646 |
+
max_sources: int,
|
| 647 |
+
length_prompt: str,
|
| 648 |
+
marginal_relevance: bool,
|
| 649 |
+
):
|
| 650 |
+
if k < max_sources:
|
| 651 |
+
k = max_sources + 1
|
| 652 |
+
|
| 653 |
+
answer = Answer(question=query)
|
| 654 |
+
|
| 655 |
+
messages_qa = [system_message_prompt]
|
| 656 |
+
if len(chat_history) != 0:
|
| 657 |
+
for conversation in chat_history:
|
| 658 |
+
messages_qa.append(HumanMessagePromptTemplate.from_template(conversation[0]))
|
| 659 |
+
messages_qa.append(AIMessagePromptTemplate.from_template(conversation[1]))
|
| 660 |
+
messages_qa.append(human_qa_message_prompt)
|
| 661 |
+
chat_qa_prompt = ChatPromptTemplate.from_messages(messages_qa)
|
| 662 |
+
self.qa_chain = LLMChain(prompt=chat_qa_prompt, llm=self.llm)
|
| 663 |
+
|
| 664 |
+
for answer in self.get_evidence(
|
| 665 |
+
answer,
|
| 666 |
+
embedding,
|
| 667 |
+
k=k,
|
| 668 |
+
max_sources=max_sources,
|
| 669 |
+
marginal_relevance=marginal_relevance,
|
| 670 |
+
):
|
| 671 |
+
yield answer
|
| 672 |
+
|
| 673 |
+
references_dict = dict()
|
| 674 |
+
passages = dict()
|
| 675 |
+
if len(answer.context) < 10:
|
| 676 |
+
answer_text = "I cannot answer this question due to insufficient information."
|
| 677 |
+
else:
|
| 678 |
+
with get_openai_callback() as cb:
|
| 679 |
+
init_qa_time = time.time()
|
| 680 |
+
answer_text = self.qa_chain.run(
|
| 681 |
+
question=answer.question, context_str=answer.context, length=length_prompt
|
| 682 |
+
)
|
| 683 |
+
if OPERATING_MODE == "debug":
|
| 684 |
+
print(f"time to make the Q&A answer: {time.time() - init_qa_time:.2f} [s]")
|
| 685 |
+
engine = self.qa_chain.llm.model_kwargs.get('deployment_id') or self.qa_chain.llm.model_name
|
| 686 |
+
if not answer.tokens:
|
| 687 |
+
answer.tokens = [{
|
| 688 |
+
'engine': engine,
|
| 689 |
+
'total_tokens': cb.total_tokens}]
|
| 690 |
+
else:
|
| 691 |
+
answer.tokens.append({
|
| 692 |
+
'engine': engine,
|
| 693 |
+
'total_tokens': cb.total_tokens
|
| 694 |
+
})
|
| 695 |
+
|
| 696 |
+
# it still happens lol
|
| 697 |
+
if "(Foo2012)" in answer_text:
|
| 698 |
+
answer_text = answer_text.replace("(Foo2012)", "")
|
| 699 |
+
for key, citation, summary, text in answer.packages:
|
| 700 |
+
# do check for whole key (so we don't catch Callahan2019a with Callahan2019)
|
| 701 |
+
skey = key.split(" ")[0]
|
| 702 |
+
if skey + " " in answer_text or skey + ")" in answer_text:
|
| 703 |
+
references_dict[skey] = citation
|
| 704 |
+
passages[key] = text
|
| 705 |
+
references_str = "\n\n".join(
|
| 706 |
+
[f"{i+1}. ({k}): {c}" for i, (k, c) in enumerate(references_dict.items())]
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
# cost_str = f"{answer_text}\n\n"
|
| 710 |
+
cost_str = ""
|
| 711 |
+
itemized_cost = ""
|
| 712 |
+
total_amount = 0
|
| 713 |
+
for d in answer.tokens:
|
| 714 |
+
total_tokens = d.get('total_tokens')
|
| 715 |
+
if total_tokens:
|
| 716 |
+
engine = d.get('engine')
|
| 717 |
+
key_price = None
|
| 718 |
+
for key in PRICES.keys():
|
| 719 |
+
if re.match(f"{key}", engine):
|
| 720 |
+
key_price = key
|
| 721 |
+
break
|
| 722 |
+
if PRICES.get(key_price):
|
| 723 |
+
partial_amount = total_tokens / 1000 * PRICES.get(key_price)
|
| 724 |
+
total_amount += partial_amount
|
| 725 |
+
itemized_cost += f"- {engine}: {total_tokens} tokens\t ---> ${partial_amount:.4f},\n"
|
| 726 |
+
else:
|
| 727 |
+
itemized_cost += f"- {engine}: {total_tokens} tokens,\n"
|
| 728 |
+
# delete ,\n
|
| 729 |
+
itemized_cost = itemized_cost[:-2]
|
| 730 |
+
|
| 731 |
+
# add tokens to formatted answer
|
| 732 |
+
cost_str += f"Total cost: ${total_amount:.4f}\nItemized cost:\n{itemized_cost}"
|
| 733 |
+
|
| 734 |
+
answer.answer = answer_text
|
| 735 |
+
answer.cost_str = cost_str
|
| 736 |
+
answer.references = references_str
|
| 737 |
+
answer.passages = passages
|
| 738 |
+
yield answer
|
| 739 |
+
|
| 740 |
+
|
streamlit_langchain_chat/inputs/__init__.py
ADDED
|
File without changes
|
streamlit_langchain_chat/prompts.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import langchain.prompts as prompts
|
| 2 |
+
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
summary_template = """Summarize and provide direct quotes from the text below to help answer a question.
|
| 6 |
+
Do not directly answer the question, instead provide a summary and quotes with the context of the user's question.
|
| 7 |
+
Do not use outside sources.
|
| 8 |
+
Reply with "Not applicable" if the text is unrelated to the question.
|
| 9 |
+
Use 75 or less words.
|
| 10 |
+
Remember, if the user does not specify a language, reply in the language of the user's question.
|
| 11 |
+
|
| 12 |
+
{context_str}
|
| 13 |
+
|
| 14 |
+
User's question: {question}
|
| 15 |
+
Relevant Information Summary:"""
|
| 16 |
+
summary_prompt = prompts.PromptTemplate(
|
| 17 |
+
input_variables=["question", "context_str"],
|
| 18 |
+
template=summary_template,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
qa_template = """Write an answer for the user's question below solely based on the provided context.
|
| 22 |
+
If the user does not specify how many words the answer should be, the length of the answer should be {length}.
|
| 23 |
+
If the context is irrelevant, reply "Your question falls outside the scope of University of Sydney policy, so I cannot answer".
|
| 24 |
+
For each sentence in your answer, indicate which sources most support it via valid citation markers at the end of sentences, like (Example2012).
|
| 25 |
+
Answer in an unbiased and professional tone.
|
| 26 |
+
Make clear what is your opinion.
|
| 27 |
+
Use Markdown for formatting code or text, and try to use direct quotes to support arguments.
|
| 28 |
+
Remember, if the user does not specify a language, answer in the language of the user's question.
|
| 29 |
+
|
| 30 |
+
Context:
|
| 31 |
+
{context_str}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
User's question: {question}
|
| 35 |
+
Answer:
|
| 36 |
+
"""
|
| 37 |
+
qa_prompt = prompts.PromptTemplate(
|
| 38 |
+
input_variables=["question", "context_str", "length"],
|
| 39 |
+
template=qa_template,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# usado por GPCL
|
| 43 |
+
qa_prompt_GPCL = prompts.PromptTemplate(
|
| 44 |
+
input_variables=["question", "context_str"],
|
| 45 |
+
template="You are an AI assistant providing helpful advice about University of Sydney policy. You are given the following extracted parts of a long document and a question. Provide a conversational answer based on the context provided."
|
| 46 |
+
"You should only provide hyperlinks that reference the context below. Do NOT make up hyperlinks."
|
| 47 |
+
'If you can not find the answer in the context below, just say "Hmm, I am not sure. Could you please rephrase your question?" Do not try to make up an answer.'
|
| 48 |
+
"If the question is not related to the context, politely respond that you are tuned to only answer questions that are related to the context.\n\n"
|
| 49 |
+
"Question: {question}\n"
|
| 50 |
+
"=========\n"
|
| 51 |
+
"{context_str}\n"
|
| 52 |
+
"=========\n"
|
| 53 |
+
"Answer in Markdown:",
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
search_prompt = prompts.PromptTemplate(
|
| 57 |
+
input_variables=["question"],
|
| 58 |
+
template="We want to answer the following question: {question} \n"
|
| 59 |
+
"Provide three different targeted keyword searches (one search per line) "
|
| 60 |
+
"that will find papers that help answer the question. Do not use boolean operators. "
|
| 61 |
+
"Recent years are 2021, 2022, 2023.\n\n"
|
| 62 |
+
"1.",
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _get_datetime():
|
| 67 |
+
now = datetime.now()
|
| 68 |
+
return now.strftime("%m/%d/%Y")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
citation_prompt = prompts.PromptTemplate(
|
| 72 |
+
input_variables=["text"],
|
| 73 |
+
template="Provide a possible citation for the following text in MLA Format. Today's date is {date}\n"
|
| 74 |
+
"{text}\n\n"
|
| 75 |
+
"Citation:",
|
| 76 |
+
partial_variables={"date": _get_datetime},
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
system_template = """You are an AI chatbot with knowledge of the University of Sydney's legal policies that answers in an unbiased, professional tone.
|
| 80 |
+
You sometimes refuse to answer if there is insufficient information.
|
| 81 |
+
If the user does not specify a language, answer in the language of the user's question. """
|
| 82 |
+
system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
|
| 83 |
+
|
| 84 |
+
human_summary_message_prompt = HumanMessagePromptTemplate.from_template(summary_template)
|
| 85 |
+
chat_summary_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_summary_message_prompt])
|
| 86 |
+
|
| 87 |
+
human_qa_message_prompt = HumanMessagePromptTemplate.from_template(qa_template)
|
| 88 |
+
# chat_qa_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_qa_message_prompt]) # TODO: borrar
|
| 89 |
+
|
| 90 |
+
# human_condense_message_prompt = HumanMessagePromptTemplate.from_template(condense_template)
|
| 91 |
+
# chat_condense_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_condense_message_prompt])
|
streamlit_langchain_chat/streamlit_app.py
ADDED
|
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
To run:
|
| 4 |
+
- activate the virtual environment
|
| 5 |
+
- streamlit run path\to\streamlit_app.py
|
| 6 |
+
"""
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import re
|
| 10 |
+
import sys
|
| 11 |
+
import time
|
| 12 |
+
import warnings
|
| 13 |
+
import shutil
|
| 14 |
+
|
| 15 |
+
from langchain.chat_models import ChatOpenAI
|
| 16 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
| 17 |
+
import openai
|
| 18 |
+
import pandas as pd
|
| 19 |
+
import streamlit as st
|
| 20 |
+
from st_aggrid import GridOptionsBuilder, AgGrid, GridUpdateMode, ColumnsAutoSizeMode
|
| 21 |
+
from streamlit_chat import message
|
| 22 |
+
|
| 23 |
+
from streamlit_langchain_chat.constants import *
|
| 24 |
+
from streamlit_langchain_chat.customized_langchain.llms import OpenAI, AzureOpenAI, AzureOpenAIChat
|
| 25 |
+
from streamlit_langchain_chat.dataset import Dataset
|
| 26 |
+
|
| 27 |
+
# Configure logger
|
| 28 |
+
logging.basicConfig(format="\n%(asctime)s\n%(message)s", level=logging.INFO, force=True)
|
| 29 |
+
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
|
| 30 |
+
|
| 31 |
+
warnings.filterwarnings('ignore')
|
| 32 |
+
|
| 33 |
+
if 'generated' not in st.session_state:
|
| 34 |
+
st.session_state['generated'] = []
|
| 35 |
+
if 'past' not in st.session_state:
|
| 36 |
+
st.session_state['past'] = []
|
| 37 |
+
if 'costs' not in st.session_state:
|
| 38 |
+
st.session_state['costs'] = []
|
| 39 |
+
if 'contexts' not in st.session_state:
|
| 40 |
+
st.session_state['contexts'] = []
|
| 41 |
+
if 'chunks' not in st.session_state:
|
| 42 |
+
st.session_state['chunks'] = []
|
| 43 |
+
if 'user_input' not in st.session_state:
|
| 44 |
+
st.session_state['user_input'] = ""
|
| 45 |
+
if 'dataset' not in st.session_state:
|
| 46 |
+
st.session_state['dataset'] = None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def check_api_keys() -> bool:
|
| 50 |
+
source_id = app.params['source_id']
|
| 51 |
+
index_id = app.params['index_id']
|
| 52 |
+
|
| 53 |
+
open_api_key = os.getenv('OPENAI_API_KEY', '')
|
| 54 |
+
openapi_api_key_ready = type(open_api_key) is str and len(open_api_key) > 0
|
| 55 |
+
|
| 56 |
+
pinecone_api_key = os.getenv('PINECONE_API_KEY', '')
|
| 57 |
+
pinecone_api_key_ready = type(pinecone_api_key) is str and len(pinecone_api_key) > 0 if index_id == 2 else True
|
| 58 |
+
|
| 59 |
+
is_ready = True if openapi_api_key_ready and pinecone_api_key_ready else False
|
| 60 |
+
return is_ready
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def check_combination_point() -> bool:
|
| 64 |
+
type_id = app.params['type_id']
|
| 65 |
+
open_api_key = os.getenv('OPENAI_API_KEY', '')
|
| 66 |
+
openapi_api_key_ready = type(open_api_key) is str and len(open_api_key) > 0
|
| 67 |
+
api_base = app.params['api_base']
|
| 68 |
+
|
| 69 |
+
if type_id == 1:
|
| 70 |
+
deployment_id = app.params['deployment_id']
|
| 71 |
+
return True if openapi_api_key_ready and api_base and deployment_id else False
|
| 72 |
+
elif type_id == 2:
|
| 73 |
+
return True if openapi_api_key_ready and api_base else False
|
| 74 |
+
else:
|
| 75 |
+
return False
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def check_index() -> bool:
|
| 79 |
+
dataset = st.session_state['dataset']
|
| 80 |
+
|
| 81 |
+
index_built = dataset.index_docstore if hasattr(dataset, "index_docstore") else False
|
| 82 |
+
without_source = app.params['source_id'] == 4
|
| 83 |
+
is_ready = True if index_built or without_source else False
|
| 84 |
+
return is_ready
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def check_index_point() -> bool:
|
| 88 |
+
index_id = app.params['index_id']
|
| 89 |
+
|
| 90 |
+
pinecone_api_key = os.getenv('PINECONE_API_KEY', '')
|
| 91 |
+
pinecone_api_key_ready = type(pinecone_api_key) is str and len(pinecone_api_key) > 0 if index_id == 2 else True
|
| 92 |
+
pinecone_environment = os.getenv('PINECONE_ENVIRONMENT', False) if index_id == 2 else True
|
| 93 |
+
|
| 94 |
+
is_ready = True if index_id and pinecone_api_key_ready and pinecone_environment else False
|
| 95 |
+
return is_ready
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def check_params_point() -> bool:
|
| 99 |
+
max_sources = app.params['max_sources']
|
| 100 |
+
temperature = app.params['temperature']
|
| 101 |
+
|
| 102 |
+
is_ready = True if max_sources and isinstance(temperature, float) else False
|
| 103 |
+
return is_ready
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def check_source_point() -> bool:
|
| 107 |
+
return True
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def clear_chat_history():
|
| 111 |
+
if st.session_state['past'] or st.session_state['generated'] or st.session_state['contexts'] or st.session_state['chunks'] or st.session_state['costs']:
|
| 112 |
+
st.session_state['past'] = []
|
| 113 |
+
st.session_state['generated'] = []
|
| 114 |
+
st.session_state['contexts'] = []
|
| 115 |
+
st.session_state['chunks'] = []
|
| 116 |
+
st.session_state['costs'] = []
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def clear_index():
|
| 120 |
+
if dataset := st.session_state['dataset']:
|
| 121 |
+
# delete directory (with files)
|
| 122 |
+
index_path = dataset.index_path
|
| 123 |
+
if index_path.exists():
|
| 124 |
+
shutil.rmtree(str(index_path))
|
| 125 |
+
|
| 126 |
+
# update variable
|
| 127 |
+
st.session_state['dataset'] = None
|
| 128 |
+
|
| 129 |
+
elif (TEMP_DIR / "default").exists():
|
| 130 |
+
shutil.rmtree(str(TEMP_DIR / "default"))
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def check_sources() -> bool:
|
| 134 |
+
uploaded_files_rows = app.params['uploaded_files_rows']
|
| 135 |
+
urls_df = app.params['urls_df']
|
| 136 |
+
source_id = app.params['source_id']
|
| 137 |
+
|
| 138 |
+
some_files = True if uploaded_files_rows and uploaded_files_rows[-1].get('filepath') != "" else False
|
| 139 |
+
some_urls = bool([True for url, citation in urls_df.to_numpy() if url])
|
| 140 |
+
|
| 141 |
+
only_local_files = some_files and not some_urls
|
| 142 |
+
only_urls = not some_files and some_urls
|
| 143 |
+
is_ready = only_local_files or only_urls or (source_id == 4)
|
| 144 |
+
return is_ready
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def collect_dataset_and_built_index():
|
| 148 |
+
start = time.time()
|
| 149 |
+
uploaded_files_rows = app.params['uploaded_files_rows']
|
| 150 |
+
urls_df = app.params['urls_df']
|
| 151 |
+
type_id = app.params['type_id']
|
| 152 |
+
temperature = app.params['temperature']
|
| 153 |
+
index_id = app.params['index_id']
|
| 154 |
+
api_base = app.params['api_base']
|
| 155 |
+
deployment_id = app.params['deployment_id']
|
| 156 |
+
|
| 157 |
+
some_files = True if uploaded_files_rows and uploaded_files_rows[-1].get('filepath') != "" else False
|
| 158 |
+
some_urls = bool([True for url, citation in urls_df.to_numpy() if url])
|
| 159 |
+
|
| 160 |
+
openai.api_type = "azure" if type_id == 1 else "open_ai"
|
| 161 |
+
openai.api_base = api_base
|
| 162 |
+
openai.api_version = "2023-03-15-preview" if type_id == 1 else None
|
| 163 |
+
|
| 164 |
+
if deployment_id != "text-davinci-003":
|
| 165 |
+
dataset = Dataset(
|
| 166 |
+
llm=ChatOpenAI(
|
| 167 |
+
temperature=temperature,
|
| 168 |
+
max_tokens=512,
|
| 169 |
+
deployment_id=deployment_id,
|
| 170 |
+
)
|
| 171 |
+
)
|
| 172 |
+
else:
|
| 173 |
+
dataset = Dataset(
|
| 174 |
+
llm=OpenAI(
|
| 175 |
+
temperature=temperature,
|
| 176 |
+
max_tokens=512,
|
| 177 |
+
deployment_id=COMBINATIONS_OPTIONS.get(combination_id).get('deployment_name'),
|
| 178 |
+
)
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# get url documents
|
| 182 |
+
if some_urls:
|
| 183 |
+
urls_df = urls_df.reset_index()
|
| 184 |
+
for url_index, url_row in urls_df.iterrows():
|
| 185 |
+
url = url_row.get('urls', '')
|
| 186 |
+
citation = url_row.get('citation string', '')
|
| 187 |
+
if url:
|
| 188 |
+
try:
|
| 189 |
+
dataset.add(
|
| 190 |
+
url,
|
| 191 |
+
citation,
|
| 192 |
+
citation,
|
| 193 |
+
disable_check=True # True to accept Japanese letters
|
| 194 |
+
)
|
| 195 |
+
except Exception as e:
|
| 196 |
+
print(e)
|
| 197 |
+
pass
|
| 198 |
+
|
| 199 |
+
# dataset is pandas dataframe
|
| 200 |
+
if some_files:
|
| 201 |
+
for uploaded_files_row in uploaded_files_rows:
|
| 202 |
+
key = uploaded_files_row.get('citation string') if ',' not in uploaded_files_row.get('citation string') else None
|
| 203 |
+
dataset.add(
|
| 204 |
+
uploaded_files_row.get('filepath'),
|
| 205 |
+
uploaded_files_row.get('citation string'),
|
| 206 |
+
key=key,
|
| 207 |
+
disable_check=True # True to accept Japanese letters
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
openai_embeddings = OpenAIEmbeddings(
|
| 211 |
+
document_model_name="text-embedding-ada-002",
|
| 212 |
+
query_model_name="text-embedding-ada-002",
|
| 213 |
+
)
|
| 214 |
+
if index_id == 1:
|
| 215 |
+
dataset._build_faiss_index(openai_embeddings)
|
| 216 |
+
else:
|
| 217 |
+
dataset._build_pinecone_index(openai_embeddings)
|
| 218 |
+
st.session_state['dataset'] = dataset
|
| 219 |
+
|
| 220 |
+
if OPERATING_MODE == "debug":
|
| 221 |
+
print(f"time to collect dataset: {time.time() - start:.2f} [s]")
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def configure_streamlit_and_page():
|
| 225 |
+
# Configure Streamlit page and state
|
| 226 |
+
st.set_page_config(**ST_CONFIG)
|
| 227 |
+
|
| 228 |
+
# Force responsive layout for columns also on mobile
|
| 229 |
+
st.write(
|
| 230 |
+
"""<style>
|
| 231 |
+
[data-testid="column"] {
|
| 232 |
+
width: calc(50% - 1rem);
|
| 233 |
+
flex: 1 1 calc(50% - 1rem);
|
| 234 |
+
min-width: calc(50% - 1rem);
|
| 235 |
+
}
|
| 236 |
+
</style>""",
|
| 237 |
+
unsafe_allow_html=True,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def get_answer():
|
| 242 |
+
query = st.session_state['user_input']
|
| 243 |
+
dataset = st.session_state['dataset']
|
| 244 |
+
type_id = app.params['type_id']
|
| 245 |
+
index_id = app.params['index_id']
|
| 246 |
+
max_sources = app.params['max_sources']
|
| 247 |
+
|
| 248 |
+
if query and dataset and type_id and index_id:
|
| 249 |
+
chat_history = [(past, generated)
|
| 250 |
+
for (past, generated) in zip(st.session_state['past'], st.session_state['generated'])]
|
| 251 |
+
marginal_relevance = False if not index_id == 1 else True
|
| 252 |
+
start = time.time()
|
| 253 |
+
openai_embeddings = OpenAIEmbeddings(
|
| 254 |
+
document_model_name="text-embedding-ada-002",
|
| 255 |
+
query_model_name="text-embedding-ada-002",
|
| 256 |
+
)
|
| 257 |
+
result = dataset.query(
|
| 258 |
+
query,
|
| 259 |
+
openai_embeddings,
|
| 260 |
+
chat_history,
|
| 261 |
+
marginal_relevance=marginal_relevance, # if pinecone is used it must be False
|
| 262 |
+
)
|
| 263 |
+
if OPERATING_MODE == "debug":
|
| 264 |
+
print(f"time to get answer: {time.time() - start:.2f} [s]")
|
| 265 |
+
print("-" * 10)
|
| 266 |
+
# response = {'generated_text': result.formatted_answer}
|
| 267 |
+
# response = {'generated_text': f"test_{len(st.session_state['generated'])} by {query}"} # @debug
|
| 268 |
+
return result
|
| 269 |
+
else:
|
| 270 |
+
return None
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def load_main_page():
|
| 274 |
+
"""
|
| 275 |
+
Load the body of web.
|
| 276 |
+
"""
|
| 277 |
+
# Streamlit HTML Markdown
|
| 278 |
+
# st.title <h1> #
|
| 279 |
+
# st.header <h2> ##
|
| 280 |
+
# st.subheader <h3> ###
|
| 281 |
+
st.markdown(f"## Augmented-Retrieval Q&A ChatGPT ({APP_VERSION})")
|
| 282 |
+
validate_status()
|
| 283 |
+
st.markdown(f"#### **Status**: {app.params['status']}")
|
| 284 |
+
|
| 285 |
+
# hidden div with anchor
|
| 286 |
+
st.markdown("<div id='linkto_top'></div>", unsafe_allow_html=True)
|
| 287 |
+
col1, col2, col3 = st.columns(3)
|
| 288 |
+
col1.button(label="clear index", type="primary", on_click=clear_index)
|
| 289 |
+
col2.button(label="clear conversation", type="primary", on_click=clear_chat_history)
|
| 290 |
+
col3.markdown("<a href='#linkto_bottom'>Link to bottom</a>", unsafe_allow_html=True)
|
| 291 |
+
|
| 292 |
+
if st.session_state["generated"]:
|
| 293 |
+
for i in range(len(st.session_state["generated"])):
|
| 294 |
+
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
|
| 295 |
+
message(st.session_state['generated'][i], key=str(i))
|
| 296 |
+
with st.expander("See context"):
|
| 297 |
+
st.write(st.session_state['contexts'][i])
|
| 298 |
+
with st.expander("See chunks"):
|
| 299 |
+
st.write(st.session_state['chunks'][i])
|
| 300 |
+
with st.expander("See costs"):
|
| 301 |
+
st.write(st.session_state['costs'][i])
|
| 302 |
+
dataset = st.session_state['dataset']
|
| 303 |
+
index_built = dataset.index_docstore if hasattr(dataset, "index_docstore") else False
|
| 304 |
+
without_source = app.params['source_id'] == 4
|
| 305 |
+
enable_chat_button = index_built or without_source
|
| 306 |
+
st.text_input("You:",
|
| 307 |
+
key='user_input',
|
| 308 |
+
on_change=on_enter,
|
| 309 |
+
disabled=not enable_chat_button
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
st.markdown("<a href='#linkto_top'>Link to top</a>", unsafe_allow_html=True)
|
| 313 |
+
# hidden div with anchor
|
| 314 |
+
st.markdown("<div id='linkto_bottom'></div>", unsafe_allow_html=True)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def load_sidebar_page():
|
| 318 |
+
st.sidebar.markdown("## Instructions")
|
| 319 |
+
|
| 320 |
+
# ############ #
|
| 321 |
+
# SOURCES TYPE #
|
| 322 |
+
# ############ #
|
| 323 |
+
st.sidebar.markdown("1. Select a source:")
|
| 324 |
+
source_selected = st.sidebar.selectbox(
|
| 325 |
+
"Choose the location of your info to give context to chatgpt",
|
| 326 |
+
[key for key, value in SOURCES_IDS.items()])
|
| 327 |
+
app.params['source_id'] = SOURCES_IDS.get(source_selected, None)
|
| 328 |
+
|
| 329 |
+
# ##### #
|
| 330 |
+
# MODEL #
|
| 331 |
+
# ##### #
|
| 332 |
+
st.sidebar.markdown("2. Select a model (LLM):")
|
| 333 |
+
combination_selected = st.sidebar.selectbox(
|
| 334 |
+
"Choose type: MSF Azure OpenAI and model / OpenAI",
|
| 335 |
+
[key for key, value in TYPE_IDS.items()])
|
| 336 |
+
app.params['type_id'] = TYPE_IDS.get(combination_selected, None)
|
| 337 |
+
|
| 338 |
+
if app.params['type_id'] == 1: # with AzureOpenAI endpoint
|
| 339 |
+
# https://docs.streamlit.io/library/api-reference/widgets/st.text_input
|
| 340 |
+
os.environ['OPENAI_API_KEY'] = st.sidebar.text_input(
|
| 341 |
+
label="Enter Azure OpenAI API Key",
|
| 342 |
+
type="password"
|
| 343 |
+
).strip()
|
| 344 |
+
app.params['api_base'] = st.sidebar.text_input(
|
| 345 |
+
label="Enter Azure API base",
|
| 346 |
+
placeholder="https://<api_base_endpoint>.openai.azure.com/",
|
| 347 |
+
).strip()
|
| 348 |
+
app.params['deployment_id'] = st.sidebar.text_input(
|
| 349 |
+
label="Enter Azure deployment_id",
|
| 350 |
+
).strip()
|
| 351 |
+
elif app.params['type_id'] == 2: # with OpenAI endpoint
|
| 352 |
+
os.environ['OPENAI_API_KEY'] = st.sidebar.text_input(
|
| 353 |
+
label="Enter OpenAI API Key",
|
| 354 |
+
placeholder="sk-...",
|
| 355 |
+
type="password"
|
| 356 |
+
).strip()
|
| 357 |
+
app.params['api_base'] = "https://api.openai.com/v1"
|
| 358 |
+
app.params['deployment_id'] = None
|
| 359 |
+
|
| 360 |
+
# ####### #
|
| 361 |
+
# INDEXES #
|
| 362 |
+
# ####### #
|
| 363 |
+
st.sidebar.markdown("3. Select a index store:")
|
| 364 |
+
index_selected = st.sidebar.selectbox(
|
| 365 |
+
"Type of Index",
|
| 366 |
+
[key for key, value in INDEX_IDS.items()])
|
| 367 |
+
app.params['index_id'] = INDEX_IDS.get(index_selected, None)
|
| 368 |
+
if app.params['index_id'] == 2: # with pinecone
|
| 369 |
+
os.environ['PINECONE_API_KEY'] = st.sidebar.text_input(
|
| 370 |
+
label="Enter pinecone API Key",
|
| 371 |
+
type="password"
|
| 372 |
+
).strip()
|
| 373 |
+
|
| 374 |
+
os.environ['PINECONE_ENVIRONMENT'] = st.sidebar.text_input(
|
| 375 |
+
label="Enter pinecone environment",
|
| 376 |
+
placeholder="eu-west1-gcp",
|
| 377 |
+
).strip()
|
| 378 |
+
|
| 379 |
+
# ############## #
|
| 380 |
+
# CONFIGURATIONS #
|
| 381 |
+
# ############## #
|
| 382 |
+
st.sidebar.markdown("4. Choose configuration:")
|
| 383 |
+
# https://docs.streamlit.io/library/api-reference/widgets/st.number_input
|
| 384 |
+
max_sources = st.sidebar.number_input(
|
| 385 |
+
label="Top-k: Number of chunks/sections (1-5)",
|
| 386 |
+
step=1,
|
| 387 |
+
format="%d",
|
| 388 |
+
value=5
|
| 389 |
+
)
|
| 390 |
+
app.params['max_sources'] = max_sources
|
| 391 |
+
temperature = st.sidebar.number_input(
|
| 392 |
+
label="Temperature (0.0 – 1.0)",
|
| 393 |
+
step=0.1,
|
| 394 |
+
format="%f",
|
| 395 |
+
value=0.0,
|
| 396 |
+
min_value=0.0,
|
| 397 |
+
max_value=1.0
|
| 398 |
+
)
|
| 399 |
+
app.params['temperature'] = round(temperature, 1)
|
| 400 |
+
|
| 401 |
+
# ############## #
|
| 402 |
+
# UPLOAD SOURCES #
|
| 403 |
+
# ############## #
|
| 404 |
+
app.params['uploaded_files_rows'] = []
|
| 405 |
+
if app.params['source_id'] == 1:
|
| 406 |
+
# https://docs.streamlit.io/library/api-reference/widgets/st.file_uploader
|
| 407 |
+
# https://towardsdatascience.com/make-dataframes-interactive-in-streamlit-c3d0c4f84ccb
|
| 408 |
+
st.sidebar.markdown("""5. Upload your local documents and modify citation strings (optional)""")
|
| 409 |
+
uploaded_files = st.sidebar.file_uploader(
|
| 410 |
+
"Choose files",
|
| 411 |
+
accept_multiple_files=True,
|
| 412 |
+
type=['pdf', 'PDF',
|
| 413 |
+
'txt', 'TXT',
|
| 414 |
+
'html',
|
| 415 |
+
'docx', 'DOCX',
|
| 416 |
+
'pptx', 'PPTX',
|
| 417 |
+
],
|
| 418 |
+
)
|
| 419 |
+
uploaded_files_dataset = request_pathname(uploaded_files)
|
| 420 |
+
uploaded_files_df = pd.DataFrame(
|
| 421 |
+
uploaded_files_dataset,
|
| 422 |
+
columns=['filepath', 'citation string'])
|
| 423 |
+
uploaded_files_grid_options_builder = GridOptionsBuilder.from_dataframe(uploaded_files_df)
|
| 424 |
+
uploaded_files_grid_options_builder.configure_selection(
|
| 425 |
+
selection_mode='multiple',
|
| 426 |
+
pre_selected_rows=list(range(uploaded_files_df.shape[0])) if uploaded_files_df.iloc[-1, 0] != "" else [],
|
| 427 |
+
use_checkbox=True,
|
| 428 |
+
)
|
| 429 |
+
uploaded_files_grid_options_builder.configure_column("citation string", editable=True)
|
| 430 |
+
uploaded_files_grid_options_builder.configure_auto_height()
|
| 431 |
+
uploaded_files_grid_options = uploaded_files_grid_options_builder.build()
|
| 432 |
+
with st.sidebar:
|
| 433 |
+
uploaded_files_ag_grid = AgGrid(
|
| 434 |
+
uploaded_files_df,
|
| 435 |
+
gridOptions=uploaded_files_grid_options,
|
| 436 |
+
update_mode=GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED,
|
| 437 |
+
)
|
| 438 |
+
app.params['uploaded_files_rows'] = uploaded_files_ag_grid["selected_rows"]
|
| 439 |
+
|
| 440 |
+
app.params['urls_df'] = pd.DataFrame()
|
| 441 |
+
if app.params['source_id'] == 3:
|
| 442 |
+
st.sidebar.markdown("""5. Write some urls and modify citation strings if you want (to look prettier)""")
|
| 443 |
+
# option 1: with streamlit version 1.20.0+
|
| 444 |
+
# app.params['urls_df'] = st.sidebar.experimental_data_editor(
|
| 445 |
+
# pd.DataFrame([["", ""]], columns=['urls', 'citation string']),
|
| 446 |
+
# use_container_width=True,
|
| 447 |
+
# num_rows="dynamic",
|
| 448 |
+
# )
|
| 449 |
+
|
| 450 |
+
# option 2: with streamlit version 1.19.0
|
| 451 |
+
urls_dataset = [["", ""],
|
| 452 |
+
["", ""],
|
| 453 |
+
["", ""],
|
| 454 |
+
["", ""],
|
| 455 |
+
["", ""]]
|
| 456 |
+
urls_df = pd.DataFrame(
|
| 457 |
+
urls_dataset,
|
| 458 |
+
columns=['urls', 'citation string'])
|
| 459 |
+
|
| 460 |
+
urls_grid_options_builder = GridOptionsBuilder.from_dataframe(urls_df)
|
| 461 |
+
urls_grid_options_builder.configure_columns(['urls', 'citation string'], editable=True)
|
| 462 |
+
urls_grid_options_builder.configure_auto_height()
|
| 463 |
+
urls_grid_options = urls_grid_options_builder.build()
|
| 464 |
+
with st.sidebar:
|
| 465 |
+
urls_ag_grid = AgGrid(
|
| 466 |
+
urls_df,
|
| 467 |
+
gridOptions=urls_grid_options,
|
| 468 |
+
update_mode=GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED,
|
| 469 |
+
)
|
| 470 |
+
df = urls_ag_grid.data
|
| 471 |
+
df = df[df.urls != ""]
|
| 472 |
+
app.params['urls_df'] = df
|
| 473 |
+
|
| 474 |
+
if app.params['source_id'] in (1, 2, 3):
|
| 475 |
+
st.sidebar.markdown("""6. Build an index where you can ask""")
|
| 476 |
+
api_keys_ready = check_api_keys()
|
| 477 |
+
source_ready = check_sources()
|
| 478 |
+
enable_index_button = api_keys_ready and source_ready
|
| 479 |
+
if st.sidebar.button("Build index", disabled=not enable_index_button):
|
| 480 |
+
collect_dataset_and_built_index()
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def main():
|
| 484 |
+
configure_streamlit_and_page()
|
| 485 |
+
load_sidebar_page()
|
| 486 |
+
load_main_page()
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def on_enter():
|
| 490 |
+
output = get_answer()
|
| 491 |
+
if output:
|
| 492 |
+
st.session_state.past.append(st.session_state['user_input'])
|
| 493 |
+
st.session_state.generated.append(output.answer)
|
| 494 |
+
st.session_state.contexts.append(output.context)
|
| 495 |
+
st.session_state.chunks.append(output.chunks)
|
| 496 |
+
st.session_state.costs.append(output.cost_str)
|
| 497 |
+
st.session_state['user_input'] = ""
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def request_pathname(files):
|
| 501 |
+
if not files:
|
| 502 |
+
return [["", ""]]
|
| 503 |
+
|
| 504 |
+
# check if temporal directory exist, if not create it
|
| 505 |
+
if not Path.exists(TEMP_DIR):
|
| 506 |
+
TEMP_DIR.mkdir(
|
| 507 |
+
parents=True,
|
| 508 |
+
exist_ok=True,
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
file_paths = []
|
| 512 |
+
for file in files:
|
| 513 |
+
# # absolut path
|
| 514 |
+
# file_path = str(TEMP_DIR / file.name)
|
| 515 |
+
# relative path
|
| 516 |
+
file_path = str((TEMP_DIR / file.name).relative_to(ROOT_DIR))
|
| 517 |
+
file_paths.append(file_path)
|
| 518 |
+
with open(file_path, "wb") as f:
|
| 519 |
+
f.write(file.getbuffer())
|
| 520 |
+
return [[filepath, filename.name] for filepath, filename in zip(file_paths, files)]
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def validate_status():
|
| 524 |
+
source_point_ready = check_source_point()
|
| 525 |
+
combination_point_ready = check_combination_point()
|
| 526 |
+
index_point_ready = check_index_point()
|
| 527 |
+
params_point_ready = check_params_point()
|
| 528 |
+
sources_ready = check_sources()
|
| 529 |
+
index_ready = check_index()
|
| 530 |
+
|
| 531 |
+
if source_point_ready and combination_point_ready and index_point_ready and params_point_ready and sources_ready and index_ready:
|
| 532 |
+
app.params['status'] = "✨Ready✨"
|
| 533 |
+
elif not source_point_ready:
|
| 534 |
+
app.params['status'] = "⚠️Review step 1 on the sidebar."
|
| 535 |
+
elif not combination_point_ready:
|
| 536 |
+
app.params['status'] = "⚠️Review step 2 on the sidebar. API Keys or endpoint, ..."
|
| 537 |
+
elif not index_point_ready:
|
| 538 |
+
app.params['status'] = "⚠️Review step 3 on the sidebar. Index API Key or environment."
|
| 539 |
+
elif not params_point_ready:
|
| 540 |
+
app.params['status'] = "⚠️Review step 4 on the sidebar"
|
| 541 |
+
elif not sources_ready:
|
| 542 |
+
app.params['status'] = "⚠️Review step 5 on the sidebar. Waiting for some source..."
|
| 543 |
+
elif not index_ready:
|
| 544 |
+
app.params['status'] = "⚠️Review step 6 on the sidebar. Waiting for press button to create index ..."
|
| 545 |
+
else:
|
| 546 |
+
app.params['status'] = "⚠️Something is not ready..."
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
class StreamlitLangchainChatApp():
|
| 550 |
+
def __init__(self) -> None:
|
| 551 |
+
"""Use __init__ to define instance variables. It cannot have any arguments."""
|
| 552 |
+
self.params = dict()
|
| 553 |
+
|
| 554 |
+
def run(self, **state) -> None:
|
| 555 |
+
"""Define here all logic required by your application."""
|
| 556 |
+
main()
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
if __name__ == "__main__":
|
| 560 |
+
app = StreamlitLangchainChatApp()
|
| 561 |
+
app.run()
|
streamlit_langchain_chat/utils.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import string
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def maybe_is_text(s, thresh=2.5):
|
| 6 |
+
if len(s) == 0:
|
| 7 |
+
return False
|
| 8 |
+
# Calculate the entropy of the string
|
| 9 |
+
entropy = 0
|
| 10 |
+
for c in string.printable:
|
| 11 |
+
p = s.count(c) / len(s)
|
| 12 |
+
if p > 0:
|
| 13 |
+
entropy += -p * math.log2(p)
|
| 14 |
+
|
| 15 |
+
# Check if the entropy is within a reasonable range for text
|
| 16 |
+
if entropy > thresh:
|
| 17 |
+
return True
|
| 18 |
+
return False
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def maybe_is_code(s):
|
| 22 |
+
if len(s) == 0:
|
| 23 |
+
return False
|
| 24 |
+
# Check if the string contains a lot of non-ascii characters
|
| 25 |
+
if len([c for c in s if ord(c) > 128]) / len(s) > 0.1:
|
| 26 |
+
return True
|
| 27 |
+
return False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def strings_similarity(s1, s2):
|
| 31 |
+
if len(s1) == 0 or len(s2) == 0:
|
| 32 |
+
return 0
|
| 33 |
+
# break the strings into words
|
| 34 |
+
s1 = set(s1.split())
|
| 35 |
+
s2 = set(s2.split())
|
| 36 |
+
# return the similarity ratio
|
| 37 |
+
return len(s1.intersection(s2)) / len(s1.union(s2))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def maybe_is_truncated(s):
|
| 41 |
+
punct = [".", "!", "?", '"']
|
| 42 |
+
if s[-1] in punct:
|
| 43 |
+
return False
|
| 44 |
+
return True
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def maybe_is_html(s):
|
| 48 |
+
if len(s) == 0:
|
| 49 |
+
return False
|
| 50 |
+
# check for html tags
|
| 51 |
+
if "<body" in s or "<html" in s or "<div" in s:
|
| 52 |
+
return True
|