Spaces:
Runtime error
Runtime error
Commit
·
0539589
1
Parent(s):
eeb7ca1
Update with h2oGPT hash 3513278043665f503945eb05d56c1ec1152d1006
Browse files- generate.py +31 -15
- gpt_langchain.py +40 -8
- gradio_runner.py +8 -6
- requirements.txt +2 -1
- utils.py +0 -1
generate.py
CHANGED
|
@@ -33,7 +33,6 @@ from typing import Union
|
|
| 33 |
|
| 34 |
import fire
|
| 35 |
import torch
|
| 36 |
-
from peft import PeftModel
|
| 37 |
from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
|
| 38 |
from accelerate import init_empty_weights, infer_auto_device_map
|
| 39 |
|
|
@@ -710,6 +709,7 @@ def get_model(
|
|
| 710 |
base_model,
|
| 711 |
**model_kwargs
|
| 712 |
)
|
|
|
|
| 713 |
model = PeftModel.from_pretrained(
|
| 714 |
model,
|
| 715 |
lora_weights,
|
|
@@ -727,6 +727,7 @@ def get_model(
|
|
| 727 |
base_model,
|
| 728 |
**model_kwargs
|
| 729 |
)
|
|
|
|
| 730 |
model = PeftModel.from_pretrained(
|
| 731 |
model,
|
| 732 |
lora_weights,
|
|
@@ -827,24 +828,27 @@ no_default_param_names = [
|
|
| 827 |
'iinput_nochat',
|
| 828 |
]
|
| 829 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 830 |
eval_func_param_names = ['instruction',
|
| 831 |
'iinput',
|
| 832 |
'context',
|
| 833 |
'stream_output',
|
| 834 |
'prompt_type',
|
| 835 |
-
'prompt_dict'
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
'top_k',
|
| 839 |
-
'num_beams',
|
| 840 |
-
'max_new_tokens',
|
| 841 |
-
'min_new_tokens',
|
| 842 |
-
'early_stopping',
|
| 843 |
-
'max_time',
|
| 844 |
-
'repetition_penalty',
|
| 845 |
-
'num_return_sequences',
|
| 846 |
-
'do_sample',
|
| 847 |
-
'chat',
|
| 848 |
'instruction_nochat',
|
| 849 |
'iinput_nochat',
|
| 850 |
'langchain_mode',
|
|
@@ -900,6 +904,9 @@ def evaluate_from_str(
|
|
| 900 |
# only used for submit_nochat_api
|
| 901 |
user_kwargs['chat'] = False
|
| 902 |
user_kwargs['stream_output'] = False
|
|
|
|
|
|
|
|
|
|
| 903 |
|
| 904 |
assert set(list(default_kwargs.keys())) == set(eval_func_param_names)
|
| 905 |
# correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get()
|
|
@@ -1083,7 +1090,6 @@ def evaluate(
|
|
| 1083 |
db=db1,
|
| 1084 |
user_path=user_path,
|
| 1085 |
detect_user_path_changes_every_query=detect_user_path_changes_every_query,
|
| 1086 |
-
max_new_tokens=max_new_tokens,
|
| 1087 |
cut_distanct=1.1 if langchain_mode in ['wiki_full'] else 1.64, # FIXME, too arbitrary
|
| 1088 |
use_openai_embedding=use_openai_embedding,
|
| 1089 |
use_openai_model=use_openai_model,
|
|
@@ -1096,10 +1102,20 @@ def evaluate(
|
|
| 1096 |
document_choice=document_choice,
|
| 1097 |
db_type=db_type,
|
| 1098 |
top_k_docs=top_k_docs,
|
|
|
|
|
|
|
|
|
|
| 1099 |
temperature=temperature,
|
| 1100 |
repetition_penalty=repetition_penalty,
|
| 1101 |
top_k=top_k,
|
| 1102 |
top_p=top_p,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1103 |
prompt_type=prompt_type,
|
| 1104 |
prompt_dict=prompt_dict,
|
| 1105 |
n_jobs=n_jobs,
|
|
|
|
| 33 |
|
| 34 |
import fire
|
| 35 |
import torch
|
|
|
|
| 36 |
from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
|
| 37 |
from accelerate import init_empty_weights, infer_auto_device_map
|
| 38 |
|
|
|
|
| 709 |
base_model,
|
| 710 |
**model_kwargs
|
| 711 |
)
|
| 712 |
+
from peft import PeftModel # loads cuda, so avoid in global scope
|
| 713 |
model = PeftModel.from_pretrained(
|
| 714 |
model,
|
| 715 |
lora_weights,
|
|
|
|
| 727 |
base_model,
|
| 728 |
**model_kwargs
|
| 729 |
)
|
| 730 |
+
from peft import PeftModel # loads cuda, so avoid in global scope
|
| 731 |
model = PeftModel.from_pretrained(
|
| 732 |
model,
|
| 733 |
lora_weights,
|
|
|
|
| 828 |
'iinput_nochat',
|
| 829 |
]
|
| 830 |
|
| 831 |
+
gen_hyper = ['temperature',
|
| 832 |
+
'top_p',
|
| 833 |
+
'top_k',
|
| 834 |
+
'num_beams',
|
| 835 |
+
'max_new_tokens',
|
| 836 |
+
'min_new_tokens',
|
| 837 |
+
'early_stopping',
|
| 838 |
+
'max_time',
|
| 839 |
+
'repetition_penalty',
|
| 840 |
+
'num_return_sequences',
|
| 841 |
+
'do_sample',
|
| 842 |
+
]
|
| 843 |
+
|
| 844 |
eval_func_param_names = ['instruction',
|
| 845 |
'iinput',
|
| 846 |
'context',
|
| 847 |
'stream_output',
|
| 848 |
'prompt_type',
|
| 849 |
+
'prompt_dict'] + \
|
| 850 |
+
gen_hyper + \
|
| 851 |
+
['chat',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 852 |
'instruction_nochat',
|
| 853 |
'iinput_nochat',
|
| 854 |
'langchain_mode',
|
|
|
|
| 904 |
# only used for submit_nochat_api
|
| 905 |
user_kwargs['chat'] = False
|
| 906 |
user_kwargs['stream_output'] = False
|
| 907 |
+
if 'langchain_mode' not in user_kwargs:
|
| 908 |
+
# if user doesn't specify, then assume disabled, not use default
|
| 909 |
+
user_kwargs['langchain_mode'] = 'Disabled'
|
| 910 |
|
| 911 |
assert set(list(default_kwargs.keys())) == set(eval_func_param_names)
|
| 912 |
# correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get()
|
|
|
|
| 1090 |
db=db1,
|
| 1091 |
user_path=user_path,
|
| 1092 |
detect_user_path_changes_every_query=detect_user_path_changes_every_query,
|
|
|
|
| 1093 |
cut_distanct=1.1 if langchain_mode in ['wiki_full'] else 1.64, # FIXME, too arbitrary
|
| 1094 |
use_openai_embedding=use_openai_embedding,
|
| 1095 |
use_openai_model=use_openai_model,
|
|
|
|
| 1102 |
document_choice=document_choice,
|
| 1103 |
db_type=db_type,
|
| 1104 |
top_k_docs=top_k_docs,
|
| 1105 |
+
|
| 1106 |
+
# gen_hyper:
|
| 1107 |
+
do_sample=do_sample,
|
| 1108 |
temperature=temperature,
|
| 1109 |
repetition_penalty=repetition_penalty,
|
| 1110 |
top_k=top_k,
|
| 1111 |
top_p=top_p,
|
| 1112 |
+
num_beams=num_beams,
|
| 1113 |
+
min_new_tokens=min_new_tokens,
|
| 1114 |
+
max_new_tokens=max_new_tokens,
|
| 1115 |
+
early_stopping=early_stopping,
|
| 1116 |
+
max_time=max_time,
|
| 1117 |
+
num_return_sequences=num_return_sequences,
|
| 1118 |
+
|
| 1119 |
prompt_type=prompt_type,
|
| 1120 |
prompt_dict=prompt_dict,
|
| 1121 |
n_jobs=n_jobs,
|
gpt_langchain.py
CHANGED
|
@@ -22,6 +22,7 @@ from langchain.embeddings import HuggingFaceInstructEmbeddings
|
|
| 22 |
from tqdm import tqdm
|
| 23 |
|
| 24 |
from enums import DocumentChoices
|
|
|
|
| 25 |
from prompter import non_hf_types, PromptType
|
| 26 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
| 27 |
get_device, ProgressParallel, remove, hash_file, clear_torch_cache
|
|
@@ -261,11 +262,17 @@ def get_answer_from_sources(chain, sources, question):
|
|
| 261 |
|
| 262 |
def get_llm(use_openai_model=False, model_name=None, model=None,
|
| 263 |
tokenizer=None, stream_output=False,
|
| 264 |
-
|
| 265 |
temperature=0.1,
|
| 266 |
-
repetition_penalty=1.0,
|
| 267 |
top_k=40,
|
| 268 |
top_p=0.7,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
prompt_type=None,
|
| 270 |
prompt_dict=None,
|
| 271 |
prompter=None,
|
|
@@ -312,10 +319,20 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
|
|
| 312 |
load_in_8bit=load_8bit)
|
| 313 |
|
| 314 |
max_max_tokens = tokenizer.model_max_length
|
| 315 |
-
gen_kwargs = dict(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
return_full_text=True,
|
| 317 |
-
early_stopping=False,
|
| 318 |
handle_long_generation='hole')
|
|
|
|
| 319 |
|
| 320 |
if stream_output:
|
| 321 |
skip_prompt = False
|
|
@@ -1235,11 +1252,17 @@ def _run_qa_db(query=None,
|
|
| 1235 |
show_rank=False,
|
| 1236 |
load_db_if_exists=False,
|
| 1237 |
db=None,
|
| 1238 |
-
|
| 1239 |
temperature=0.1,
|
| 1240 |
-
repetition_penalty=1.0,
|
| 1241 |
top_k=40,
|
| 1242 |
top_p=0.7,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1243 |
langchain_mode=None,
|
| 1244 |
document_choice=[DocumentChoices.All_Relevant.name],
|
| 1245 |
n_jobs=-1,
|
|
@@ -1274,14 +1297,21 @@ def _run_qa_db(query=None,
|
|
| 1274 |
assert prompt_dict is not None # should at least be {} or ''
|
| 1275 |
else:
|
| 1276 |
prompt_dict = ''
|
|
|
|
| 1277 |
llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
|
| 1278 |
model=model, tokenizer=tokenizer,
|
| 1279 |
stream_output=stream_output,
|
| 1280 |
-
|
| 1281 |
temperature=temperature,
|
| 1282 |
-
repetition_penalty=repetition_penalty,
|
| 1283 |
top_k=top_k,
|
| 1284 |
top_p=top_p,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1285 |
prompt_type=prompt_type,
|
| 1286 |
prompt_dict=prompt_dict,
|
| 1287 |
prompter=prompter,
|
|
@@ -1609,6 +1639,7 @@ def get_some_dbs_from_hf(dest='.', db_zips=None):
|
|
| 1609 |
assert os.path.isdir(os.path.join(dest, dir_expected)), "Missing path for %s" % dir_expected
|
| 1610 |
assert os.path.isdir(os.path.join(dest, dir_expected, 'index')), "Missing index in %s" % dir_expected
|
| 1611 |
|
|
|
|
| 1612 |
def _create_local_weaviate_client():
|
| 1613 |
WEAVIATE_URL = os.getenv('WEAVIATE_URL', "http://localhost:8080")
|
| 1614 |
WEAVIATE_USERNAME = os.getenv('WEAVIATE_USERNAME')
|
|
@@ -1629,5 +1660,6 @@ def _create_local_weaviate_client():
|
|
| 1629 |
print(f"Failed to create Weaviate client: {e}")
|
| 1630 |
return None
|
| 1631 |
|
|
|
|
| 1632 |
if __name__ == '__main__':
|
| 1633 |
pass
|
|
|
|
| 22 |
from tqdm import tqdm
|
| 23 |
|
| 24 |
from enums import DocumentChoices
|
| 25 |
+
from generate import gen_hyper
|
| 26 |
from prompter import non_hf_types, PromptType
|
| 27 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
| 28 |
get_device, ProgressParallel, remove, hash_file, clear_torch_cache
|
|
|
|
| 262 |
|
| 263 |
def get_llm(use_openai_model=False, model_name=None, model=None,
|
| 264 |
tokenizer=None, stream_output=False,
|
| 265 |
+
do_sample=False,
|
| 266 |
temperature=0.1,
|
|
|
|
| 267 |
top_k=40,
|
| 268 |
top_p=0.7,
|
| 269 |
+
num_beams=1,
|
| 270 |
+
max_new_tokens=256,
|
| 271 |
+
min_new_tokens=1,
|
| 272 |
+
early_stopping=False,
|
| 273 |
+
max_time=180,
|
| 274 |
+
repetition_penalty=1.0,
|
| 275 |
+
num_return_sequences=1,
|
| 276 |
prompt_type=None,
|
| 277 |
prompt_dict=None,
|
| 278 |
prompter=None,
|
|
|
|
| 319 |
load_in_8bit=load_8bit)
|
| 320 |
|
| 321 |
max_max_tokens = tokenizer.model_max_length
|
| 322 |
+
gen_kwargs = dict(do_sample=do_sample,
|
| 323 |
+
temperature=temperature,
|
| 324 |
+
top_k=top_k,
|
| 325 |
+
top_p=top_p,
|
| 326 |
+
num_beams=num_beams,
|
| 327 |
+
max_new_tokens=max_new_tokens,
|
| 328 |
+
min_new_tokens=min_new_tokens,
|
| 329 |
+
early_stopping=early_stopping,
|
| 330 |
+
max_time=max_time,
|
| 331 |
+
repetition_penalty=repetition_penalty,
|
| 332 |
+
num_return_sequences=num_return_sequences,
|
| 333 |
return_full_text=True,
|
|
|
|
| 334 |
handle_long_generation='hole')
|
| 335 |
+
assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0
|
| 336 |
|
| 337 |
if stream_output:
|
| 338 |
skip_prompt = False
|
|
|
|
| 1252 |
show_rank=False,
|
| 1253 |
load_db_if_exists=False,
|
| 1254 |
db=None,
|
| 1255 |
+
do_sample=False,
|
| 1256 |
temperature=0.1,
|
|
|
|
| 1257 |
top_k=40,
|
| 1258 |
top_p=0.7,
|
| 1259 |
+
num_beams=1,
|
| 1260 |
+
max_new_tokens=256,
|
| 1261 |
+
min_new_tokens=1,
|
| 1262 |
+
early_stopping=False,
|
| 1263 |
+
max_time=180,
|
| 1264 |
+
repetition_penalty=1.0,
|
| 1265 |
+
num_return_sequences=1,
|
| 1266 |
langchain_mode=None,
|
| 1267 |
document_choice=[DocumentChoices.All_Relevant.name],
|
| 1268 |
n_jobs=-1,
|
|
|
|
| 1297 |
assert prompt_dict is not None # should at least be {} or ''
|
| 1298 |
else:
|
| 1299 |
prompt_dict = ''
|
| 1300 |
+
assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
|
| 1301 |
llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
|
| 1302 |
model=model, tokenizer=tokenizer,
|
| 1303 |
stream_output=stream_output,
|
| 1304 |
+
do_sample=do_sample,
|
| 1305 |
temperature=temperature,
|
|
|
|
| 1306 |
top_k=top_k,
|
| 1307 |
top_p=top_p,
|
| 1308 |
+
num_beams=num_beams,
|
| 1309 |
+
max_new_tokens=max_new_tokens,
|
| 1310 |
+
min_new_tokens=min_new_tokens,
|
| 1311 |
+
early_stopping=early_stopping,
|
| 1312 |
+
max_time=max_time,
|
| 1313 |
+
repetition_penalty=repetition_penalty,
|
| 1314 |
+
num_return_sequences=num_return_sequences,
|
| 1315 |
prompt_type=prompt_type,
|
| 1316 |
prompt_dict=prompt_dict,
|
| 1317 |
prompter=prompter,
|
|
|
|
| 1639 |
assert os.path.isdir(os.path.join(dest, dir_expected)), "Missing path for %s" % dir_expected
|
| 1640 |
assert os.path.isdir(os.path.join(dest, dir_expected, 'index')), "Missing index in %s" % dir_expected
|
| 1641 |
|
| 1642 |
+
|
| 1643 |
def _create_local_weaviate_client():
|
| 1644 |
WEAVIATE_URL = os.getenv('WEAVIATE_URL', "http://localhost:8080")
|
| 1645 |
WEAVIATE_USERNAME = os.getenv('WEAVIATE_USERNAME')
|
|
|
|
| 1660 |
print(f"Failed to create Weaviate client: {e}")
|
| 1661 |
return None
|
| 1662 |
|
| 1663 |
+
|
| 1664 |
if __name__ == '__main__':
|
| 1665 |
pass
|
gradio_runner.py
CHANGED
|
@@ -649,7 +649,7 @@ def go_gradio(**kwargs):
|
|
| 649 |
inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
|
| 650 |
chunk, chunk_size],
|
| 651 |
outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
|
| 652 |
-
api_name='add_to_shared' if allow_api else None) \
|
| 653 |
.then(clear_file_list, outputs=fileup_output, queue=queue) \
|
| 654 |
.then(update_radio_to_user, inputs=None, outputs=langchain_mode, queue=False)
|
| 655 |
|
|
@@ -664,7 +664,7 @@ def go_gradio(**kwargs):
|
|
| 664 |
inputs=[url_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
|
| 665 |
chunk, chunk_size],
|
| 666 |
outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
|
| 667 |
-
api_name='add_url_to_shared' if allow_api else None) \
|
| 668 |
.then(clear_textbox, outputs=url_text, queue=queue) \
|
| 669 |
.then(update_radio_to_user, inputs=None, outputs=langchain_mode, queue=False)
|
| 670 |
|
|
@@ -673,7 +673,7 @@ def go_gradio(**kwargs):
|
|
| 673 |
inputs=[user_text_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
|
| 674 |
chunk, chunk_size],
|
| 675 |
outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
|
| 676 |
-
api_name='add_text_to_shared' if allow_api else None) \
|
| 677 |
.then(clear_textbox, outputs=user_text_text, queue=queue) \
|
| 678 |
.then(update_radio_to_user, inputs=None, outputs=langchain_mode, queue=False)
|
| 679 |
|
|
@@ -695,7 +695,7 @@ def go_gradio(**kwargs):
|
|
| 695 |
inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
|
| 696 |
chunk, chunk_size],
|
| 697 |
outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
|
| 698 |
-
api_name='add_to_my' if allow_api else None) \
|
| 699 |
.then(clear_file_list, outputs=fileup_output, queue=queue) \
|
| 700 |
.then(update_radio_to_my, inputs=None, outputs=langchain_mode, queue=False)
|
| 701 |
# .then(make_invisible, outputs=add_to_shared_db_btn, queue=queue)
|
|
@@ -706,7 +706,7 @@ def go_gradio(**kwargs):
|
|
| 706 |
inputs=[url_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
|
| 707 |
chunk, chunk_size],
|
| 708 |
outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
|
| 709 |
-
api_name='add_url_to_my' if allow_api else None) \
|
| 710 |
.then(clear_textbox, outputs=url_text, queue=queue) \
|
| 711 |
.then(update_radio_to_my, inputs=None, outputs=langchain_mode, queue=False)
|
| 712 |
|
|
@@ -715,7 +715,7 @@ def go_gradio(**kwargs):
|
|
| 715 |
inputs=[user_text_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
|
| 716 |
chunk, chunk_size],
|
| 717 |
outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
|
| 718 |
-
api_name='add_txt_to_my' if allow_api else None) \
|
| 719 |
.then(clear_textbox, outputs=user_text_text, queue=queue) \
|
| 720 |
.then(update_radio_to_my, inputs=None, outputs=langchain_mode, queue=False)
|
| 721 |
|
|
@@ -1788,6 +1788,8 @@ def get_db(db1, langchain_mode, dbs=None):
|
|
| 1788 |
|
| 1789 |
def get_source_files_given_langchain_mode(db1, langchain_mode='UserData', dbs=None):
|
| 1790 |
db = get_db(db1, langchain_mode, dbs=dbs)
|
|
|
|
|
|
|
| 1791 |
return get_source_files(db=db, exceptions=None)
|
| 1792 |
|
| 1793 |
|
|
|
|
| 649 |
inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
|
| 650 |
chunk, chunk_size],
|
| 651 |
outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
|
| 652 |
+
api_name='add_to_shared' if allow_api and allow_upload_to_user_data else None) \
|
| 653 |
.then(clear_file_list, outputs=fileup_output, queue=queue) \
|
| 654 |
.then(update_radio_to_user, inputs=None, outputs=langchain_mode, queue=False)
|
| 655 |
|
|
|
|
| 664 |
inputs=[url_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
|
| 665 |
chunk, chunk_size],
|
| 666 |
outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
|
| 667 |
+
api_name='add_url_to_shared' if allow_api and allow_upload_to_user_data else None) \
|
| 668 |
.then(clear_textbox, outputs=url_text, queue=queue) \
|
| 669 |
.then(update_radio_to_user, inputs=None, outputs=langchain_mode, queue=False)
|
| 670 |
|
|
|
|
| 673 |
inputs=[user_text_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
|
| 674 |
chunk, chunk_size],
|
| 675 |
outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
|
| 676 |
+
api_name='add_text_to_shared' if allow_api and allow_upload_to_user_data else None) \
|
| 677 |
.then(clear_textbox, outputs=user_text_text, queue=queue) \
|
| 678 |
.then(update_radio_to_user, inputs=None, outputs=langchain_mode, queue=False)
|
| 679 |
|
|
|
|
| 695 |
inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
|
| 696 |
chunk, chunk_size],
|
| 697 |
outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
|
| 698 |
+
api_name='add_to_my' if allow_api and allow_upload_to_my_data else None) \
|
| 699 |
.then(clear_file_list, outputs=fileup_output, queue=queue) \
|
| 700 |
.then(update_radio_to_my, inputs=None, outputs=langchain_mode, queue=False)
|
| 701 |
# .then(make_invisible, outputs=add_to_shared_db_btn, queue=queue)
|
|
|
|
| 706 |
inputs=[url_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
|
| 707 |
chunk, chunk_size],
|
| 708 |
outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
|
| 709 |
+
api_name='add_url_to_my' if allow_api and allow_upload_to_my_data else None) \
|
| 710 |
.then(clear_textbox, outputs=url_text, queue=queue) \
|
| 711 |
.then(update_radio_to_my, inputs=None, outputs=langchain_mode, queue=False)
|
| 712 |
|
|
|
|
| 715 |
inputs=[user_text_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
|
| 716 |
chunk, chunk_size],
|
| 717 |
outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
|
| 718 |
+
api_name='add_txt_to_my' if allow_api and allow_upload_to_my_data else None) \
|
| 719 |
.then(clear_textbox, outputs=user_text_text, queue=queue) \
|
| 720 |
.then(update_radio_to_my, inputs=None, outputs=langchain_mode, queue=False)
|
| 721 |
|
|
|
|
| 1788 |
|
| 1789 |
def get_source_files_given_langchain_mode(db1, langchain_mode='UserData', dbs=None):
|
| 1790 |
db = get_db(db1, langchain_mode, dbs=dbs)
|
| 1791 |
+
if langchain_mode in ['ChatLLM', 'LLM'] or db is None:
|
| 1792 |
+
return "Sources: N/A"
|
| 1793 |
return get_source_files(db=db, exceptions=None)
|
| 1794 |
|
| 1795 |
|
requirements.txt
CHANGED
|
@@ -56,7 +56,8 @@ einops==0.6.1
|
|
| 56 |
instructorembedding==1.0.1
|
| 57 |
|
| 58 |
# for gpt4all .env file, but avoid worrying about imports
|
| 59 |
-
python-dotenv==1.0.0
|
|
|
|
| 60 |
langchain==0.0.193
|
| 61 |
pypdf==3.8.1
|
| 62 |
tiktoken==0.3.3
|
|
|
|
| 56 |
instructorembedding==1.0.1
|
| 57 |
|
| 58 |
# for gpt4all .env file, but avoid worrying about imports
|
| 59 |
+
python-dotenv==1.0.0
|
| 60 |
+
# optional for chat with PDF
|
| 61 |
langchain==0.0.193
|
| 62 |
pypdf==3.8.1
|
| 63 |
tiktoken==0.3.3
|
utils.py
CHANGED
|
@@ -14,7 +14,6 @@ import time
|
|
| 14 |
import traceback
|
| 15 |
import zipfile
|
| 16 |
from datetime import datetime
|
| 17 |
-
from enum import Enum
|
| 18 |
|
| 19 |
import filelock
|
| 20 |
import requests, uuid
|
|
|
|
| 14 |
import traceback
|
| 15 |
import zipfile
|
| 16 |
from datetime import datetime
|
|
|
|
| 17 |
|
| 18 |
import filelock
|
| 19 |
import requests, uuid
|