hlydecker commited on
Commit
1ce95c4
·
0 Parent(s):

Duplicate from hlydecker/Augmented-Retrieval-qa-ChatGPT

Browse files
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ venv*/
2
+ tempDir/
3
+ .idea/
4
+ *.env
5
+ *.pkl
6
+ *.pickle
7
+ *testing*.py
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Augmented Retrieval Qa ChatGPT
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: blue
6
+ sdk: streamlit
7
+ sdk_version: 1.19.0
8
+ app_file: streamlit_langchain_chat/streamlit_app.py
9
+ pinned: false
10
+ python_version: 3.10.4
11
+ license: cc-by-nc-sa-4.0
12
+ duplicated_from: hlydecker/Augmented-Retrieval-qa-ChatGPT
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__init__.py ADDED
File without changes
requirements.txt ADDED
Binary file (5.03 kB). View file
 
static/__init__.py ADDED
File without changes
static/mini_nttdata.jpg ADDED
streamlit_langchain_chat/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
streamlit_langchain_chat/__version__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __VERSION__ = "1.0.4"
streamlit_langchain_chat/constants.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from PIL import Image
3
+
4
+ # from dotenv import load_dotenv, find_dotenv # pip install python-dotenv==1.0.0
5
+
6
+ from __version__ import __VERSION__ as APP_VERSION
7
+
8
+ _SCRIPT_PATH = Path(__file__).absolute()
9
+ PARENT_APP_DIR = _SCRIPT_PATH.parent
10
+ TEMP_DIR = PARENT_APP_DIR / 'tempDir'
11
+ ROOT_DIR = PARENT_APP_DIR.parent
12
+ STATIC_DIR = ROOT_DIR / 'static'
13
+
14
+ # _env_file_path = find_dotenv(str(CODE_DIR / '.env')) # Check if this path is correct
15
+ # if _env_file_path:
16
+ # load_dotenv(_env_file_path)
17
+
18
+ ST_CONFIG = {
19
+ "page_title": "NTT Data - Chat Q&A",
20
+ # "page_icon": Image.open(STATIC_DIR / "mini_nttdata.jpg"),
21
+ }
22
+
23
+ OPERATING_MODE = "debug" # debug, preproduction, production
24
+
25
+ REUSE_ANSWERS = False
26
+
27
+ LOAD_INDEX_LOCALLY = False
28
+ SAVE_INDEX_LOCALLY = False
29
+
30
+ # x$ per 1000 tokens
31
+ PRICES = {
32
+ 'text-embedding-ada-002': 0.0004,
33
+ 'text-davinci-003': 0.02,
34
+ 'gpt-3': 0.002,
35
+ 'gpt-4': 0.06, # 8K context
36
+ }
37
+
38
+ SOURCES_IDS = {
39
+ # "Without source. Only chat": 4,
40
+ "local files": 1,
41
+ "urls": 3
42
+ }
43
+
44
+ TYPE_IDS = {
45
+ "MSF Azure OpenAI Service": 1,
46
+ "OpenAI": 2,
47
+ }
48
+
49
+
50
+ INDEX_IDS = {
51
+ "FAISS": 1,
52
+ "Pinecode": 2,
53
+ }
streamlit_langchain_chat/customized_langchain/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from streamlit_langchain_chat.customized_langchain.docstore.in_memory import InMemoryDocstore
2
+ from streamlit_langchain_chat.customized_langchain.vectorstores import FAISS
3
+ from streamlit_langchain_chat.customized_langchain.vectorstores import Pinecone
4
+
5
+
6
+ __all__ = [
7
+ "FAISS",
8
+ "InMemoryDocstore",
9
+ "Pinecone",
10
+ ]
streamlit_langchain_chat/customized_langchain/docstore/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Wrappers on top of docstores."""
2
+ from streamlit_langchain_chat.customized_langchain.docstore.in_memory import InMemoryDocstore
3
+
4
+
5
+ __all__ = [
6
+ "InMemoryDocstore",
7
+ ]
streamlit_langchain_chat/customized_langchain/docstore/in_memory.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Simple in memory docstore in the form of a dict."""
2
+ from typing import Dict, Union
3
+
4
+ from langchain.docstore.base import AddableMixin, Docstore
5
+ from langchain.docstore.document import Document
6
+
7
+
8
+ class InMemoryDocstore(Docstore, AddableMixin):
9
+ """Simple in memory docstore in the form of a dict."""
10
+
11
+ def __init__(self, dict_: Dict[str, Document]):
12
+ """Initialize with dict."""
13
+ self.dict_ = dict_
14
+
15
+ def add(self, texts: Dict[str, Document]) -> None:
16
+ """Add texts to in memory dictionary."""
17
+ overlapping = set(texts).intersection(self.dict_)
18
+ if overlapping:
19
+ raise ValueError(f"Tried to add ids that already exist: {overlapping}")
20
+ self.dict_ = dict(self.dict_, **texts)
21
+
22
+ def search(self, search: str) -> Union[str, Document]:
23
+ """Search via direct lookup."""
24
+ if search not in self.dict_:
25
+ return f"ID {search} not found."
26
+ else:
27
+ return self.dict_[search]
streamlit_langchain_chat/customized_langchain/indexes/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from streamlit_langchain_chat.customized_langchain.indexes.graph import GraphIndexCreator
2
+ # from streamlit_langchain_chat.customized_langchain.vectorstore import VectorstoreIndexCreator
3
+
4
+ __all__ = [
5
+ "GraphIndexCreator",
6
+ # "VectorstoreIndexCreator"
7
+ ]
streamlit_langchain_chat/customized_langchain/indexes/graph.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from langchain.indexes.graph import *
4
+ from langchain.indexes.graph import GraphIndexCreator as OriginalGraphIndexCreator
5
+
6
+
7
+ class GraphIndexCreator(OriginalGraphIndexCreator):
8
+ def from_texts(self, texts: List[str]) -> NetworkxEntityGraph:
9
+ """Create graph index from text."""
10
+ if self.llm is None:
11
+ raise ValueError("llm should not be None")
12
+ graph = self.graph_type()
13
+ chain = LLMChain(llm=self.llm, prompt=KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT)
14
+
15
+ for text in texts:
16
+ output = chain.predict(text=text)
17
+ knowledge = parse_triples(output)
18
+ for triple in knowledge:
19
+ graph.add_triple(triple)
20
+ return graph
streamlit_langchain_chat/customized_langchain/llms/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from streamlit_langchain_chat.customized_langchain.llms.openai import AzureOpenAI, OpenAI, OpenAIChat, AzureOpenAIChat
streamlit_langchain_chat/customized_langchain/llms/openai.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wrapper around OpenAI APIs."""
2
+ from __future__ import annotations
3
+
4
+ import logging
5
+ import sys
6
+ from typing import (
7
+ Any,
8
+ Callable,
9
+ Dict,
10
+ Generator,
11
+ List,
12
+ Mapping,
13
+ Optional,
14
+ Set,
15
+ Tuple,
16
+ Union,
17
+ )
18
+
19
+ from pydantic import BaseModel, Extra, Field, root_validator
20
+ from tenacity import (
21
+ before_sleep_log,
22
+ retry,
23
+ retry_if_exception_type,
24
+ stop_after_attempt,
25
+ wait_exponential,
26
+ )
27
+
28
+ from langchain.llms.base import BaseLLM
29
+ from langchain.schema import Generation, LLMResult
30
+ from langchain.utils import get_from_dict_or_env
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def update_token_usage(
36
+ keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
37
+ ) -> None:
38
+ """Update token usage."""
39
+ _keys_to_use = keys.intersection(response["usage"])
40
+ for _key in _keys_to_use:
41
+ if _key not in token_usage:
42
+ token_usage[_key] = response["usage"][_key]
43
+ else:
44
+ token_usage[_key] += response["usage"][_key]
45
+
46
+
47
+ def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
48
+ """Update response from the stream response."""
49
+ response["choices"][0]["text"] += stream_response["choices"][0]["text"]
50
+ response["choices"][0]["finish_reason"] = stream_response["choices"][0][
51
+ "finish_reason"
52
+ ]
53
+ response["choices"][0]["logprobs"] = stream_response["choices"][0]["logprobs"]
54
+
55
+
56
+ def _streaming_response_template() -> Dict[str, Any]:
57
+ return {
58
+ "choices": [
59
+ {
60
+ "text": "",
61
+ "finish_reason": None,
62
+ "logprobs": None,
63
+ }
64
+ ]
65
+ }
66
+
67
+
68
+ def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any], Any]:
69
+ import openai
70
+
71
+ min_seconds = 4
72
+ max_seconds = 10
73
+ # Wait 2^x * 1 second between each retry starting with
74
+ # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
75
+ return retry(
76
+ reraise=True,
77
+ stop=stop_after_attempt(llm.max_retries),
78
+ wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
79
+ retry=(
80
+ retry_if_exception_type(openai.error.Timeout)
81
+ | retry_if_exception_type(openai.error.APIError)
82
+ | retry_if_exception_type(openai.error.APIConnectionError)
83
+ | retry_if_exception_type(openai.error.RateLimitError)
84
+ | retry_if_exception_type(openai.error.ServiceUnavailableError)
85
+ ),
86
+ before_sleep=before_sleep_log(logger, logging.WARNING),
87
+ )
88
+
89
+
90
+ def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) -> Any:
91
+ """Use tenacity to retry the completion call."""
92
+ retry_decorator = _create_retry_decorator(llm)
93
+
94
+ @retry_decorator
95
+ def _completion_with_retry(**kwargs: Any) -> Any:
96
+ return llm.client.create(**kwargs)
97
+
98
+ return _completion_with_retry(**kwargs)
99
+
100
+
101
+ async def acompletion_with_retry(
102
+ llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any
103
+ ) -> Any:
104
+ """Use tenacity to retry the async completion call."""
105
+ retry_decorator = _create_retry_decorator(llm)
106
+
107
+ @retry_decorator
108
+ async def _completion_with_retry(**kwargs: Any) -> Any:
109
+ # Use OpenAI's async api https://github.com/openai/openai-python#async-api
110
+ return await llm.client.acreate(**kwargs)
111
+
112
+ return await _completion_with_retry(**kwargs)
113
+
114
+
115
+ class BaseOpenAI(BaseLLM, BaseModel):
116
+ """Wrapper around OpenAI large language models.
117
+
118
+ To use, you should have the ``openai`` python package installed, and the
119
+ environment variable ``OPENAI_API_KEY`` set with your API key.
120
+
121
+ Any parameters that are valid to be passed to the openai.create call can be passed
122
+ in, even if not explicitly saved on this class.
123
+
124
+ Example:
125
+ .. code-block:: python
126
+
127
+ from langchain.llms import OpenAI
128
+ openai = OpenAI(model_name="text-davinci-003")
129
+ """
130
+
131
+ client: Any #: :meta private:
132
+ model_name: str = "text-davinci-003"
133
+ """Model name to use."""
134
+ temperature: float = 0.7
135
+ """What sampling temperature to use."""
136
+ max_tokens: int = 256
137
+ """The maximum number of tokens to generate in the completion.
138
+ -1 returns as many tokens as possible given the prompt and
139
+ the models maximal context size."""
140
+ top_p: float = 1
141
+ """Total probability mass of tokens to consider at each step."""
142
+ frequency_penalty: float = 0
143
+ """Penalizes repeated tokens according to frequency."""
144
+ presence_penalty: float = 0
145
+ """Penalizes repeated tokens."""
146
+ n: int = 1
147
+ """How many completions to generate for each prompt."""
148
+ best_of: int = 1
149
+ """Generates best_of completions server-side and returns the "best"."""
150
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
151
+ """Holds any model parameters valid for `create` call not explicitly specified."""
152
+ openai_api_key: Optional[str] = None
153
+ batch_size: int = 20
154
+ """Batch size to use when passing multiple documents to generate."""
155
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = None
156
+ """Timeout for requests to OpenAI completion API. Default is 600 seconds."""
157
+ logit_bias: Optional[Dict[str, float]] = Field(default_factory=dict)
158
+ """Adjust the probability of specific tokens being generated."""
159
+ max_retries: int = 6
160
+ """Maximum number of retries to make when generating."""
161
+ streaming: bool = False
162
+ """Whether to stream the results or not."""
163
+
164
+ def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
165
+ """Initialize the OpenAI object."""
166
+ if data.get("model_name", "").startswith("gpt-3.5-turbo"):
167
+ return OpenAIChat(**data)
168
+ return super().__new__(cls)
169
+
170
+ class Config:
171
+ """Configuration for this pydantic object."""
172
+
173
+ extra = Extra.ignore
174
+
175
+ @root_validator(pre=True, allow_reuse=True)
176
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
177
+ """Build extra kwargs from additional params that were passed in."""
178
+ all_required_field_names = {field.alias for field in cls.__fields__.values()}
179
+
180
+ extra = values.get("model_kwargs", {})
181
+ for field_name in list(values):
182
+ if field_name not in all_required_field_names:
183
+ if field_name in extra:
184
+ raise ValueError(f"Found {field_name} supplied twice.")
185
+ logger.warning(
186
+ f"""WARNING! {field_name} is not default parameter.
187
+ {field_name} was transfered to model_kwargs.
188
+ Please confirm that {field_name} is what you intended."""
189
+ )
190
+ extra[field_name] = values.pop(field_name)
191
+ values["model_kwargs"] = extra
192
+ return values
193
+
194
+ @root_validator(allow_reuse=True)
195
+ def validate_environment(cls, values: Dict) -> Dict:
196
+ """Validate that api key and python package exists in environment."""
197
+ openai_api_key = get_from_dict_or_env(
198
+ values, "openai_api_key", "OPENAI_API_KEY"
199
+ )
200
+ try:
201
+ import openai
202
+
203
+ openai.api_key = openai_api_key
204
+ values["client"] = openai.Completion
205
+ except ImportError:
206
+ raise ValueError(
207
+ "Could not import openai python package. "
208
+ "Please it install it with `pip install openai`."
209
+ )
210
+ if values["streaming"] and values["n"] > 1:
211
+ raise ValueError("Cannot stream results when n > 1.")
212
+ if values["streaming"] and values.get("best_of") and values["best_of"] > 1:
213
+ raise ValueError("Cannot stream results when best_of > 1.")
214
+ return values
215
+
216
+ @property
217
+ def _default_params(self) -> Dict[str, Any]:
218
+ """Get the default parameters for calling OpenAI API."""
219
+ normal_params = {
220
+ "temperature": self.temperature,
221
+ "max_tokens": self.max_tokens,
222
+ "top_p": self.top_p,
223
+ "frequency_penalty": self.frequency_penalty,
224
+ "presence_penalty": self.presence_penalty,
225
+ "n": self.n,
226
+ # "best_of": self.best_of,
227
+ "request_timeout": self.request_timeout,
228
+ "logit_bias": self.logit_bias,
229
+ }
230
+ return {**normal_params, **self.model_kwargs}
231
+
232
+ def _generate(
233
+ self, prompts: List[str], stop: Optional[List[str]] = None
234
+ ) -> LLMResult:
235
+ """Call out to OpenAI's endpoint with k unique prompts.
236
+
237
+ Args:
238
+ prompts: The prompts to pass into the model.
239
+ stop: Optional list of stop words to use when generating.
240
+
241
+ Returns:
242
+ The full LLM output.
243
+
244
+ Example:
245
+ .. code-block:: python
246
+
247
+ response = openai.generate(["Tell me a joke."])
248
+ """
249
+ # TODO: write a unit test for this
250
+ params = self._invocation_params
251
+ sub_prompts = self.get_sub_prompts(params, prompts, stop)
252
+ choices = []
253
+ token_usage: Dict[str, int] = {}
254
+ # Get the token usage from the response.
255
+ # Includes prompt, completion, and total tokens used.
256
+ _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
257
+ for _prompts in sub_prompts:
258
+ if self.streaming:
259
+ if len(_prompts) > 1:
260
+ raise ValueError("Cannot stream results with multiple prompts.")
261
+ params["stream"] = True
262
+ response = _streaming_response_template()
263
+ for stream_resp in completion_with_retry(
264
+ self, prompt=_prompts, **params
265
+ ):
266
+ self.callback_manager.on_llm_new_token(
267
+ stream_resp["choices"][0]["text"],
268
+ verbose=self.verbose,
269
+ logprobs=stream_resp["choices"][0]["logprobs"],
270
+ )
271
+ _update_response(response, stream_resp)
272
+ choices.extend(response["choices"])
273
+ else:
274
+ response = completion_with_retry(self, prompt=_prompts, **params)
275
+ choices.extend(response["choices"])
276
+ if not self.streaming:
277
+ # Can't update token usage if streaming
278
+ update_token_usage(_keys, response, token_usage)
279
+ return self.create_llm_result(choices, prompts, token_usage)
280
+
281
+ async def _agenerate(
282
+ self, prompts: List[str], stop: Optional[List[str]] = None
283
+ ) -> LLMResult:
284
+ """Call out to OpenAI's endpoint async with k unique prompts."""
285
+ params = self._invocation_params
286
+ sub_prompts = self.get_sub_prompts(params, prompts, stop)
287
+ choices = []
288
+ token_usage: Dict[str, int] = {}
289
+ # Get the token usage from the response.
290
+ # Includes prompt, completion, and total tokens used.
291
+ _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
292
+ for _prompts in sub_prompts:
293
+ if self.streaming:
294
+ if len(_prompts) > 1:
295
+ raise ValueError("Cannot stream results with multiple prompts.")
296
+ params["stream"] = True
297
+ response = _streaming_response_template()
298
+ async for stream_resp in await acompletion_with_retry(
299
+ self, prompt=_prompts, **params
300
+ ):
301
+ if self.callback_manager.is_async:
302
+ await self.callback_manager.on_llm_new_token(
303
+ stream_resp["choices"][0]["text"],
304
+ verbose=self.verbose,
305
+ logprobs=stream_resp["choices"][0]["logprobs"],
306
+ )
307
+ else:
308
+ self.callback_manager.on_llm_new_token(
309
+ stream_resp["choices"][0]["text"],
310
+ verbose=self.verbose,
311
+ logprobs=stream_resp["choices"][0]["logprobs"],
312
+ )
313
+ _update_response(response, stream_resp)
314
+ choices.extend(response["choices"])
315
+ else:
316
+ response = await acompletion_with_retry(self, prompt=_prompts, **params)
317
+ choices.extend(response["choices"])
318
+ if not self.streaming:
319
+ # Can't update token usage if streaming
320
+ update_token_usage(_keys, response, token_usage)
321
+ return self.create_llm_result(choices, prompts, token_usage)
322
+
323
+ def get_sub_prompts(
324
+ self,
325
+ params: Dict[str, Any],
326
+ prompts: List[str],
327
+ stop: Optional[List[str]] = None,
328
+ ) -> List[List[str]]:
329
+ """Get the sub prompts for llm call."""
330
+ if stop is not None:
331
+ if "stop" in params:
332
+ raise ValueError("`stop` found in both the input and default params.")
333
+ params["stop"] = stop
334
+ if params["max_tokens"] == -1:
335
+ if len(prompts) != 1:
336
+ raise ValueError(
337
+ "max_tokens set to -1 not supported for multiple inputs."
338
+ )
339
+ params["max_tokens"] = self.max_tokens_for_prompt(prompts[0])
340
+ sub_prompts = [
341
+ prompts[i : i + self.batch_size]
342
+ for i in range(0, len(prompts), self.batch_size)
343
+ ]
344
+ return sub_prompts
345
+
346
+ def create_llm_result(
347
+ self, choices: Any, prompts: List[str], token_usage: Dict[str, int]
348
+ ) -> LLMResult:
349
+ """Create the LLMResult from the choices and prompts."""
350
+ generations = []
351
+ for i, _ in enumerate(prompts):
352
+ sub_choices = choices[i * self.n : (i + 1) * self.n]
353
+ generations.append(
354
+ [
355
+ Generation(
356
+ text=choice["text"],
357
+ generation_info=dict(
358
+ finish_reason=choice.get("finish_reason"),
359
+ logprobs=choice.get("logprobs"),
360
+ ),
361
+ )
362
+ for choice in sub_choices
363
+ ]
364
+ )
365
+ return LLMResult(
366
+ generations=generations, llm_output={"token_usage": token_usage}
367
+ )
368
+
369
+ def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
370
+ """Call OpenAI with streaming flag and return the resulting generator.
371
+
372
+ BETA: this is a beta feature while we figure out the right abstraction.
373
+ Once that happens, this interface could change.
374
+
375
+ Args:
376
+ prompt: The prompts to pass into the model.
377
+ stop: Optional list of stop words to use when generating.
378
+
379
+ Returns:
380
+ A generator representing the stream of tokens from OpenAI.
381
+
382
+ Example:
383
+ .. code-block:: python
384
+
385
+ generator = openai.stream("Tell me a joke.")
386
+ for token in generator:
387
+ yield token
388
+ """
389
+ params = self.prep_streaming_params(stop)
390
+ generator = self.client.create(prompt=prompt, **params)
391
+
392
+ return generator
393
+
394
+ def prep_streaming_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
395
+ """Prepare the params for streaming."""
396
+ params = self._invocation_params
397
+ if params.get('best_of') and params["best_of"] != 1:
398
+ raise ValueError("OpenAI only supports best_of == 1 for streaming")
399
+ if stop is not None:
400
+ if "stop" in params:
401
+ raise ValueError("`stop` found in both the input and default params.")
402
+ params["stop"] = stop
403
+ params["stream"] = True
404
+ return params
405
+
406
+ @property
407
+ def _invocation_params(self) -> Dict[str, Any]:
408
+ """Get the parameters used to invoke the model."""
409
+ return self._default_params
410
+
411
+ @property
412
+ def _identifying_params(self) -> Mapping[str, Any]:
413
+ """Get the identifying parameters."""
414
+ return {**{"model_name": self.model_name}, **self._default_params}
415
+
416
+ @property
417
+ def _llm_type(self) -> str:
418
+ """Return type of llm."""
419
+ return "openai"
420
+
421
+ def get_num_tokens(self, text: str) -> int:
422
+ """Calculate num tokens with tiktoken package."""
423
+ # tiktoken NOT supported for Python 3.8 or below
424
+ if sys.version_info[1] <= 8:
425
+ return super().get_num_tokens(text)
426
+ try:
427
+ import tiktoken
428
+ except ImportError:
429
+ raise ValueError(
430
+ "Could not import tiktoken python package. "
431
+ "This is needed in order to calculate get_num_tokens. "
432
+ "Please it install it with `pip install tiktoken`."
433
+ )
434
+ encoder = "gpt2"
435
+ if self.model_name in ("text-davinci-003", "text-davinci-002"):
436
+ encoder = "p50k_base"
437
+ if self.model_name.startswith("code"):
438
+ encoder = "p50k_base"
439
+ # create a GPT-3 encoder instance
440
+ enc = tiktoken.get_encoding(encoder)
441
+
442
+ # encode the text using the GPT-3 encoder
443
+ tokenized_text = enc.encode(text)
444
+
445
+ # calculate the number of tokens in the encoded text
446
+ return len(tokenized_text)
447
+
448
+ def modelname_to_contextsize(self, modelname: str) -> int:
449
+ """Calculate the maximum number of tokens possible to generate for a model.
450
+
451
+ text-davinci-003: 4,097 tokens
452
+ text-curie-001: 2,048 tokens
453
+ text-babbage-001: 2,048 tokens
454
+ text-ada-001: 2,048 tokens
455
+ code-davinci-002: 8,000 tokens
456
+ code-cushman-001: 2,048 tokens
457
+
458
+ Args:
459
+ modelname: The modelname we want to know the context size for.
460
+
461
+ Returns:
462
+ The maximum context size
463
+
464
+ Example:
465
+ .. code-block:: python
466
+
467
+ max_tokens = openai.modelname_to_contextsize("text-davinci-003")
468
+ """
469
+ if modelname == "text-davinci-003":
470
+ return 4097
471
+ elif modelname == "text-curie-001":
472
+ return 2048
473
+ elif modelname == "text-babbage-001":
474
+ return 2048
475
+ elif modelname == "text-ada-001":
476
+ return 2048
477
+ elif modelname == "code-davinci-002":
478
+ return 8000
479
+ elif modelname == "code-cushman-001":
480
+ return 2048
481
+ else:
482
+ return 4097
483
+
484
+ def max_tokens_for_prompt(self, prompt: str) -> int:
485
+ """Calculate the maximum number of tokens possible to generate for a prompt.
486
+
487
+ Args:
488
+ prompt: The prompt to pass into the model.
489
+
490
+ Returns:
491
+ The maximum number of tokens to generate for a prompt.
492
+
493
+ Example:
494
+ .. code-block:: python
495
+
496
+ max_tokens = openai.max_token_for_prompt("Tell me a joke.")
497
+ """
498
+ num_tokens = self.get_num_tokens(prompt)
499
+
500
+ # get max context size for model by name
501
+ max_size = self.modelname_to_contextsize(self.model_name)
502
+ return max_size - num_tokens
503
+
504
+
505
+ class OpenAI(BaseOpenAI):
506
+ """Generic OpenAI class that uses model name."""
507
+
508
+ @property
509
+ def _invocation_params(self) -> Dict[str, Any]:
510
+ return {**{"model": self.model_name}, **super()._invocation_params}
511
+
512
+
513
+ class AzureOpenAI(BaseOpenAI):
514
+ """Azure specific OpenAI class that uses deployment name."""
515
+
516
+ deployment_name: str = ""
517
+ """Deployment name to use."""
518
+
519
+ @property
520
+ def _identifying_params(self) -> Mapping[str, Any]:
521
+ return {
522
+ **{"deployment_name": self.deployment_name},
523
+ **super()._identifying_params,
524
+ }
525
+
526
+ @property
527
+ def _invocation_params(self) -> Dict[str, Any]:
528
+ return {**{"engine": self.deployment_name}, **super()._invocation_params}
529
+
530
+
531
+ class OpenAIChat(BaseLLM, BaseModel):
532
+ """Wrapper around OpenAI Chat large language models.
533
+
534
+ To use, you should have the ``openai`` python package installed, and the
535
+ environment variable ``OPENAI_API_KEY`` set with your API key.
536
+
537
+ Any parameters that are valid to be passed to the openai.create call can be passed
538
+ in, even if not explicitly saved on this class.
539
+
540
+ Example:
541
+ .. code-block:: python
542
+
543
+ from langchain.llms import OpenAIChat
544
+ openaichat = OpenAIChat(model_name="gpt-3.5-turbo")
545
+ """
546
+
547
+ client: Any #: :meta private:
548
+ model_name: str = "gpt-3.5-turbo"
549
+ """Model name to use."""
550
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
551
+ """Holds any model parameters valid for `create` call not explicitly specified."""
552
+ openai_api_key: Optional[str] = None
553
+ max_retries: int = 6
554
+ """Maximum number of retries to make when generating."""
555
+ prefix_messages: List = Field(default_factory=list)
556
+ """Series of messages for Chat input."""
557
+ streaming: bool = False
558
+ """Whether to stream the results or not."""
559
+
560
+ class Config:
561
+ """Configuration for this pydantic object."""
562
+
563
+ extra = Extra.ignore
564
+
565
+ @root_validator(pre=True, allow_reuse=True)
566
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
567
+ """Build extra kwargs from additional params that were passed in."""
568
+ all_required_field_names = {field.alias for field in cls.__fields__.values()}
569
+
570
+ extra = values.get("model_kwargs", {})
571
+ for field_name in list(values):
572
+ if field_name not in all_required_field_names:
573
+ if field_name in extra:
574
+ raise ValueError(f"Found {field_name} supplied twice.")
575
+ extra[field_name] = values.pop(field_name)
576
+ values["model_kwargs"] = extra
577
+ return values
578
+
579
+ @root_validator(allow_reuse=True)
580
+ def validate_environment(cls, values: Dict) -> Dict:
581
+ """Validate that api key and python package exists in environment."""
582
+ openai_api_key = get_from_dict_or_env(
583
+ values, "openai_api_key", "OPENAI_API_KEY"
584
+ )
585
+ try:
586
+ import openai
587
+
588
+ openai.api_key = openai_api_key
589
+ except ImportError:
590
+ raise ValueError(
591
+ "Could not import openai python package. "
592
+ "Please it install it with `pip install openai`."
593
+ )
594
+ try:
595
+ values["client"] = openai.ChatCompletion
596
+ except AttributeError:
597
+ raise ValueError(
598
+ "`openai` has no `ChatCompletion` attribute, this is likely "
599
+ "due to an old version of the openai package. Try upgrading it "
600
+ "with `pip install --upgrade openai`."
601
+ )
602
+ return values
603
+
604
+ @property
605
+ def _default_params(self) -> Dict[str, Any]:
606
+ """Get the default parameters for calling OpenAI API."""
607
+ return self.model_kwargs
608
+
609
+ def _get_chat_params(
610
+ self, prompts: List[str], stop: Optional[List[str]] = None
611
+ ) -> Tuple:
612
+ if len(prompts) > 1:
613
+ raise ValueError(
614
+ f"OpenAIChat currently only supports single prompt, got {prompts}"
615
+ )
616
+ messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
617
+ params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
618
+ if stop is not None:
619
+ if "stop" in params:
620
+ raise ValueError("`stop` found in both the input and default params.")
621
+ params["stop"] = stop
622
+ return messages, params
623
+
624
+ def _generate(
625
+ self, prompts: List[str], stop: Optional[List[str]] = None
626
+ ) -> LLMResult:
627
+ messages, params = self._get_chat_params(prompts, stop)
628
+ if self.streaming:
629
+ response = ""
630
+ params["stream"] = True
631
+ for stream_resp in completion_with_retry(self, messages=messages, **params):
632
+ token = stream_resp["choices"][0]["delta"].get("content", "")
633
+ response += token
634
+ self.callback_manager.on_llm_new_token(
635
+ token,
636
+ verbose=self.verbose,
637
+ )
638
+ return LLMResult(
639
+ generations=[[Generation(text=response)]],
640
+ )
641
+ else:
642
+ full_response = completion_with_retry(self, messages=messages, **params)
643
+ return LLMResult(
644
+ generations=[
645
+ [Generation(text=full_response["choices"][0]["message"]["content"])]
646
+ ],
647
+ llm_output={"token_usage": full_response["usage"]},
648
+ )
649
+
650
+ async def _agenerate(
651
+ self, prompts: List[str], stop: Optional[List[str]] = None
652
+ ) -> LLMResult:
653
+ messages, params = self._get_chat_params(prompts, stop)
654
+ if self.streaming:
655
+ response = ""
656
+ params["stream"] = True
657
+ async for stream_resp in await acompletion_with_retry(
658
+ self, messages=messages, **params
659
+ ):
660
+ token = stream_resp["choices"][0]["delta"].get("content", "")
661
+ response += token
662
+ if self.callback_manager.is_async:
663
+ await self.callback_manager.on_llm_new_token(
664
+ token,
665
+ verbose=self.verbose,
666
+ )
667
+ else:
668
+ self.callback_manager.on_llm_new_token(
669
+ token,
670
+ verbose=self.verbose,
671
+ )
672
+ return LLMResult(
673
+ generations=[[Generation(text=response)]],
674
+ )
675
+ else:
676
+ full_response = await acompletion_with_retry(
677
+ self, messages=messages, **params
678
+ )
679
+ return LLMResult(
680
+ generations=[
681
+ [Generation(text=full_response["choices"][0]["message"]["content"])]
682
+ ],
683
+ llm_output={"token_usage": full_response["usage"]},
684
+ )
685
+
686
+ @property
687
+ def _identifying_params(self) -> Mapping[str, Any]:
688
+ """Get the identifying parameters."""
689
+ return {**{"model_name": self.model_name}, **self._default_params}
690
+
691
+ @property
692
+ def _llm_type(self) -> str:
693
+ """Return type of llm."""
694
+ return "openai-chat"
695
+
696
+
697
+ class AzureOpenAIChat(OpenAIChat):
698
+ """Azure specific OpenAI class that uses deployment name."""
699
+
700
+ deployment_name: str = ""
701
+ """Deployment name to use."""
702
+
703
+ @property
704
+ def _identifying_params(self) -> Mapping[str, Any]:
705
+ return {
706
+ **{"deployment_name": self.deployment_name},
707
+ **super()._identifying_params,
708
+ }
streamlit_langchain_chat/customized_langchain/vectorstores/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Wrappers on top of vector stores."""
2
+ from streamlit_langchain_chat.customized_langchain.vectorstores.faiss import FAISS
3
+ from streamlit_langchain_chat.customized_langchain.vectorstores.pinecone import Pinecone
4
+
5
+ __all__ = [
6
+ "FAISS",
7
+ "Pinecone",
8
+ ]
streamlit_langchain_chat/customized_langchain/vectorstores/faiss.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import hashlib
2
+
3
+ from langchain.vectorstores.faiss import *
4
+ from langchain.vectorstores.faiss import FAISS as OriginalFAISS
5
+
6
+ from streamlit_langchain_chat.customized_langchain.docstore.in_memory import InMemoryDocstore
7
+
8
+
9
+ class FAISS(OriginalFAISS):
10
+ def __add(
11
+ self,
12
+ texts: Iterable[str],
13
+ embeddings: Iterable[List[float]],
14
+ metadatas: Optional[List[dict]] = None,
15
+ **kwargs: Any,
16
+ ) -> List[str]:
17
+ if not isinstance(self.docstore, AddableMixin):
18
+ raise ValueError(
19
+ "If trying to add texts, the underlying docstore should support "
20
+ f"adding items, which {self.docstore} does not"
21
+ )
22
+ documents = []
23
+ for i, text in enumerate(texts):
24
+ metadata = metadatas[i] if metadatas else {}
25
+ documents.append(Document(page_content=text, metadata=metadata))
26
+ # Add to the index, the index_to_id mapping, and the docstore.
27
+ starting_len = len(self.index_to_docstore_id)
28
+ self.index.add(np.array(embeddings, dtype=np.float32))
29
+ # Get list of index, id, and docs.
30
+ full_info = [
31
+ (starting_len + i, str(uuid.uuid4()), doc)
32
+ for i, doc in enumerate(documents)
33
+ ]
34
+ # Add information to docstore and index.
35
+ self.docstore.add({_id: doc for _, _id, doc in full_info})
36
+ index_to_id = {index: _id for index, _id, _ in full_info}
37
+ self.index_to_docstore_id.update(index_to_id)
38
+ return [_id for _, _id, _ in full_info]
39
+
40
+ @classmethod
41
+ def __from(
42
+ cls,
43
+ texts: List[str],
44
+ embeddings: List[List[float]],
45
+ embedding: Embeddings,
46
+ metadatas: Optional[List[dict]] = None,
47
+ **kwargs: Any,
48
+ ) -> FAISS:
49
+ faiss = dependable_faiss_import()
50
+ index = faiss.IndexFlatL2(len(embeddings[0]))
51
+ index.add(np.array(embeddings, dtype=np.float32))
52
+ documents = []
53
+ for i, text in enumerate(texts):
54
+ metadata = metadatas[i] if metadatas else {}
55
+ documents.append(Document(page_content=text, metadata=metadata))
56
+ index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
57
+
58
+ # # TODO: cambiar para usar el hash. Y ver donde se pondria para que no cargara el chunk en el dataset
59
+ # index_to_id_2 = dict()
60
+ # for i in range(len(documents)):
61
+ # h = hashlib.new('sha256')
62
+ # text_ = documents[i].page_content
63
+ # h.update(text_.encode())
64
+ # index_to_id_2[i] = str(h.hexdigest())
65
+ # #
66
+ docstore = InMemoryDocstore(
67
+ {index_to_id[i]: doc for i, doc in enumerate(documents)}
68
+ )
69
+ return cls(embedding.embed_query, index, docstore, index_to_id)
70
+
71
+ @classmethod
72
+ def from_texts(
73
+ cls,
74
+ texts: List[str],
75
+ embedding: Embeddings,
76
+ metadatas: Optional[List[dict]] = None,
77
+ **kwargs: Any,
78
+ ) -> FAISS:
79
+ """Construct FAISS wrapper from raw documents.
80
+
81
+ This is a user friendly interface that:
82
+ 1. Embeds documents.
83
+ 2. Creates an in memory docstore
84
+ 3. Initializes the FAISS database
85
+
86
+ This is intended to be a quick way to get started.
87
+
88
+ Example:
89
+ .. code-block:: python
90
+
91
+ from langchain import FAISS
92
+ from langchain.embeddings import OpenAIEmbeddings
93
+ embeddings = OpenAIEmbeddings()
94
+ faiss = FAISS.from_texts(texts, embeddings)
95
+ """
96
+ # embeddings = embedding.embed_documents(texts)
97
+ print(f"len(texts): {len(texts)}") # TODO: borrar
98
+ embeddings = [embedding.embed_documents([text])[0] for text in texts]
99
+ print(f"len(embeddings): {len(embeddings)}") # TODO: borrar
100
+ return cls.__from(texts, embeddings, embedding, metadatas, **kwargs)
streamlit_langchain_chat/customized_langchain/vectorstores/pinecone.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.vectorstores.pinecone import *
2
+ from langchain.vectorstores.pinecone import Pinecone as OriginalPinecone
3
+
4
+
5
+ class Pinecone(OriginalPinecone):
6
+ @classmethod
7
+ def from_texts(
8
+ cls,
9
+ texts: List[str],
10
+ embedding: Embeddings,
11
+ metadatas: Optional[List[dict]] = None,
12
+ ids: Optional[List[str]] = None,
13
+ batch_size: int = 32,
14
+ text_key: str = "text",
15
+ index_name: Optional[str] = None,
16
+ namespace: Optional[str] = None,
17
+ **kwargs: Any,
18
+ ) -> Pinecone:
19
+ """Construct Pinecone wrapper from raw documents.
20
+
21
+ This is a user friendly interface that:
22
+ 1. Embeds documents.
23
+ 2. Adds the documents to a provided Pinecone index
24
+
25
+ This is intended to be a quick way to get started.
26
+
27
+ Example:
28
+ .. code-block:: python
29
+
30
+ from langchain import Pinecone
31
+ from langchain.embeddings import OpenAIEmbeddings
32
+ embeddings = OpenAIEmbeddings()
33
+ pinecone = Pinecone.from_texts(
34
+ texts,
35
+ embeddings,
36
+ index_name="langchain-demo"
37
+ )
38
+ """
39
+ try:
40
+ import pinecone
41
+ except ImportError:
42
+ raise ValueError(
43
+ "Could not import pinecone python package. "
44
+ "Please install it with `pip install pinecone-client`."
45
+ )
46
+ _index_name = index_name or str(uuid.uuid4())
47
+ indexes = pinecone.list_indexes() # checks if provided index exists
48
+ if _index_name in indexes:
49
+ index = pinecone.Index(_index_name)
50
+ else:
51
+ index = None
52
+ for i in range(0, len(texts), batch_size):
53
+ # set end position of batch
54
+ i_end = min(i + batch_size, len(texts))
55
+ # get batch of texts and ids
56
+ lines_batch = texts[i:i_end]
57
+ # create ids if not provided
58
+ if ids:
59
+ ids_batch = ids[i:i_end]
60
+ else:
61
+ ids_batch = [str(uuid.uuid4()) for n in range(i, i_end)]
62
+ # create embeddings
63
+ # embeds = embedding.embed_documents(lines_batch)
64
+ embeds = [embedding.embed_documents([line_batch])[0] for line_batch in lines_batch]
65
+ # prep metadata and upsert batch
66
+ if metadatas:
67
+ metadata = metadatas[i:i_end]
68
+ else:
69
+ metadata = [{} for _ in range(i, i_end)]
70
+ for j, line in enumerate(lines_batch):
71
+ metadata[j][text_key] = line
72
+ to_upsert = zip(ids_batch, embeds, metadata)
73
+ # Create index if it does not exist
74
+ if index is None:
75
+ pinecone.create_index(_index_name, dimension=len(embeds[0]))
76
+ index = pinecone.Index(_index_name)
77
+ # upsert to Pinecone
78
+ index.upsert(vectors=list(to_upsert), namespace=namespace)
79
+ return cls(index, embedding.embed_query, text_key, namespace)
streamlit_langchain_chat/dataset.py ADDED
@@ -0,0 +1,740 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from dataclasses import dataclass
3
+ from datetime import datetime
4
+ from functools import reduce
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+ import re
9
+ import requests
10
+ from requests.models import MissingSchema
11
+ import sys
12
+ from typing import List, Optional, Tuple, Dict, Callable, Any
13
+
14
+ from bs4 import BeautifulSoup
15
+ import docx
16
+ from html2text import html2text
17
+ import langchain
18
+ from langchain.callbacks import get_openai_callback
19
+ from langchain.cache import SQLiteCache
20
+ from langchain.chains import LLMChain
21
+ from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT
22
+ from langchain.chat_models import ChatOpenAI
23
+ from langchain.chat_models.base import BaseChatModel
24
+ from langchain.document_loaders import PyPDFLoader, PyMuPDFLoader
25
+ from langchain.embeddings.base import Embeddings
26
+ from langchain.embeddings.openai import OpenAIEmbeddings
27
+ from langchain.llms import OpenAI
28
+ from langchain.llms.base import LLM, BaseLLM
29
+ from langchain.prompts.chat import AIMessagePromptTemplate
30
+ from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter
31
+ from langchain.vectorstores import Pinecone as OriginalPinecone
32
+ import numpy as np
33
+ import openai
34
+ import pinecone
35
+ from pptx import Presentation
36
+ from pypdf import PdfReader
37
+ import trafilatura
38
+
39
+ from streamlit_langchain_chat.constants import *
40
+ from streamlit_langchain_chat.customized_langchain.vectorstores import FAISS
41
+ from streamlit_langchain_chat.customized_langchain.vectorstores import Pinecone
42
+ from streamlit_langchain_chat.utils import maybe_is_text, maybe_is_truncated
43
+ from streamlit_langchain_chat.prompts import *
44
+
45
+
46
+ if REUSE_ANSWERS:
47
+ CACHE_PATH = TEMP_DIR / "llm_cache.db"
48
+ os.makedirs(os.path.dirname(CACHE_PATH), exist_ok=True)
49
+ langchain.llm_cache = SQLiteCache(str(CACHE_PATH))
50
+
51
+ # option 1
52
+ TextSplitter = TokenTextSplitter
53
+ # option 2
54
+ # TextSplitter = RecursiveCharacterTextSplitter # usado por gpt4_pdf_chatbot_langchain (aka GPCL)
55
+
56
+
57
+ @dataclass
58
+ class Answer:
59
+ """A class to hold the answer to a question."""
60
+ question: str = ""
61
+ answer: str = ""
62
+ context: str = ""
63
+ chunks: str = ""
64
+ packages: List[Any] = None
65
+ references: str = ""
66
+ cost_str: str = ""
67
+ passages: Dict[str, str] = None
68
+ tokens: List[Dict] = None
69
+
70
+ def __post_init__(self):
71
+ """Initialize the answer."""
72
+ if self.packages is None:
73
+ self.packages = []
74
+ if self.passages is None:
75
+ self.passages = {}
76
+
77
+ def __str__(self) -> str:
78
+ """Return the answer as a string."""
79
+ return self.answer
80
+
81
+
82
+ def parse_docx(path, citation, key, chunk_chars=2000, overlap=50):
83
+ try:
84
+ document = docx.Document(path)
85
+ fullText = []
86
+ for paragraph in document.paragraphs:
87
+ fullText.append(paragraph.text)
88
+ doc = '\n'.join(fullText) + '\n'
89
+ except Exception as e:
90
+ print(f"code_error: {e}")
91
+ sys.exit(1)
92
+
93
+ if doc:
94
+ text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap)
95
+ texts = text_splitter.split_text(doc)
96
+ return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts)
97
+ else:
98
+ return [], []
99
+
100
+
101
+ # TODO: si pones un conector con el formato loader = ... ; data = loader.load();
102
+ # podrás poner todos los conectores de langchain
103
+ # https://langchain.readthedocs.io/en/stable/modules/document_loaders/examples/pdf.html
104
+ def parse_pdf(path, citation, key, chunk_chars=2000, overlap=50):
105
+ pdfFileObj = open(path, "rb")
106
+ pdfReader = PdfReader(pdfFileObj)
107
+ splits = []
108
+ split = ""
109
+ pages = []
110
+ metadatas = []
111
+ for i, page in enumerate(pdfReader.pages):
112
+ split += page.extract_text()
113
+ pages.append(str(i + 1))
114
+ # split could be so long it needs to be split
115
+ # into multiple chunks. Or it could be so short
116
+ # that it needs to be combined with the next chunk.
117
+ while len(split) > chunk_chars:
118
+ splits.append(split[:chunk_chars])
119
+ # pretty formatting of pages (e.g. 1-3, 4, 5-7)
120
+ pg = "-".join([pages[0], pages[-1]])
121
+ metadatas.append(
122
+ dict(
123
+ citation=citation,
124
+ dockey=key,
125
+ key=f"{key} pages {pg}",
126
+ )
127
+ )
128
+ split = split[chunk_chars - overlap:]
129
+ pages = [str(i + 1)]
130
+ if len(split) > overlap:
131
+ splits.append(split[:chunk_chars])
132
+ pg = "-".join([pages[0], pages[-1]])
133
+ metadatas.append(
134
+ dict(
135
+ citation=citation,
136
+ dockey=key,
137
+ key=f"{key} pages {pg}",
138
+ )
139
+ )
140
+ pdfFileObj.close()
141
+
142
+ # # ### option 2. PyPDFLoader
143
+ # loader = PyPDFLoader(path)
144
+ # data = loader.load_and_split()
145
+ # # ### option 2.1. PyPDFLoader usado por GPCL, aunque luego usa el
146
+ # loader = PyPDFLoader(path)
147
+ # rawDocs = loader.load()
148
+ # text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap)
149
+ # texts = text_splitter.split_documents(rawDocs)
150
+ # # ### option 3. PDFMiner. Este parece la mejor opcion
151
+ # loader = PyMuPDFLoader(path)
152
+ # data = loader.load()
153
+ return splits, metadatas
154
+
155
+
156
+ def parse_pptx(path, citation, key, chunk_chars=2000, overlap=50):
157
+ try:
158
+ presentation = Presentation(path)
159
+ fullText = []
160
+ for slide in presentation.slides:
161
+ for shape in slide.shapes:
162
+ if hasattr(shape, "text"):
163
+ fullText.append(shape.text)
164
+ doc = ''.join(fullText)
165
+
166
+ if doc:
167
+ text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap)
168
+ texts = text_splitter.split_text(doc)
169
+ return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts)
170
+ else:
171
+ return [], []
172
+
173
+ except Exception as e:
174
+ print(f"code_error: {e}")
175
+ sys.exit(1)
176
+
177
+
178
+ def parse_txt(path, citation, key, chunk_chars=2000, overlap=50, html=False):
179
+ try:
180
+ with open(path) as f:
181
+ doc = f.read()
182
+ except UnicodeDecodeError as e:
183
+ with open(path, encoding="utf-8", errors="ignore") as f:
184
+ doc = f.read()
185
+ if html:
186
+ doc = html2text(doc)
187
+ # yo, no idea why but the texts are not split correctly
188
+ text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap)
189
+ texts = text_splitter.split_text(doc)
190
+ return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts)
191
+
192
+
193
+ def parse_url(url: str, citation, key, chunk_chars=2000, overlap=50):
194
+ def beautifulsoup_extract_text_fallback(response_content):
195
+ """
196
+ This is a fallback function, so that we can always return a value for text content.
197
+ Even for when both Trafilatura and BeautifulSoup are unable to extract the text from a
198
+ single URL.
199
+ """
200
+
201
+ # Create the beautifulsoup object:
202
+ soup = BeautifulSoup(response_content, 'html.parser')
203
+
204
+ # Finding the text:
205
+ text = soup.find_all(text=True)
206
+
207
+ # Remove unwanted tag elements:
208
+ cleaned_text = ''
209
+ blacklist = [
210
+ '[document]',
211
+ 'noscript',
212
+ 'header',
213
+ 'html',
214
+ 'meta',
215
+ 'head',
216
+ 'input',
217
+ 'script',
218
+ 'style', ]
219
+
220
+ # Then we will loop over every item in the extract text and make sure that the beautifulsoup4 tag
221
+ # is NOT in the blacklist
222
+ for item in text:
223
+ if item.parent.name not in blacklist:
224
+ cleaned_text += f'{item} ' # cleaned_text += '{} '.format(item)
225
+
226
+ # Remove any tab separation and strip the text:
227
+ cleaned_text = cleaned_text.replace('\t', '')
228
+ return cleaned_text.strip()
229
+
230
+ def extract_text_from_single_web_page(url):
231
+ print(f"\n===========\n{url=}\n===========\n")
232
+ downloaded_url = trafilatura.fetch_url(url)
233
+ a = None
234
+ try:
235
+ a = trafilatura.extract(downloaded_url,
236
+ output_format='json',
237
+ with_metadata=True,
238
+ include_comments=False,
239
+ date_extraction_params={'extensive_search': True,
240
+ 'original_date': True})
241
+ except AttributeError:
242
+ a = trafilatura.extract(downloaded_url,
243
+ output_format='json',
244
+ with_metadata=True,
245
+ date_extraction_params={'extensive_search': True,
246
+ 'original_date': True})
247
+ except Exception as e:
248
+ print(f"code_error: {e}")
249
+
250
+ if a:
251
+ json_output = json.loads(a)
252
+ return json_output['text']
253
+ else:
254
+ try:
255
+ headers = {'User-Agent': 'Chrome/83.0.4103.106'}
256
+ resp = requests.get(url, headers=headers)
257
+ print(f"{resp=}\n")
258
+ # We will only extract the text from successful requests:
259
+ if resp.status_code == 200:
260
+ return beautifulsoup_extract_text_fallback(resp.content)
261
+ else:
262
+ # This line will handle for any failures in both the Trafilature and BeautifulSoup4 functions:
263
+ return np.nan
264
+ # Handling for any URLs that don't have the correct protocol
265
+ except MissingSchema:
266
+ return np.nan
267
+
268
+ text_to_split = extract_text_from_single_web_page(url)
269
+ text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap)
270
+ texts = text_splitter.split_text(text_to_split)
271
+ return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts)
272
+
273
+
274
+ def read_source(path: str = None,
275
+ citation: str = None,
276
+ key: str = None,
277
+ chunk_chars: int = 3000,
278
+ overlap: int = 100,
279
+ disable_check: bool = False):
280
+ if path.endswith(".pdf"):
281
+ return parse_pdf(path, citation, key, chunk_chars, overlap)
282
+ elif path.endswith(".txt"):
283
+ return parse_txt(path, citation, key, chunk_chars, overlap)
284
+ elif path.endswith(".html"):
285
+ return parse_txt(path, citation, key, chunk_chars, overlap, html=True)
286
+ elif path.endswith(".docx"):
287
+ return parse_docx(path, citation, key, chunk_chars, overlap)
288
+ elif path.endswith(".pptx"):
289
+ return parse_pptx(path, citation, key, chunk_chars, overlap)
290
+ elif path.startswith("http://") or path.startswith("https://"):
291
+ return parse_url(path, citation, key, chunk_chars, overlap)
292
+ # TODO: poner mas conectores
293
+ # else:
294
+ # return parse_code_txt(path, citation, key, chunk_chars, overlap)
295
+ else:
296
+ raise "unknown extension"
297
+
298
+
299
+ class Dataset:
300
+ """A collection of documents to be used for answering questions."""
301
+ def __init__(
302
+ self,
303
+ chunk_size_limit: int = 3000,
304
+ llm: Optional[BaseLLM] | Optional[BaseChatModel] = None,
305
+ summary_llm: Optional[BaseLLM] = None,
306
+ name: str = "default",
307
+ index_path: Optional[Path] = None,
308
+ ) -> None:
309
+ """Initialize the collection of documents.
310
+
311
+ Args:
312
+ chunk_size_limit: The maximum number of characters to use for a single chunk of text.
313
+ llm: The language model to use for answering questions. Default - OpenAI chat-gpt-turbo
314
+ summary_llm: The language model to use for summarizing documents. If None, llm is used.
315
+ name: The name of the collection.
316
+ index_path: The path to the index file IF pickled. If None, defaults to using name in $HOME/.paperqa/name
317
+ """
318
+ self.docs = dict()
319
+ self.keys = set()
320
+ self.chunk_size_limit = chunk_size_limit
321
+
322
+ self.index_docstore = None
323
+
324
+ if llm is None:
325
+ llm = ChatOpenAI(temperature=0.1, max_tokens=512)
326
+ if summary_llm is None:
327
+ summary_llm = llm
328
+ self.update_llm(llm, summary_llm)
329
+
330
+ if index_path is None:
331
+ index_path = TEMP_DIR / name
332
+ self.index_path = index_path
333
+ self.name = name
334
+
335
+ def update_llm(self, llm: BaseLLM | ChatOpenAI, summary_llm: Optional[BaseLLM] = None) -> None:
336
+ """Update the LLM for answering questions."""
337
+ self.llm = llm
338
+ if summary_llm is None:
339
+ summary_llm = llm
340
+ self.summary_llm = summary_llm
341
+ self.summary_chain = LLMChain(prompt=chat_summary_prompt, llm=summary_llm)
342
+ self.search_chain = LLMChain(prompt=search_prompt, llm=llm)
343
+ self.cite_chain = LLMChain(prompt=citation_prompt, llm=llm)
344
+
345
+ def add(
346
+ self,
347
+ path: str,
348
+ citation: Optional[str] = None,
349
+ key: Optional[str] = None,
350
+ disable_check: bool = False,
351
+ chunk_chars: Optional[int] = 3000,
352
+ ) -> None:
353
+ """Add a document to the collection."""
354
+
355
+ if path in self.docs:
356
+ print(f"Document {path} already in collection.")
357
+ return None
358
+
359
+ if citation is None:
360
+ # peak first chunk
361
+ texts, _ = read_source(path, "", "", chunk_chars=chunk_chars)
362
+ with get_openai_callback() as cb:
363
+ citation = self.cite_chain.run(texts[0])
364
+ if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation:
365
+ citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}"
366
+
367
+ if key is None:
368
+ # get first name and year from citation
369
+ try:
370
+ author = re.search(r"([A-Z][a-z]+)", citation).group(1)
371
+ except AttributeError:
372
+ # panicking - no word??
373
+ raise ValueError(
374
+ f"Could not parse key from citation {citation}. Consider just passing key explicitly - e.g. docs.py (path, citation, key='mykey')"
375
+ )
376
+ try:
377
+ year = re.search(r"(\d{4})", citation).group(1)
378
+ except AttributeError:
379
+ year = ""
380
+ key = f"{author}{year}"
381
+ suffix = ""
382
+ while key + suffix in self.keys:
383
+ # move suffix to next letter
384
+ if suffix == "":
385
+ suffix = "a"
386
+ else:
387
+ suffix = chr(ord(suffix) + 1)
388
+ key += suffix
389
+ self.keys.add(key)
390
+
391
+ texts, metadata = read_source(path, citation, key, chunk_chars=chunk_chars)
392
+ # loose check to see if document was loaded
393
+ #
394
+ if len("".join(texts)) < 10 or (
395
+ not disable_check and not maybe_is_text("".join(texts))
396
+ ):
397
+ raise ValueError(
398
+ f"This does not look like a text document: {path}. Path disable_check to ignore this error."
399
+ )
400
+
401
+ self.docs[path] = dict(texts=texts, metadata=metadata, key=key)
402
+ if self.index_docstore is not None:
403
+ self.index_docstore.add_texts(texts, metadatas=metadata)
404
+
405
+ def clear(self) -> None:
406
+ """Clear the collection of documents."""
407
+ self.docs = dict()
408
+ self.keys = set()
409
+ self.index_docstore = None
410
+ # delete index file
411
+ pkl = self.index_path / "index.pkl"
412
+ if pkl.exists():
413
+ pkl.unlink()
414
+ fs = self.index_path / "index.faiss"
415
+ if fs.exists():
416
+ fs.unlink()
417
+
418
+ @property
419
+ def doc_previews(self) -> List[Tuple[int, str, str]]:
420
+ """Return a list of tuples of (key, citation) for each document."""
421
+ return [
422
+ (
423
+ len(doc["texts"]),
424
+ doc["metadata"][0]["dockey"],
425
+ doc["metadata"][0]["citation"],
426
+ )
427
+ for doc in self.docs.values()
428
+ ]
429
+
430
+ # to pickle, we have to save the index as a file
431
+ def __getstate__(self, embedding: Embeddings):
432
+ if embedding is None:
433
+ embedding = OpenAIEmbeddings()
434
+ if self.index_docstore is None and len(self.docs) > 0:
435
+ self._build_faiss_index(embedding)
436
+ state = self.__dict__.copy()
437
+ if self.index_docstore is not None:
438
+ state["_index"].save_local(self.index_path)
439
+ del state["_index"]
440
+ # remove LLMs (they can have callbacks, which can't be pickled)
441
+ del state["summary_chain"]
442
+ del state["qa_chain"]
443
+ del state["cite_chain"]
444
+ del state["search_chain"]
445
+ return state
446
+
447
+ def __setstate__(self, state):
448
+ self.__dict__.update(state)
449
+ try:
450
+ self.index_docstore = FAISS.load_local(self.index_path, OpenAIEmbeddings())
451
+ except:
452
+ # they use some special exception type, but I don't want to import it
453
+ self.index_docstore = None
454
+ self.update_llm(
455
+ ChatOpenAI(temperature=0.1, max_tokens=512)
456
+ )
457
+
458
+ def _build_faiss_index(self, embedding: Embeddings = None):
459
+ if embedding is None:
460
+ embedding = OpenAIEmbeddings()
461
+ if self.index_docstore is None:
462
+ texts = reduce(
463
+ lambda x, y: x + y, [doc["texts"] for doc in self.docs.values()], []
464
+ )
465
+ metadatas = reduce(
466
+ lambda x, y: x + y, [doc["metadata"] for doc in self.docs.values()], []
467
+ )
468
+
469
+ # if the index exists, load it
470
+ if LOAD_INDEX_LOCALLY and (self.index_path / "index.faiss").exists():
471
+ self.index_docstore = FAISS.load_local(self.index_path, embedding)
472
+
473
+ # search if the text and metadata already existed in the index
474
+ for i in reversed(range(len(texts))):
475
+ text = texts[i]
476
+ metadata = metadatas[i]
477
+ for key, value in self.index_docstore.docstore.dict_.items():
478
+ if value.page_content == text:
479
+ if value.metadata.get('citation').split(os.sep)[-1] != metadata.get('citation').split(os.sep)[-1]:
480
+ self.index_docstore.docstore.dict_[key].metadata['citation'] = metadata.get('citation').split(os.sep)[-1]
481
+ self.index_docstore.docstore.dict_[key].metadata['dockey'] = metadata.get('citation').split(os.sep)[-1]
482
+ self.index_docstore.docstore.dict_[key].metadata['key'] = metadata.get('citation').split(os.sep)[-1]
483
+ texts.pop(i)
484
+ metadatas.pop(i)
485
+
486
+ # add remaining texts
487
+ if texts:
488
+ self.index_docstore.add_texts(texts=texts, metadatas=metadatas)
489
+ else:
490
+ # crete new index
491
+ self.index_docstore = FAISS.from_texts(texts, embedding, metadatas=metadatas)
492
+ #
493
+
494
+ if SAVE_INDEX_LOCALLY:
495
+ # save index.
496
+ self.index_docstore.save_local(self.index_path)
497
+
498
+ def _build_pinecone_index(self, embedding: Embeddings = None):
499
+ if embedding is None:
500
+ embedding = OpenAIEmbeddings()
501
+ if self.index_docstore is None:
502
+ pinecone.init(
503
+ api_key=os.environ['PINECONE_API_KEY'], # find at app.pinecone.io
504
+ environment=os.environ['PINECONE_ENVIRONMENT'] # next to api key in console
505
+ )
506
+ texts = reduce(
507
+ lambda x, y: x + y, [doc["texts"] for doc in self.docs.values()], []
508
+ )
509
+ metadatas = reduce(
510
+ lambda x, y: x + y, [doc["metadata"] for doc in self.docs.values()], []
511
+ )
512
+
513
+ # TODO: que cuando exista que no lo borre, sino que lo actualice
514
+ # index_name = "langchain-demo1"
515
+ # if index_name in pinecone.list_indexes():
516
+ # self.index_docstore = pinecone.Index(index_name)
517
+ # vectors = []
518
+ # for text, metadata in zip(texts, metadatas):
519
+ # # embed = <faltaria saber con que embedding se hizo el index que ya existia>
520
+ # self.index_docstore.upsert(vectors=vectors)
521
+ # else:
522
+ # if openai.api_type == 'azure':
523
+ # self.index_docstore = Pinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name)
524
+ # else:
525
+ # self.index_docstore = OriginalPinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name)
526
+
527
+ index_name = "langchain-demo1"
528
+
529
+ # if the index exists, delete it
530
+ if index_name in pinecone.list_indexes():
531
+ pinecone.delete_index(index_name)
532
+
533
+ # create new index
534
+ if openai.api_type == 'azure':
535
+ self.index_docstore = Pinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name)
536
+ else:
537
+ self.index_docstore = OriginalPinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name)
538
+
539
+ def get_evidence(
540
+ self,
541
+ answer: Answer,
542
+ embedding: Embeddings,
543
+ k: int = 3,
544
+ max_sources: int = 5,
545
+ marginal_relevance: bool = True,
546
+ ) -> str:
547
+ if self.index_docstore is None:
548
+ self._build_faiss_index(embedding)
549
+
550
+ init_search_time = time.time()
551
+
552
+ # want to work through indices but less k
553
+ if marginal_relevance:
554
+ docs = self.index_docstore.max_marginal_relevance_search(
555
+ answer.question, k=k, fetch_k=5 * k
556
+ )
557
+ else:
558
+ docs = self.index_docstore.similarity_search(
559
+ answer.question, k=k, fetch_k=5 * k
560
+ )
561
+ if OPERATING_MODE == "debug":
562
+ print(f"time to search docs to build context: {time.time() - init_search_time:.2f} [s]")
563
+ init_summary_time = time.time()
564
+ partial_summary_time = ""
565
+ for i, doc in enumerate(docs):
566
+ with get_openai_callback() as cb:
567
+ init__partial_summary_time = time.time()
568
+ summary_of_chunked_text = self.summary_chain.run(
569
+ question=answer.question, context_str=doc.page_content
570
+ )
571
+ if OPERATING_MODE == "debug":
572
+ partial_summary_time += f"- time to make relevant summary of doc '{i}': {time.time() - init__partial_summary_time:.2f} [s]\n"
573
+ engine = self.summary_chain.llm.model_kwargs.get('deployment_id') or self.summary_chain.llm.model_name
574
+ if not answer.tokens:
575
+ answer.tokens = [{
576
+ 'engine': engine,
577
+ 'total_tokens': cb.total_tokens}]
578
+ else:
579
+ answer.tokens.append({
580
+ 'engine': engine,
581
+ 'total_tokens': cb.total_tokens
582
+ })
583
+ summarized_package = (
584
+ doc.metadata["key"],
585
+ doc.metadata["citation"],
586
+ summary_of_chunked_text,
587
+ doc.page_content,
588
+ )
589
+ if "Not applicable" not in summary_of_chunked_text and summarized_package not in answer.packages:
590
+ answer.packages.append(summarized_package)
591
+ yield answer
592
+ if len(answer.packages) == max_sources:
593
+ break
594
+ if OPERATING_MODE == "debug":
595
+ print(f"time to make all relevant summaries: {time.time() - init_summary_time:.2f} [s]")
596
+ # no se printea el ultimo caracter porque es un \n
597
+ print(partial_summary_time[:-1])
598
+ context_str = "\n\n".join(
599
+ [f"{citation}: {summary_of_chunked_text}"
600
+ for key, citation, summary_of_chunked_text, chunked_text in answer.packages
601
+ if "Not applicable" not in summary_of_chunked_text]
602
+ )
603
+ chunks_str = "\n\n".join(
604
+ [f"{citation}: {chunked_text}"
605
+ for key, citation, summary_of_chunked_text, chunked_text in answer.packages
606
+ if "Not applicable" not in summary_of_chunked_text]
607
+ )
608
+ valid_keys = [key
609
+ for key, citation, summary_of_chunked_text, chunked_textin in answer.packages
610
+ if "Not applicable" not in summary_of_chunked_text]
611
+ if len(valid_keys) > 0:
612
+ context_str += "\n\nValid keys: " + ", ".join(valid_keys)
613
+ chunks_str += "\n\nValid keys: " + ", ".join(valid_keys)
614
+ answer.context = context_str
615
+ answer.chunks = chunks_str
616
+ yield answer
617
+
618
+ def query(
619
+ self,
620
+ query: str,
621
+ embedding: Embeddings,
622
+ chat_history: list[tuple[str, str]],
623
+ k: int = 10,
624
+ max_sources: int = 5,
625
+ length_prompt: str = "about 100 words",
626
+ marginal_relevance: bool = True,
627
+ ):
628
+ for answer in self._query(
629
+ query,
630
+ embedding,
631
+ chat_history,
632
+ k=k,
633
+ max_sources=max_sources,
634
+ length_prompt=length_prompt,
635
+ marginal_relevance=marginal_relevance,
636
+ ):
637
+ pass
638
+ return answer
639
+
640
+ def _query(
641
+ self,
642
+ query: str,
643
+ embedding: Embeddings,
644
+ chat_history: list[tuple[str, str]],
645
+ k: int,
646
+ max_sources: int,
647
+ length_prompt: str,
648
+ marginal_relevance: bool,
649
+ ):
650
+ if k < max_sources:
651
+ k = max_sources + 1
652
+
653
+ answer = Answer(question=query)
654
+
655
+ messages_qa = [system_message_prompt]
656
+ if len(chat_history) != 0:
657
+ for conversation in chat_history:
658
+ messages_qa.append(HumanMessagePromptTemplate.from_template(conversation[0]))
659
+ messages_qa.append(AIMessagePromptTemplate.from_template(conversation[1]))
660
+ messages_qa.append(human_qa_message_prompt)
661
+ chat_qa_prompt = ChatPromptTemplate.from_messages(messages_qa)
662
+ self.qa_chain = LLMChain(prompt=chat_qa_prompt, llm=self.llm)
663
+
664
+ for answer in self.get_evidence(
665
+ answer,
666
+ embedding,
667
+ k=k,
668
+ max_sources=max_sources,
669
+ marginal_relevance=marginal_relevance,
670
+ ):
671
+ yield answer
672
+
673
+ references_dict = dict()
674
+ passages = dict()
675
+ if len(answer.context) < 10:
676
+ answer_text = "I cannot answer this question due to insufficient information."
677
+ else:
678
+ with get_openai_callback() as cb:
679
+ init_qa_time = time.time()
680
+ answer_text = self.qa_chain.run(
681
+ question=answer.question, context_str=answer.context, length=length_prompt
682
+ )
683
+ if OPERATING_MODE == "debug":
684
+ print(f"time to make the Q&A answer: {time.time() - init_qa_time:.2f} [s]")
685
+ engine = self.qa_chain.llm.model_kwargs.get('deployment_id') or self.qa_chain.llm.model_name
686
+ if not answer.tokens:
687
+ answer.tokens = [{
688
+ 'engine': engine,
689
+ 'total_tokens': cb.total_tokens}]
690
+ else:
691
+ answer.tokens.append({
692
+ 'engine': engine,
693
+ 'total_tokens': cb.total_tokens
694
+ })
695
+
696
+ # it still happens lol
697
+ if "(Foo2012)" in answer_text:
698
+ answer_text = answer_text.replace("(Foo2012)", "")
699
+ for key, citation, summary, text in answer.packages:
700
+ # do check for whole key (so we don't catch Callahan2019a with Callahan2019)
701
+ skey = key.split(" ")[0]
702
+ if skey + " " in answer_text or skey + ")" in answer_text:
703
+ references_dict[skey] = citation
704
+ passages[key] = text
705
+ references_str = "\n\n".join(
706
+ [f"{i+1}. ({k}): {c}" for i, (k, c) in enumerate(references_dict.items())]
707
+ )
708
+
709
+ # cost_str = f"{answer_text}\n\n"
710
+ cost_str = ""
711
+ itemized_cost = ""
712
+ total_amount = 0
713
+ for d in answer.tokens:
714
+ total_tokens = d.get('total_tokens')
715
+ if total_tokens:
716
+ engine = d.get('engine')
717
+ key_price = None
718
+ for key in PRICES.keys():
719
+ if re.match(f"{key}", engine):
720
+ key_price = key
721
+ break
722
+ if PRICES.get(key_price):
723
+ partial_amount = total_tokens / 1000 * PRICES.get(key_price)
724
+ total_amount += partial_amount
725
+ itemized_cost += f"- {engine}: {total_tokens} tokens\t ---> ${partial_amount:.4f},\n"
726
+ else:
727
+ itemized_cost += f"- {engine}: {total_tokens} tokens,\n"
728
+ # delete ,\n
729
+ itemized_cost = itemized_cost[:-2]
730
+
731
+ # add tokens to formatted answer
732
+ cost_str += f"Total cost: ${total_amount:.4f}\nItemized cost:\n{itemized_cost}"
733
+
734
+ answer.answer = answer_text
735
+ answer.cost_str = cost_str
736
+ answer.references = references_str
737
+ answer.passages = passages
738
+ yield answer
739
+
740
+
streamlit_langchain_chat/inputs/__init__.py ADDED
File without changes
streamlit_langchain_chat/prompts.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import langchain.prompts as prompts
2
+ from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
3
+ from datetime import datetime
4
+
5
+ summary_template = """Summarize and provide direct quotes from the text below to help answer a question.
6
+ Do not directly answer the question, instead provide a summary and quotes with the context of the user's question.
7
+ Do not use outside sources.
8
+ Reply with "Not applicable" if the text is unrelated to the question.
9
+ Use 75 or less words.
10
+ Remember, if the user does not specify a language, reply in the language of the user's question.
11
+
12
+ {context_str}
13
+
14
+ User's question: {question}
15
+ Relevant Information Summary:"""
16
+ summary_prompt = prompts.PromptTemplate(
17
+ input_variables=["question", "context_str"],
18
+ template=summary_template,
19
+ )
20
+
21
+ qa_template = """Write an answer for the user's question below solely based on the provided context.
22
+ If the user does not specify how many words the answer should be, the length of the answer should be {length}.
23
+ If the context is irrelevant, reply "Your question falls outside the scope of University of Sydney policy, so I cannot answer".
24
+ For each sentence in your answer, indicate which sources most support it via valid citation markers at the end of sentences, like (Example2012).
25
+ Answer in an unbiased and professional tone.
26
+ Make clear what is your opinion.
27
+ Use Markdown for formatting code or text, and try to use direct quotes to support arguments.
28
+ Remember, if the user does not specify a language, answer in the language of the user's question.
29
+
30
+ Context:
31
+ {context_str}
32
+
33
+
34
+ User's question: {question}
35
+ Answer:
36
+ """
37
+ qa_prompt = prompts.PromptTemplate(
38
+ input_variables=["question", "context_str", "length"],
39
+ template=qa_template,
40
+ )
41
+
42
+ # usado por GPCL
43
+ qa_prompt_GPCL = prompts.PromptTemplate(
44
+ input_variables=["question", "context_str"],
45
+ template="You are an AI assistant providing helpful advice about University of Sydney policy. You are given the following extracted parts of a long document and a question. Provide a conversational answer based on the context provided."
46
+ "You should only provide hyperlinks that reference the context below. Do NOT make up hyperlinks."
47
+ 'If you can not find the answer in the context below, just say "Hmm, I am not sure. Could you please rephrase your question?" Do not try to make up an answer.'
48
+ "If the question is not related to the context, politely respond that you are tuned to only answer questions that are related to the context.\n\n"
49
+ "Question: {question}\n"
50
+ "=========\n"
51
+ "{context_str}\n"
52
+ "=========\n"
53
+ "Answer in Markdown:",
54
+ )
55
+
56
+ search_prompt = prompts.PromptTemplate(
57
+ input_variables=["question"],
58
+ template="We want to answer the following question: {question} \n"
59
+ "Provide three different targeted keyword searches (one search per line) "
60
+ "that will find papers that help answer the question. Do not use boolean operators. "
61
+ "Recent years are 2021, 2022, 2023.\n\n"
62
+ "1.",
63
+ )
64
+
65
+
66
+ def _get_datetime():
67
+ now = datetime.now()
68
+ return now.strftime("%m/%d/%Y")
69
+
70
+
71
+ citation_prompt = prompts.PromptTemplate(
72
+ input_variables=["text"],
73
+ template="Provide a possible citation for the following text in MLA Format. Today's date is {date}\n"
74
+ "{text}\n\n"
75
+ "Citation:",
76
+ partial_variables={"date": _get_datetime},
77
+ )
78
+
79
+ system_template = """You are an AI chatbot with knowledge of the University of Sydney's legal policies that answers in an unbiased, professional tone.
80
+ You sometimes refuse to answer if there is insufficient information.
81
+ If the user does not specify a language, answer in the language of the user's question. """
82
+ system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
83
+
84
+ human_summary_message_prompt = HumanMessagePromptTemplate.from_template(summary_template)
85
+ chat_summary_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_summary_message_prompt])
86
+
87
+ human_qa_message_prompt = HumanMessagePromptTemplate.from_template(qa_template)
88
+ # chat_qa_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_qa_message_prompt]) # TODO: borrar
89
+
90
+ # human_condense_message_prompt = HumanMessagePromptTemplate.from_template(condense_template)
91
+ # chat_condense_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_condense_message_prompt])
streamlit_langchain_chat/streamlit_app.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ To run:
4
+ - activate the virtual environment
5
+ - streamlit run path\to\streamlit_app.py
6
+ """
7
+ import logging
8
+ import os
9
+ import re
10
+ import sys
11
+ import time
12
+ import warnings
13
+ import shutil
14
+
15
+ from langchain.chat_models import ChatOpenAI
16
+ from langchain.embeddings.openai import OpenAIEmbeddings
17
+ import openai
18
+ import pandas as pd
19
+ import streamlit as st
20
+ from st_aggrid import GridOptionsBuilder, AgGrid, GridUpdateMode, ColumnsAutoSizeMode
21
+ from streamlit_chat import message
22
+
23
+ from streamlit_langchain_chat.constants import *
24
+ from streamlit_langchain_chat.customized_langchain.llms import OpenAI, AzureOpenAI, AzureOpenAIChat
25
+ from streamlit_langchain_chat.dataset import Dataset
26
+
27
+ # Configure logger
28
+ logging.basicConfig(format="\n%(asctime)s\n%(message)s", level=logging.INFO, force=True)
29
+ logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
30
+
31
+ warnings.filterwarnings('ignore')
32
+
33
+ if 'generated' not in st.session_state:
34
+ st.session_state['generated'] = []
35
+ if 'past' not in st.session_state:
36
+ st.session_state['past'] = []
37
+ if 'costs' not in st.session_state:
38
+ st.session_state['costs'] = []
39
+ if 'contexts' not in st.session_state:
40
+ st.session_state['contexts'] = []
41
+ if 'chunks' not in st.session_state:
42
+ st.session_state['chunks'] = []
43
+ if 'user_input' not in st.session_state:
44
+ st.session_state['user_input'] = ""
45
+ if 'dataset' not in st.session_state:
46
+ st.session_state['dataset'] = None
47
+
48
+
49
+ def check_api_keys() -> bool:
50
+ source_id = app.params['source_id']
51
+ index_id = app.params['index_id']
52
+
53
+ open_api_key = os.getenv('OPENAI_API_KEY', '')
54
+ openapi_api_key_ready = type(open_api_key) is str and len(open_api_key) > 0
55
+
56
+ pinecone_api_key = os.getenv('PINECONE_API_KEY', '')
57
+ pinecone_api_key_ready = type(pinecone_api_key) is str and len(pinecone_api_key) > 0 if index_id == 2 else True
58
+
59
+ is_ready = True if openapi_api_key_ready and pinecone_api_key_ready else False
60
+ return is_ready
61
+
62
+
63
+ def check_combination_point() -> bool:
64
+ type_id = app.params['type_id']
65
+ open_api_key = os.getenv('OPENAI_API_KEY', '')
66
+ openapi_api_key_ready = type(open_api_key) is str and len(open_api_key) > 0
67
+ api_base = app.params['api_base']
68
+
69
+ if type_id == 1:
70
+ deployment_id = app.params['deployment_id']
71
+ return True if openapi_api_key_ready and api_base and deployment_id else False
72
+ elif type_id == 2:
73
+ return True if openapi_api_key_ready and api_base else False
74
+ else:
75
+ return False
76
+
77
+
78
+ def check_index() -> bool:
79
+ dataset = st.session_state['dataset']
80
+
81
+ index_built = dataset.index_docstore if hasattr(dataset, "index_docstore") else False
82
+ without_source = app.params['source_id'] == 4
83
+ is_ready = True if index_built or without_source else False
84
+ return is_ready
85
+
86
+
87
+ def check_index_point() -> bool:
88
+ index_id = app.params['index_id']
89
+
90
+ pinecone_api_key = os.getenv('PINECONE_API_KEY', '')
91
+ pinecone_api_key_ready = type(pinecone_api_key) is str and len(pinecone_api_key) > 0 if index_id == 2 else True
92
+ pinecone_environment = os.getenv('PINECONE_ENVIRONMENT', False) if index_id == 2 else True
93
+
94
+ is_ready = True if index_id and pinecone_api_key_ready and pinecone_environment else False
95
+ return is_ready
96
+
97
+
98
+ def check_params_point() -> bool:
99
+ max_sources = app.params['max_sources']
100
+ temperature = app.params['temperature']
101
+
102
+ is_ready = True if max_sources and isinstance(temperature, float) else False
103
+ return is_ready
104
+
105
+
106
+ def check_source_point() -> bool:
107
+ return True
108
+
109
+
110
+ def clear_chat_history():
111
+ if st.session_state['past'] or st.session_state['generated'] or st.session_state['contexts'] or st.session_state['chunks'] or st.session_state['costs']:
112
+ st.session_state['past'] = []
113
+ st.session_state['generated'] = []
114
+ st.session_state['contexts'] = []
115
+ st.session_state['chunks'] = []
116
+ st.session_state['costs'] = []
117
+
118
+
119
+ def clear_index():
120
+ if dataset := st.session_state['dataset']:
121
+ # delete directory (with files)
122
+ index_path = dataset.index_path
123
+ if index_path.exists():
124
+ shutil.rmtree(str(index_path))
125
+
126
+ # update variable
127
+ st.session_state['dataset'] = None
128
+
129
+ elif (TEMP_DIR / "default").exists():
130
+ shutil.rmtree(str(TEMP_DIR / "default"))
131
+
132
+
133
+ def check_sources() -> bool:
134
+ uploaded_files_rows = app.params['uploaded_files_rows']
135
+ urls_df = app.params['urls_df']
136
+ source_id = app.params['source_id']
137
+
138
+ some_files = True if uploaded_files_rows and uploaded_files_rows[-1].get('filepath') != "" else False
139
+ some_urls = bool([True for url, citation in urls_df.to_numpy() if url])
140
+
141
+ only_local_files = some_files and not some_urls
142
+ only_urls = not some_files and some_urls
143
+ is_ready = only_local_files or only_urls or (source_id == 4)
144
+ return is_ready
145
+
146
+
147
+ def collect_dataset_and_built_index():
148
+ start = time.time()
149
+ uploaded_files_rows = app.params['uploaded_files_rows']
150
+ urls_df = app.params['urls_df']
151
+ type_id = app.params['type_id']
152
+ temperature = app.params['temperature']
153
+ index_id = app.params['index_id']
154
+ api_base = app.params['api_base']
155
+ deployment_id = app.params['deployment_id']
156
+
157
+ some_files = True if uploaded_files_rows and uploaded_files_rows[-1].get('filepath') != "" else False
158
+ some_urls = bool([True for url, citation in urls_df.to_numpy() if url])
159
+
160
+ openai.api_type = "azure" if type_id == 1 else "open_ai"
161
+ openai.api_base = api_base
162
+ openai.api_version = "2023-03-15-preview" if type_id == 1 else None
163
+
164
+ if deployment_id != "text-davinci-003":
165
+ dataset = Dataset(
166
+ llm=ChatOpenAI(
167
+ temperature=temperature,
168
+ max_tokens=512,
169
+ deployment_id=deployment_id,
170
+ )
171
+ )
172
+ else:
173
+ dataset = Dataset(
174
+ llm=OpenAI(
175
+ temperature=temperature,
176
+ max_tokens=512,
177
+ deployment_id=COMBINATIONS_OPTIONS.get(combination_id).get('deployment_name'),
178
+ )
179
+ )
180
+
181
+ # get url documents
182
+ if some_urls:
183
+ urls_df = urls_df.reset_index()
184
+ for url_index, url_row in urls_df.iterrows():
185
+ url = url_row.get('urls', '')
186
+ citation = url_row.get('citation string', '')
187
+ if url:
188
+ try:
189
+ dataset.add(
190
+ url,
191
+ citation,
192
+ citation,
193
+ disable_check=True # True to accept Japanese letters
194
+ )
195
+ except Exception as e:
196
+ print(e)
197
+ pass
198
+
199
+ # dataset is pandas dataframe
200
+ if some_files:
201
+ for uploaded_files_row in uploaded_files_rows:
202
+ key = uploaded_files_row.get('citation string') if ',' not in uploaded_files_row.get('citation string') else None
203
+ dataset.add(
204
+ uploaded_files_row.get('filepath'),
205
+ uploaded_files_row.get('citation string'),
206
+ key=key,
207
+ disable_check=True # True to accept Japanese letters
208
+ )
209
+
210
+ openai_embeddings = OpenAIEmbeddings(
211
+ document_model_name="text-embedding-ada-002",
212
+ query_model_name="text-embedding-ada-002",
213
+ )
214
+ if index_id == 1:
215
+ dataset._build_faiss_index(openai_embeddings)
216
+ else:
217
+ dataset._build_pinecone_index(openai_embeddings)
218
+ st.session_state['dataset'] = dataset
219
+
220
+ if OPERATING_MODE == "debug":
221
+ print(f"time to collect dataset: {time.time() - start:.2f} [s]")
222
+
223
+
224
+ def configure_streamlit_and_page():
225
+ # Configure Streamlit page and state
226
+ st.set_page_config(**ST_CONFIG)
227
+
228
+ # Force responsive layout for columns also on mobile
229
+ st.write(
230
+ """<style>
231
+ [data-testid="column"] {
232
+ width: calc(50% - 1rem);
233
+ flex: 1 1 calc(50% - 1rem);
234
+ min-width: calc(50% - 1rem);
235
+ }
236
+ </style>""",
237
+ unsafe_allow_html=True,
238
+ )
239
+
240
+
241
+ def get_answer():
242
+ query = st.session_state['user_input']
243
+ dataset = st.session_state['dataset']
244
+ type_id = app.params['type_id']
245
+ index_id = app.params['index_id']
246
+ max_sources = app.params['max_sources']
247
+
248
+ if query and dataset and type_id and index_id:
249
+ chat_history = [(past, generated)
250
+ for (past, generated) in zip(st.session_state['past'], st.session_state['generated'])]
251
+ marginal_relevance = False if not index_id == 1 else True
252
+ start = time.time()
253
+ openai_embeddings = OpenAIEmbeddings(
254
+ document_model_name="text-embedding-ada-002",
255
+ query_model_name="text-embedding-ada-002",
256
+ )
257
+ result = dataset.query(
258
+ query,
259
+ openai_embeddings,
260
+ chat_history,
261
+ marginal_relevance=marginal_relevance, # if pinecone is used it must be False
262
+ )
263
+ if OPERATING_MODE == "debug":
264
+ print(f"time to get answer: {time.time() - start:.2f} [s]")
265
+ print("-" * 10)
266
+ # response = {'generated_text': result.formatted_answer}
267
+ # response = {'generated_text': f"test_{len(st.session_state['generated'])} by {query}"} # @debug
268
+ return result
269
+ else:
270
+ return None
271
+
272
+
273
+ def load_main_page():
274
+ """
275
+ Load the body of web.
276
+ """
277
+ # Streamlit HTML Markdown
278
+ # st.title <h1> #
279
+ # st.header <h2> ##
280
+ # st.subheader <h3> ###
281
+ st.markdown(f"## Augmented-Retrieval Q&A ChatGPT ({APP_VERSION})")
282
+ validate_status()
283
+ st.markdown(f"#### **Status**: {app.params['status']}")
284
+
285
+ # hidden div with anchor
286
+ st.markdown("<div id='linkto_top'></div>", unsafe_allow_html=True)
287
+ col1, col2, col3 = st.columns(3)
288
+ col1.button(label="clear index", type="primary", on_click=clear_index)
289
+ col2.button(label="clear conversation", type="primary", on_click=clear_chat_history)
290
+ col3.markdown("<a href='#linkto_bottom'>Link to bottom</a>", unsafe_allow_html=True)
291
+
292
+ if st.session_state["generated"]:
293
+ for i in range(len(st.session_state["generated"])):
294
+ message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
295
+ message(st.session_state['generated'][i], key=str(i))
296
+ with st.expander("See context"):
297
+ st.write(st.session_state['contexts'][i])
298
+ with st.expander("See chunks"):
299
+ st.write(st.session_state['chunks'][i])
300
+ with st.expander("See costs"):
301
+ st.write(st.session_state['costs'][i])
302
+ dataset = st.session_state['dataset']
303
+ index_built = dataset.index_docstore if hasattr(dataset, "index_docstore") else False
304
+ without_source = app.params['source_id'] == 4
305
+ enable_chat_button = index_built or without_source
306
+ st.text_input("You:",
307
+ key='user_input',
308
+ on_change=on_enter,
309
+ disabled=not enable_chat_button
310
+ )
311
+
312
+ st.markdown("<a href='#linkto_top'>Link to top</a>", unsafe_allow_html=True)
313
+ # hidden div with anchor
314
+ st.markdown("<div id='linkto_bottom'></div>", unsafe_allow_html=True)
315
+
316
+
317
+ def load_sidebar_page():
318
+ st.sidebar.markdown("## Instructions")
319
+
320
+ # ############ #
321
+ # SOURCES TYPE #
322
+ # ############ #
323
+ st.sidebar.markdown("1. Select a source:")
324
+ source_selected = st.sidebar.selectbox(
325
+ "Choose the location of your info to give context to chatgpt",
326
+ [key for key, value in SOURCES_IDS.items()])
327
+ app.params['source_id'] = SOURCES_IDS.get(source_selected, None)
328
+
329
+ # ##### #
330
+ # MODEL #
331
+ # ##### #
332
+ st.sidebar.markdown("2. Select a model (LLM):")
333
+ combination_selected = st.sidebar.selectbox(
334
+ "Choose type: MSF Azure OpenAI and model / OpenAI",
335
+ [key for key, value in TYPE_IDS.items()])
336
+ app.params['type_id'] = TYPE_IDS.get(combination_selected, None)
337
+
338
+ if app.params['type_id'] == 1: # with AzureOpenAI endpoint
339
+ # https://docs.streamlit.io/library/api-reference/widgets/st.text_input
340
+ os.environ['OPENAI_API_KEY'] = st.sidebar.text_input(
341
+ label="Enter Azure OpenAI API Key",
342
+ type="password"
343
+ ).strip()
344
+ app.params['api_base'] = st.sidebar.text_input(
345
+ label="Enter Azure API base",
346
+ placeholder="https://<api_base_endpoint>.openai.azure.com/",
347
+ ).strip()
348
+ app.params['deployment_id'] = st.sidebar.text_input(
349
+ label="Enter Azure deployment_id",
350
+ ).strip()
351
+ elif app.params['type_id'] == 2: # with OpenAI endpoint
352
+ os.environ['OPENAI_API_KEY'] = st.sidebar.text_input(
353
+ label="Enter OpenAI API Key",
354
+ placeholder="sk-...",
355
+ type="password"
356
+ ).strip()
357
+ app.params['api_base'] = "https://api.openai.com/v1"
358
+ app.params['deployment_id'] = None
359
+
360
+ # ####### #
361
+ # INDEXES #
362
+ # ####### #
363
+ st.sidebar.markdown("3. Select a index store:")
364
+ index_selected = st.sidebar.selectbox(
365
+ "Type of Index",
366
+ [key for key, value in INDEX_IDS.items()])
367
+ app.params['index_id'] = INDEX_IDS.get(index_selected, None)
368
+ if app.params['index_id'] == 2: # with pinecone
369
+ os.environ['PINECONE_API_KEY'] = st.sidebar.text_input(
370
+ label="Enter pinecone API Key",
371
+ type="password"
372
+ ).strip()
373
+
374
+ os.environ['PINECONE_ENVIRONMENT'] = st.sidebar.text_input(
375
+ label="Enter pinecone environment",
376
+ placeholder="eu-west1-gcp",
377
+ ).strip()
378
+
379
+ # ############## #
380
+ # CONFIGURATIONS #
381
+ # ############## #
382
+ st.sidebar.markdown("4. Choose configuration:")
383
+ # https://docs.streamlit.io/library/api-reference/widgets/st.number_input
384
+ max_sources = st.sidebar.number_input(
385
+ label="Top-k: Number of chunks/sections (1-5)",
386
+ step=1,
387
+ format="%d",
388
+ value=5
389
+ )
390
+ app.params['max_sources'] = max_sources
391
+ temperature = st.sidebar.number_input(
392
+ label="Temperature (0.0 – 1.0)",
393
+ step=0.1,
394
+ format="%f",
395
+ value=0.0,
396
+ min_value=0.0,
397
+ max_value=1.0
398
+ )
399
+ app.params['temperature'] = round(temperature, 1)
400
+
401
+ # ############## #
402
+ # UPLOAD SOURCES #
403
+ # ############## #
404
+ app.params['uploaded_files_rows'] = []
405
+ if app.params['source_id'] == 1:
406
+ # https://docs.streamlit.io/library/api-reference/widgets/st.file_uploader
407
+ # https://towardsdatascience.com/make-dataframes-interactive-in-streamlit-c3d0c4f84ccb
408
+ st.sidebar.markdown("""5. Upload your local documents and modify citation strings (optional)""")
409
+ uploaded_files = st.sidebar.file_uploader(
410
+ "Choose files",
411
+ accept_multiple_files=True,
412
+ type=['pdf', 'PDF',
413
+ 'txt', 'TXT',
414
+ 'html',
415
+ 'docx', 'DOCX',
416
+ 'pptx', 'PPTX',
417
+ ],
418
+ )
419
+ uploaded_files_dataset = request_pathname(uploaded_files)
420
+ uploaded_files_df = pd.DataFrame(
421
+ uploaded_files_dataset,
422
+ columns=['filepath', 'citation string'])
423
+ uploaded_files_grid_options_builder = GridOptionsBuilder.from_dataframe(uploaded_files_df)
424
+ uploaded_files_grid_options_builder.configure_selection(
425
+ selection_mode='multiple',
426
+ pre_selected_rows=list(range(uploaded_files_df.shape[0])) if uploaded_files_df.iloc[-1, 0] != "" else [],
427
+ use_checkbox=True,
428
+ )
429
+ uploaded_files_grid_options_builder.configure_column("citation string", editable=True)
430
+ uploaded_files_grid_options_builder.configure_auto_height()
431
+ uploaded_files_grid_options = uploaded_files_grid_options_builder.build()
432
+ with st.sidebar:
433
+ uploaded_files_ag_grid = AgGrid(
434
+ uploaded_files_df,
435
+ gridOptions=uploaded_files_grid_options,
436
+ update_mode=GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED,
437
+ )
438
+ app.params['uploaded_files_rows'] = uploaded_files_ag_grid["selected_rows"]
439
+
440
+ app.params['urls_df'] = pd.DataFrame()
441
+ if app.params['source_id'] == 3:
442
+ st.sidebar.markdown("""5. Write some urls and modify citation strings if you want (to look prettier)""")
443
+ # option 1: with streamlit version 1.20.0+
444
+ # app.params['urls_df'] = st.sidebar.experimental_data_editor(
445
+ # pd.DataFrame([["", ""]], columns=['urls', 'citation string']),
446
+ # use_container_width=True,
447
+ # num_rows="dynamic",
448
+ # )
449
+
450
+ # option 2: with streamlit version 1.19.0
451
+ urls_dataset = [["", ""],
452
+ ["", ""],
453
+ ["", ""],
454
+ ["", ""],
455
+ ["", ""]]
456
+ urls_df = pd.DataFrame(
457
+ urls_dataset,
458
+ columns=['urls', 'citation string'])
459
+
460
+ urls_grid_options_builder = GridOptionsBuilder.from_dataframe(urls_df)
461
+ urls_grid_options_builder.configure_columns(['urls', 'citation string'], editable=True)
462
+ urls_grid_options_builder.configure_auto_height()
463
+ urls_grid_options = urls_grid_options_builder.build()
464
+ with st.sidebar:
465
+ urls_ag_grid = AgGrid(
466
+ urls_df,
467
+ gridOptions=urls_grid_options,
468
+ update_mode=GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED,
469
+ )
470
+ df = urls_ag_grid.data
471
+ df = df[df.urls != ""]
472
+ app.params['urls_df'] = df
473
+
474
+ if app.params['source_id'] in (1, 2, 3):
475
+ st.sidebar.markdown("""6. Build an index where you can ask""")
476
+ api_keys_ready = check_api_keys()
477
+ source_ready = check_sources()
478
+ enable_index_button = api_keys_ready and source_ready
479
+ if st.sidebar.button("Build index", disabled=not enable_index_button):
480
+ collect_dataset_and_built_index()
481
+
482
+
483
+ def main():
484
+ configure_streamlit_and_page()
485
+ load_sidebar_page()
486
+ load_main_page()
487
+
488
+
489
+ def on_enter():
490
+ output = get_answer()
491
+ if output:
492
+ st.session_state.past.append(st.session_state['user_input'])
493
+ st.session_state.generated.append(output.answer)
494
+ st.session_state.contexts.append(output.context)
495
+ st.session_state.chunks.append(output.chunks)
496
+ st.session_state.costs.append(output.cost_str)
497
+ st.session_state['user_input'] = ""
498
+
499
+
500
+ def request_pathname(files):
501
+ if not files:
502
+ return [["", ""]]
503
+
504
+ # check if temporal directory exist, if not create it
505
+ if not Path.exists(TEMP_DIR):
506
+ TEMP_DIR.mkdir(
507
+ parents=True,
508
+ exist_ok=True,
509
+ )
510
+
511
+ file_paths = []
512
+ for file in files:
513
+ # # absolut path
514
+ # file_path = str(TEMP_DIR / file.name)
515
+ # relative path
516
+ file_path = str((TEMP_DIR / file.name).relative_to(ROOT_DIR))
517
+ file_paths.append(file_path)
518
+ with open(file_path, "wb") as f:
519
+ f.write(file.getbuffer())
520
+ return [[filepath, filename.name] for filepath, filename in zip(file_paths, files)]
521
+
522
+
523
+ def validate_status():
524
+ source_point_ready = check_source_point()
525
+ combination_point_ready = check_combination_point()
526
+ index_point_ready = check_index_point()
527
+ params_point_ready = check_params_point()
528
+ sources_ready = check_sources()
529
+ index_ready = check_index()
530
+
531
+ if source_point_ready and combination_point_ready and index_point_ready and params_point_ready and sources_ready and index_ready:
532
+ app.params['status'] = "✨Ready✨"
533
+ elif not source_point_ready:
534
+ app.params['status'] = "⚠️Review step 1 on the sidebar."
535
+ elif not combination_point_ready:
536
+ app.params['status'] = "⚠️Review step 2 on the sidebar. API Keys or endpoint, ..."
537
+ elif not index_point_ready:
538
+ app.params['status'] = "⚠️Review step 3 on the sidebar. Index API Key or environment."
539
+ elif not params_point_ready:
540
+ app.params['status'] = "⚠️Review step 4 on the sidebar"
541
+ elif not sources_ready:
542
+ app.params['status'] = "⚠️Review step 5 on the sidebar. Waiting for some source..."
543
+ elif not index_ready:
544
+ app.params['status'] = "⚠️Review step 6 on the sidebar. Waiting for press button to create index ..."
545
+ else:
546
+ app.params['status'] = "⚠️Something is not ready..."
547
+
548
+
549
+ class StreamlitLangchainChatApp():
550
+ def __init__(self) -> None:
551
+ """Use __init__ to define instance variables. It cannot have any arguments."""
552
+ self.params = dict()
553
+
554
+ def run(self, **state) -> None:
555
+ """Define here all logic required by your application."""
556
+ main()
557
+
558
+
559
+ if __name__ == "__main__":
560
+ app = StreamlitLangchainChatApp()
561
+ app.run()
streamlit_langchain_chat/utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import string
3
+
4
+
5
+ def maybe_is_text(s, thresh=2.5):
6
+ if len(s) == 0:
7
+ return False
8
+ # Calculate the entropy of the string
9
+ entropy = 0
10
+ for c in string.printable:
11
+ p = s.count(c) / len(s)
12
+ if p > 0:
13
+ entropy += -p * math.log2(p)
14
+
15
+ # Check if the entropy is within a reasonable range for text
16
+ if entropy > thresh:
17
+ return True
18
+ return False
19
+
20
+
21
+ def maybe_is_code(s):
22
+ if len(s) == 0:
23
+ return False
24
+ # Check if the string contains a lot of non-ascii characters
25
+ if len([c for c in s if ord(c) > 128]) / len(s) > 0.1:
26
+ return True
27
+ return False
28
+
29
+
30
+ def strings_similarity(s1, s2):
31
+ if len(s1) == 0 or len(s2) == 0:
32
+ return 0
33
+ # break the strings into words
34
+ s1 = set(s1.split())
35
+ s2 = set(s2.split())
36
+ # return the similarity ratio
37
+ return len(s1.intersection(s2)) / len(s1.union(s2))
38
+
39
+
40
+ def maybe_is_truncated(s):
41
+ punct = [".", "!", "?", '"']
42
+ if s[-1] in punct:
43
+ return False
44
+ return True
45
+
46
+
47
+ def maybe_is_html(s):
48
+ if len(s) == 0:
49
+ return False
50
+ # check for html tags
51
+ if "<body" in s or "<html" in s or "<div" in s:
52
+ return True