Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -172,6 +172,80 @@ EOS_TOKEN = '</s>'
|
|
| 172 |
SYSTEM_PROMPT_1 = """You are a helpful, respectful, honest and safe AI assistant built by Alibaba Group."""
|
| 173 |
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
# ============ CONSTANT ============
|
| 176 |
# https://github.com/gradio-app/gradio/issues/884
|
| 177 |
MODEL_NAME = "SeaLLM-7B"
|
|
@@ -771,7 +845,7 @@ def chat_response_stream_multiturn(
|
|
| 771 |
presence_penalty: float,
|
| 772 |
system_prompt: Optional[str] = SYSTEM_PROMPT_1,
|
| 773 |
current_time: Optional[float] = None,
|
| 774 |
-
profile: Optional[gr.OAuthProfile] = None,
|
| 775 |
) -> str:
|
| 776 |
"""
|
| 777 |
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
|
@@ -794,7 +868,8 @@ def chat_response_stream_multiturn(
|
|
| 794 |
global llm, RES_PRINTED
|
| 795 |
assert llm is not None
|
| 796 |
assert system_prompt.strip() != '', f'system prompt is empty'
|
| 797 |
-
is_by_pass = False if profile is None else profile.username in BYPASS_USERS
|
|
|
|
| 798 |
|
| 799 |
tokenizer = llm.get_tokenizer()
|
| 800 |
# force removing all
|
|
@@ -876,6 +951,32 @@ def chat_response_stream_multiturn(
|
|
| 876 |
|
| 877 |
|
| 878 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 879 |
def debug_generate_free_form_stream(message):
|
| 880 |
output = " This is a debugging message...."
|
| 881 |
for i in range(len(output)):
|
|
@@ -1450,6 +1551,61 @@ def create_chat_demo(title=None, description=None):
|
|
| 1450 |
return demo_chat
|
| 1451 |
|
| 1452 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1453 |
|
| 1454 |
def launch_demo():
|
| 1455 |
global demo, llm, DEBUG, LOG_FILE
|
|
@@ -1544,18 +1700,29 @@ def launch_demo():
|
|
| 1544 |
|
| 1545 |
if ENABLE_BATCH_INFER:
|
| 1546 |
|
| 1547 |
-
demo_file_upload = create_file_upload_demo()
|
| 1548 |
|
| 1549 |
demo_free_form = create_free_form_generation_demo()
|
| 1550 |
|
| 1551 |
demo_chat = create_chat_demo()
|
|
|
|
| 1552 |
descriptions = model_desc
|
| 1553 |
if DISPLAY_MODEL_PATH:
|
| 1554 |
descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
|
| 1555 |
|
| 1556 |
demo = CustomTabbedInterface(
|
| 1557 |
-
interface_list=[
|
| 1558 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1559 |
title=f"{model_title}",
|
| 1560 |
description=descriptions,
|
| 1561 |
)
|
|
@@ -1582,7 +1749,7 @@ def launch_demo():
|
|
| 1582 |
if ENABLE_AGREE_POPUP:
|
| 1583 |
demo.load(None, None, None, _js=AGREE_POP_SCRIPTS)
|
| 1584 |
|
| 1585 |
-
login_btn = gr.LoginButton()
|
| 1586 |
|
| 1587 |
demo.queue(api_open=False)
|
| 1588 |
return demo
|
|
|
|
| 172 |
SYSTEM_PROMPT_1 = """You are a helpful, respectful, honest and safe AI assistant built by Alibaba Group."""
|
| 173 |
|
| 174 |
|
| 175 |
+
|
| 176 |
+
# ######### RAG PREPARE
|
| 177 |
+
RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE = None, None, None
|
| 178 |
+
|
| 179 |
+
RAG_EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def load_embeddings():
|
| 183 |
+
global RAG_EMBED
|
| 184 |
+
if RAG_EMBED is None:
|
| 185 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
|
| 186 |
+
print(f'LOading embeddings: {RAG_EMBED_MODEL_NAME}')
|
| 187 |
+
RAG_EMBED = HuggingFaceEmbeddings(model_name=RAG_EMBED_MODEL_NAME, model_kwargs={'trust_remote_code':True})
|
| 188 |
+
else:
|
| 189 |
+
print(f'RAG_EMBED ALREADY EXIST: {RAG_EMBED_MODEL_NAME}: {RAG_EMBED=}')
|
| 190 |
+
return RAG_EMBED
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def get_rag_embeddings():
|
| 194 |
+
return load_embeddings()
|
| 195 |
+
|
| 196 |
+
_ = get_rag_embeddings()
|
| 197 |
+
|
| 198 |
+
RAG_CURRENT_VECTORSTORE = None
|
| 199 |
+
|
| 200 |
+
def load_document_split_vectorstore(file_path):
|
| 201 |
+
global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
|
| 202 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 203 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
|
| 204 |
+
from langchain_community.vectorstores import Chroma, FAISS
|
| 205 |
+
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
|
| 206 |
+
# assert RAG_EMBED is not None
|
| 207 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=50)
|
| 208 |
+
if file_path.endswith('.pdf'):
|
| 209 |
+
loader = PyPDFLoader(file_path)
|
| 210 |
+
elif file_path.endswith('.docx'):
|
| 211 |
+
loader = Docx2txtLoader(file_path)
|
| 212 |
+
elif file_path.endswith('.txt'):
|
| 213 |
+
loader = TextLoader(file_path)
|
| 214 |
+
splits = loader.load_and_split(splitter)
|
| 215 |
+
RAG_CURRENT_VECTORSTORE = FAISS.from_texts(texts=[s.page_content for s in splits], embedding=get_rag_embeddings())
|
| 216 |
+
return RAG_CURRENT_VECTORSTORE
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def docs_to_rag_context(docs: List[str]):
|
| 220 |
+
contexts = "\n".join([d.page_content for d in docs])
|
| 221 |
+
context = f"""### Begin document
|
| 222 |
+
{contexts}
|
| 223 |
+
### End document
|
| 224 |
+
Asnwer the following query exclusively based on the information provided in the document above. \
|
| 225 |
+
Remember to follow the language of the user query.
|
| 226 |
+
"""
|
| 227 |
+
return context
|
| 228 |
+
|
| 229 |
+
def maybe_get_doc_context(message, file_input, rag_num_docs: Optional[int] = 3):
|
| 230 |
+
global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
|
| 231 |
+
doc_context = None
|
| 232 |
+
if file_input is not None:
|
| 233 |
+
assert os.path.exists(file_input), f"not found: {file_input}"
|
| 234 |
+
if file_input == RAG_CURRENT_FILE:
|
| 235 |
+
# reuse
|
| 236 |
+
vectorstore = RAG_CURRENT_VECTORSTORE
|
| 237 |
+
print(f'Reuse vectorstore: {file_input}')
|
| 238 |
+
else:
|
| 239 |
+
vectorstore = load_document_split_vectorstore(file_input)
|
| 240 |
+
print(f'New vectorstore: {RAG_CURRENT_FILE} {file_input}')
|
| 241 |
+
RAG_CURRENT_FILE = file_input
|
| 242 |
+
docs = vectorstore.similarity_search(message, k=rag_num_docs)
|
| 243 |
+
doc_context = docs_to_rag_context(docs)
|
| 244 |
+
return doc_context
|
| 245 |
+
|
| 246 |
+
# ######### RAG PREPARE
|
| 247 |
+
|
| 248 |
+
|
| 249 |
# ============ CONSTANT ============
|
| 250 |
# https://github.com/gradio-app/gradio/issues/884
|
| 251 |
MODEL_NAME = "SeaLLM-7B"
|
|
|
|
| 845 |
presence_penalty: float,
|
| 846 |
system_prompt: Optional[str] = SYSTEM_PROMPT_1,
|
| 847 |
current_time: Optional[float] = None,
|
| 848 |
+
# profile: Optional[gr.OAuthProfile] = None,
|
| 849 |
) -> str:
|
| 850 |
"""
|
| 851 |
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
|
|
|
| 868 |
global llm, RES_PRINTED
|
| 869 |
assert llm is not None
|
| 870 |
assert system_prompt.strip() != '', f'system prompt is empty'
|
| 871 |
+
# is_by_pass = False if profile is None else profile.username in BYPASS_USERS
|
| 872 |
+
is_by_pass = False
|
| 873 |
|
| 874 |
tokenizer = llm.get_tokenizer()
|
| 875 |
# force removing all
|
|
|
|
| 951 |
|
| 952 |
|
| 953 |
|
| 954 |
+
def chat_response_stream_rag_multiturn(
|
| 955 |
+
message: str,
|
| 956 |
+
history: List[Tuple[str, str]],
|
| 957 |
+
file_input: str,
|
| 958 |
+
temperature: float,
|
| 959 |
+
max_tokens: int,
|
| 960 |
+
# frequency_penalty: float,
|
| 961 |
+
# presence_penalty: float,
|
| 962 |
+
system_prompt: Optional[str] = SYSTEM_PROMPT_1,
|
| 963 |
+
current_time: Optional[float] = None,
|
| 964 |
+
rag_num_docs: Optional[int] = 3,
|
| 965 |
+
):
|
| 966 |
+
message = message.strip()
|
| 967 |
+
frequency_penalty = FREQUENCE_PENALTY
|
| 968 |
+
presence_penalty = PRESENCE_PENALTY
|
| 969 |
+
if len(message) == 0:
|
| 970 |
+
raise gr.Error("The message cannot be empty!")
|
| 971 |
+
doc_context = maybe_get_doc_context(message, file_input, rag_num_docs=rag_num_docs)
|
| 972 |
+
if doc_context is not None:
|
| 973 |
+
message = f"{doc_context}\n\n{message}"
|
| 974 |
+
yield from chat_response_stream_multiturn(
|
| 975 |
+
message, history, temperature, max_tokens, frequency_penalty,
|
| 976 |
+
presence_penalty, system_prompt, current_time
|
| 977 |
+
)
|
| 978 |
+
|
| 979 |
+
|
| 980 |
def debug_generate_free_form_stream(message):
|
| 981 |
output = " This is a debugging message...."
|
| 982 |
for i in range(len(output)):
|
|
|
|
| 1551 |
return demo_chat
|
| 1552 |
|
| 1553 |
|
| 1554 |
+
def upload_file(file):
|
| 1555 |
+
# file_paths = [file.name for file in files]
|
| 1556 |
+
# return file_paths
|
| 1557 |
+
return file.name
|
| 1558 |
+
|
| 1559 |
+
def create_chat_demo_rag(title=None, description=None):
|
| 1560 |
+
sys_prompt = SYSTEM_PROMPT_1
|
| 1561 |
+
max_tokens = MAX_TOKENS
|
| 1562 |
+
temperature = TEMPERATURE
|
| 1563 |
+
frequence_penalty = FREQUENCE_PENALTY
|
| 1564 |
+
presence_penalty = PRESENCE_PENALTY
|
| 1565 |
+
|
| 1566 |
+
# with gr.Blocks(title="RAG") as rag_demo:
|
| 1567 |
+
additional_inputs = [
|
| 1568 |
+
# gr.File(label='Upload Document', file_count='single', file_types=['pdf', 'docx', 'txt', 'json']),
|
| 1569 |
+
gr.Textbox(value=None, label='Document path', lines=1, interactive=False),
|
| 1570 |
+
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
| 1571 |
+
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
| 1572 |
+
# gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
|
| 1573 |
+
# gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
|
| 1574 |
+
gr.Textbox(value=sys_prompt, label='System prompt', lines=1, interactive=False),
|
| 1575 |
+
gr.Number(value=0, label='current_time', visible=False),
|
| 1576 |
+
]
|
| 1577 |
+
|
| 1578 |
+
|
| 1579 |
+
demo_rag_chat = gr.ChatInterface(
|
| 1580 |
+
chat_response_stream_rag_multiturn,
|
| 1581 |
+
chatbot=gr.Chatbot(
|
| 1582 |
+
label=MODEL_NAME + "-RAG",
|
| 1583 |
+
bubble_full_width=False,
|
| 1584 |
+
latex_delimiters=[
|
| 1585 |
+
{ "left": "$", "right": "$", "display": False},
|
| 1586 |
+
{ "left": "$$", "right": "$$", "display": True},
|
| 1587 |
+
],
|
| 1588 |
+
show_copy_button=True,
|
| 1589 |
+
),
|
| 1590 |
+
textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200),
|
| 1591 |
+
submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
|
| 1592 |
+
# ! consider preventing the stop button
|
| 1593 |
+
# stop_btn=None,
|
| 1594 |
+
title=title,
|
| 1595 |
+
description=description,
|
| 1596 |
+
additional_inputs=additional_inputs,
|
| 1597 |
+
additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
|
| 1598 |
+
# examples=CHAT_EXAMPLES,
|
| 1599 |
+
cache_examples=False
|
| 1600 |
+
)
|
| 1601 |
+
with demo_rag_chat:
|
| 1602 |
+
upload_button = gr.UploadButton("Click to Upload document", file_types=['pdf', 'docx', 'txt', 'json'], file_count="single")
|
| 1603 |
+
upload_button.upload(upload_file, upload_button, additional_inputs[0])
|
| 1604 |
+
|
| 1605 |
+
# return demo_chat
|
| 1606 |
+
return demo_rag_chat
|
| 1607 |
+
|
| 1608 |
+
|
| 1609 |
|
| 1610 |
def launch_demo():
|
| 1611 |
global demo, llm, DEBUG, LOG_FILE
|
|
|
|
| 1700 |
|
| 1701 |
if ENABLE_BATCH_INFER:
|
| 1702 |
|
| 1703 |
+
# demo_file_upload = create_file_upload_demo()
|
| 1704 |
|
| 1705 |
demo_free_form = create_free_form_generation_demo()
|
| 1706 |
|
| 1707 |
demo_chat = create_chat_demo()
|
| 1708 |
+
demo_chat_rag = create_chat_demo_rag()
|
| 1709 |
descriptions = model_desc
|
| 1710 |
if DISPLAY_MODEL_PATH:
|
| 1711 |
descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
|
| 1712 |
|
| 1713 |
demo = CustomTabbedInterface(
|
| 1714 |
+
interface_list=[
|
| 1715 |
+
demo_chat,
|
| 1716 |
+
demo_chat_rag,
|
| 1717 |
+
demo_free_form
|
| 1718 |
+
# demo_file_upload,
|
| 1719 |
+
],
|
| 1720 |
+
tab_names=[
|
| 1721 |
+
"Chat Interface",
|
| 1722 |
+
"RAG Chat Interface"
|
| 1723 |
+
"Text completion"
|
| 1724 |
+
# "Batch Inference",
|
| 1725 |
+
],
|
| 1726 |
title=f"{model_title}",
|
| 1727 |
description=descriptions,
|
| 1728 |
)
|
|
|
|
| 1749 |
if ENABLE_AGREE_POPUP:
|
| 1750 |
demo.load(None, None, None, _js=AGREE_POP_SCRIPTS)
|
| 1751 |
|
| 1752 |
+
# login_btn = gr.LoginButton()
|
| 1753 |
|
| 1754 |
demo.queue(api_open=False)
|
| 1755 |
return demo
|