Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
|
| 4 |
# Description:
|
| 5 |
"""
|
| 6 |
-
VLLM-based demo script to launch Language chat model for
|
| 7 |
"""
|
| 8 |
|
| 9 |
|
|
@@ -29,12 +29,16 @@ from huggingface_hub import snapshot_download
|
|
| 29 |
|
| 30 |
DEBUG = bool(int(os.environ.get("DEBUG", "1")))
|
| 31 |
BLOCK_ZH = bool(int(os.environ.get("BLOCK_ZH", "1")))
|
|
|
|
|
|
|
| 32 |
TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1"))
|
| 33 |
DTYPE = os.environ.get("DTYPE", "bfloat16")
|
| 34 |
|
| 35 |
# ! (no debug) whether to download HF_MODEL_NAME and save to MODEL_PATH
|
| 36 |
DOWNLOAD_SNAPSHOT = bool(int(os.environ.get("DOWNLOAD_SNAPSHOT", "0")))
|
| 37 |
LOG_RESPONSE = bool(int(os.environ.get("LOG_RESPONSE", "0")))
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# ! uploaded model path, will be downloaded to MODEL_PATH
|
| 40 |
HF_MODEL_NAME = os.environ.get("HF_MODEL_NAME", "DAMO-NLP-SG/seal-13b-chat-a")
|
|
@@ -80,7 +84,6 @@ MODEL_PATH=./seal-13b-chat-a
|
|
| 80 |
"""
|
| 81 |
|
| 82 |
|
| 83 |
-
|
| 84 |
# ==============================
|
| 85 |
print(f'DEBUG mode: {DEBUG}')
|
| 86 |
print(f'Torch version: {torch.__version__}')
|
|
@@ -113,9 +116,10 @@ EOS_TOKEN = '</s>'
|
|
| 113 |
B_INST, E_INST = "[INST]", "[/INST]"
|
| 114 |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
| 115 |
|
| 116 |
-
SYSTEM_PROMPT_1 = """You are a multilingual, helpful, respectful and honest assistant. Your name is
|
| 117 |
-
|
| 118 |
-
|
|
|
|
| 119 |
|
| 120 |
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
| 121 |
correct. If you don't know the answer to a question, please don't share false information.
|
|
@@ -127,8 +131,8 @@ Your response should adapt to the norms and customs of the respective language a
|
|
| 127 |
# ============ CONSTANT ============
|
| 128 |
# https://github.com/gradio-app/gradio/issues/884
|
| 129 |
MODEL_NAME = "SeaLLM-13B"
|
| 130 |
-
MODEL_TITLE = "SeaLLM-13B - An Assistant for
|
| 131 |
-
|
| 132 |
MODEL_TITLE = """
|
| 133 |
<div class="container" style="
|
| 134 |
align-items: center;
|
|
@@ -150,13 +154,13 @@ MODEL_TITLE = """
|
|
| 150 |
padding-top: 2%;
|
| 151 |
float: left;
|
| 152 |
">
|
| 153 |
-
<h1>SeaLLM-13B - An Assistant for
|
| 154 |
</div>
|
| 155 |
</div>
|
| 156 |
"""
|
| 157 |
MODEL_DESC = """
|
| 158 |
<span style="font-size: larger">
|
| 159 |
-
This is SeaLLM-13B - a chatbot assistant optimized for
|
| 160 |
</span>
|
| 161 |
<br>
|
| 162 |
<span style="color: red">NOTICE: The chatbot may produce inaccurate and harmful information about people, places, or facts. \
|
|
@@ -171,19 +175,12 @@ If you find our project useful, hope you can star our repo and cite our paper as
|
|
| 171 |
```
|
| 172 |
@article{damonlpsg2023seallm,
|
| 173 |
author = {???},
|
| 174 |
-
title = {SeaLLM: A language model for
|
| 175 |
year = 2023,
|
| 176 |
}
|
| 177 |
```
|
| 178 |
"""
|
| 179 |
|
| 180 |
-
# warning_markdown = """
|
| 181 |
-
# ## Warning:
|
| 182 |
-
# <span style="color: red">The chatbot may produce inaccurate and harmful information about people, places, or facts.</span>
|
| 183 |
-
# <span style="color: red">We strongly advise against misuse of the chatbot to knowingly generate harmful or unethical content, \
|
| 184 |
-
# or content that violates locally applicable and international laws or regulations, including hate speech, violence, pornography, deception, etc!</span>
|
| 185 |
-
# """
|
| 186 |
-
|
| 187 |
path_markdown = """
|
| 188 |
#### Model path:
|
| 189 |
{model_path}
|
|
@@ -191,12 +188,12 @@ path_markdown = """
|
|
| 191 |
|
| 192 |
|
| 193 |
def _detect_lang(text):
|
|
|
|
| 194 |
from langdetect import detect as detect_lang
|
| 195 |
dlang = None
|
| 196 |
try:
|
| 197 |
dlang = detect_lang(text)
|
| 198 |
except Exception as e:
|
| 199 |
-
# No features in text.
|
| 200 |
print(f'Error: {e}')
|
| 201 |
if "No features in text." in str(e):
|
| 202 |
return "en"
|
|
@@ -491,7 +488,7 @@ def new_llama_load_weights(
|
|
| 491 |
load_format: str = "auto",
|
| 492 |
revision: Optional[str] = None
|
| 493 |
):
|
| 494 |
-
# If use newest vllm
|
| 495 |
from vllm.model_executor.weight_utils import (
|
| 496 |
load_tensor_parallel_weights, hf_model_weights_iterator
|
| 497 |
)
|
|
@@ -886,24 +883,6 @@ def _setup_events(self) -> None:
|
|
| 886 |
gr.ChatInterface._setup_stop_events = _setup_stop_events
|
| 887 |
gr.ChatInterface._setup_events = _setup_events
|
| 888 |
|
| 889 |
-
def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str:
|
| 890 |
-
global llm
|
| 891 |
-
assert llm is not None
|
| 892 |
-
from vllm import LLM, SamplingParams
|
| 893 |
-
temperature = float(temperature)
|
| 894 |
-
max_tokens = int(max_tokens)
|
| 895 |
-
if system_prompt.strip() != '':
|
| 896 |
-
# chat version, add system prompt
|
| 897 |
-
message = llama_chat_sys_input_seq_constructor(
|
| 898 |
-
message.strip(),
|
| 899 |
-
sys_prompt=system_prompt
|
| 900 |
-
)
|
| 901 |
-
|
| 902 |
-
sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens)
|
| 903 |
-
gen = llm.generate(message, sampling_params)
|
| 904 |
-
out = gen[0].outputs[0].text
|
| 905 |
-
return f'{out}'
|
| 906 |
-
|
| 907 |
|
| 908 |
def vllm_abort(self: Any):
|
| 909 |
from vllm.sequence import SequenceStatus
|
|
@@ -991,16 +970,19 @@ def vllm_generate_stream(
|
|
| 991 |
yield from _vllm_run_engine(self, use_tqdm)
|
| 992 |
|
| 993 |
|
| 994 |
-
BLOCK_MESSAGE = """Sorry, Chinese is not currently supported. Please clear the chat box for a new conversation.
|
| 995 |
-
抱歉,目前不支持中文。 请清除聊天框以进行新对话。"""
|
| 996 |
|
| 997 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 998 |
|
| 999 |
def block_zh(
|
| 1000 |
message: str,
|
| 1001 |
-
history: List[Tuple[str, str]]
|
| 1002 |
) -> str:
|
| 1003 |
-
|
|
|
|
| 1004 |
return True
|
| 1005 |
elif 'zh' in _detect_lang(message):
|
| 1006 |
print(f'Detect zh: {message}')
|
|
@@ -1021,10 +1003,10 @@ def safety_check(text, history=None, ) -> Optional[str]:
|
|
| 1021 |
if BLOCK_ZH:
|
| 1022 |
if history is not None:
|
| 1023 |
if block_zh(text, history):
|
| 1024 |
-
return
|
| 1025 |
else:
|
| 1026 |
if "zh" in _detect_lang(text):
|
| 1027 |
-
return
|
| 1028 |
|
| 1029 |
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
|
| 1030 |
return KEYWORD_BLOCK_MESSAGE
|
|
@@ -1149,9 +1131,12 @@ def launch():
|
|
| 1149 |
ckpt_info = "None"
|
| 1150 |
|
| 1151 |
print(
|
| 1152 |
-
f'Launch config: {
|
|
|
|
| 1153 |
f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
|
| 1154 |
f'\n| STREAM_CHECK_MULTIPLE={STREAM_CHECK_MULTIPLE} '
|
|
|
|
|
|
|
| 1155 |
f'\n| frequence_penalty={frequence_penalty} '
|
| 1156 |
f'\n| temperature={temperature} '
|
| 1157 |
f'\n| hf_model_name={hf_model_name} '
|
|
@@ -1159,8 +1144,8 @@ def launch():
|
|
| 1159 |
f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
|
| 1160 |
f'\n| gpu_memory_utilization={gpu_memory_utilization} '
|
| 1161 |
f'\n| KEYWORDS={KEYWORDS} '
|
| 1162 |
-
f'\
|
| 1163 |
-
f'\
|
| 1164 |
)
|
| 1165 |
|
| 1166 |
if DEBUG:
|
|
@@ -1230,7 +1215,8 @@ def launch():
|
|
| 1230 |
with demo:
|
| 1231 |
# gr.Markdown(warning_markdown)
|
| 1232 |
gr.Markdown(cite_markdown)
|
| 1233 |
-
|
|
|
|
| 1234 |
|
| 1235 |
demo.queue()
|
| 1236 |
demo.launch(server_port=PORT)
|
|
@@ -1243,3 +1229,4 @@ def main():
|
|
| 1243 |
|
| 1244 |
if __name__ == "__main__":
|
| 1245 |
main()
|
|
|
|
|
|
| 3 |
|
| 4 |
# Description:
|
| 5 |
"""
|
| 6 |
+
VLLM-based demo script to launch Language chat model for Southeast Asian Languages
|
| 7 |
"""
|
| 8 |
|
| 9 |
|
|
|
|
| 29 |
|
| 30 |
DEBUG = bool(int(os.environ.get("DEBUG", "1")))
|
| 31 |
BLOCK_ZH = bool(int(os.environ.get("BLOCK_ZH", "1")))
|
| 32 |
+
# for lang block, wether to block in history too
|
| 33 |
+
LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0")))
|
| 34 |
TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1"))
|
| 35 |
DTYPE = os.environ.get("DTYPE", "bfloat16")
|
| 36 |
|
| 37 |
# ! (no debug) whether to download HF_MODEL_NAME and save to MODEL_PATH
|
| 38 |
DOWNLOAD_SNAPSHOT = bool(int(os.environ.get("DOWNLOAD_SNAPSHOT", "0")))
|
| 39 |
LOG_RESPONSE = bool(int(os.environ.get("LOG_RESPONSE", "0")))
|
| 40 |
+
# ! show model path in the demo page, only for internal
|
| 41 |
+
DISPLAY_MODEL_PATH = bool(int(os.environ.get("DISPLAY_MODEL_PATH", "1")))
|
| 42 |
|
| 43 |
# ! uploaded model path, will be downloaded to MODEL_PATH
|
| 44 |
HF_MODEL_NAME = os.environ.get("HF_MODEL_NAME", "DAMO-NLP-SG/seal-13b-chat-a")
|
|
|
|
| 84 |
"""
|
| 85 |
|
| 86 |
|
|
|
|
| 87 |
# ==============================
|
| 88 |
print(f'DEBUG mode: {DEBUG}')
|
| 89 |
print(f'Torch version: {torch.__version__}')
|
|
|
|
| 116 |
B_INST, E_INST = "[INST]", "[/INST]"
|
| 117 |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
| 118 |
|
| 119 |
+
SYSTEM_PROMPT_1 = """You are a multilingual, helpful, respectful and honest assistant. Your name is SeaLLM and you are built by DAMO Academy, Alibaba Group. \
|
| 120 |
+
Please always answer as helpfully as possible, while being safe. Your \
|
| 121 |
+
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure \
|
| 122 |
+
that your responses are socially unbiased and positive in nature.
|
| 123 |
|
| 124 |
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
| 125 |
correct. If you don't know the answer to a question, please don't share false information.
|
|
|
|
| 131 |
# ============ CONSTANT ============
|
| 132 |
# https://github.com/gradio-app/gradio/issues/884
|
| 133 |
MODEL_NAME = "SeaLLM-13B"
|
| 134 |
+
MODEL_TITLE = "SeaLLM-13B - An Assistant for Southeast Asian Languages"
|
| 135 |
+
|
| 136 |
MODEL_TITLE = """
|
| 137 |
<div class="container" style="
|
| 138 |
align-items: center;
|
|
|
|
| 154 |
padding-top: 2%;
|
| 155 |
float: left;
|
| 156 |
">
|
| 157 |
+
<h1>SeaLLM-13B - An Assistant for Southeast Asian Languages</h1>
|
| 158 |
</div>
|
| 159 |
</div>
|
| 160 |
"""
|
| 161 |
MODEL_DESC = """
|
| 162 |
<span style="font-size: larger">
|
| 163 |
+
This is SeaLLM-13B - a chatbot assistant optimized for Southeast Asian Languages. It can produce helpful responses in English 🇬🇧, Vietnamese 🇻🇳, Indonesian 🇮🇩 and Thai 🇹🇭.
|
| 164 |
</span>
|
| 165 |
<br>
|
| 166 |
<span style="color: red">NOTICE: The chatbot may produce inaccurate and harmful information about people, places, or facts. \
|
|
|
|
| 175 |
```
|
| 176 |
@article{damonlpsg2023seallm,
|
| 177 |
author = {???},
|
| 178 |
+
title = {SeaLLM: A language model for Southeast Asian Languages},
|
| 179 |
year = 2023,
|
| 180 |
}
|
| 181 |
```
|
| 182 |
"""
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
path_markdown = """
|
| 185 |
#### Model path:
|
| 186 |
{model_path}
|
|
|
|
| 188 |
|
| 189 |
|
| 190 |
def _detect_lang(text):
|
| 191 |
+
# Disable language that may have safety risk
|
| 192 |
from langdetect import detect as detect_lang
|
| 193 |
dlang = None
|
| 194 |
try:
|
| 195 |
dlang = detect_lang(text)
|
| 196 |
except Exception as e:
|
|
|
|
| 197 |
print(f'Error: {e}')
|
| 198 |
if "No features in text." in str(e):
|
| 199 |
return "en"
|
|
|
|
| 488 |
load_format: str = "auto",
|
| 489 |
revision: Optional[str] = None
|
| 490 |
):
|
| 491 |
+
# If use newest vllm, not been thoroughly tested yet.
|
| 492 |
from vllm.model_executor.weight_utils import (
|
| 493 |
load_tensor_parallel_weights, hf_model_weights_iterator
|
| 494 |
)
|
|
|
|
| 883 |
gr.ChatInterface._setup_stop_events = _setup_stop_events
|
| 884 |
gr.ChatInterface._setup_events = _setup_events
|
| 885 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 886 |
|
| 887 |
def vllm_abort(self: Any):
|
| 888 |
from vllm.sequence import SequenceStatus
|
|
|
|
| 970 |
yield from _vllm_run_engine(self, use_tqdm)
|
| 971 |
|
| 972 |
|
|
|
|
|
|
|
| 973 |
|
| 974 |
+
# ! avoid saying
|
| 975 |
+
LANG_BLOCK_MESSAGE = """Sorry, the language you have asked is currently not supported. If you have questions in other supported languages, I'll be glad to help. \
|
| 976 |
+
Please also consider clearing the chat box for a better experience."""
|
| 977 |
+
|
| 978 |
+
KEYWORD_BLOCK_MESSAGE = "Sorry, I cannot fulfill your request. If you have any unrelated question, I'll be glad to help."
|
| 979 |
|
| 980 |
def block_zh(
|
| 981 |
message: str,
|
| 982 |
+
history: List[Tuple[str, str]] = None,
|
| 983 |
) -> str:
|
| 984 |
+
# relieve history base block
|
| 985 |
+
if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history):
|
| 986 |
return True
|
| 987 |
elif 'zh' in _detect_lang(message):
|
| 988 |
print(f'Detect zh: {message}')
|
|
|
|
| 1003 |
if BLOCK_ZH:
|
| 1004 |
if history is not None:
|
| 1005 |
if block_zh(text, history):
|
| 1006 |
+
return LANG_BLOCK_MESSAGE
|
| 1007 |
else:
|
| 1008 |
if "zh" in _detect_lang(text):
|
| 1009 |
+
return LANG_BLOCK_MESSAGE
|
| 1010 |
|
| 1011 |
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
|
| 1012 |
return KEYWORD_BLOCK_MESSAGE
|
|
|
|
| 1131 |
ckpt_info = "None"
|
| 1132 |
|
| 1133 |
print(
|
| 1134 |
+
f'Launch config: {tensor_parallel=} / {dtype=} / {max_tokens} | {BLOCK_ZH=} '
|
| 1135 |
+
f'\n| model_title=`{model_title}` '
|
| 1136 |
f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
|
| 1137 |
f'\n| STREAM_CHECK_MULTIPLE={STREAM_CHECK_MULTIPLE} '
|
| 1138 |
+
f'\n| DISPLAY_MODEL_PATH={DISPLAY_MODEL_PATH} '
|
| 1139 |
+
f'\n| LANG_BLOCK_HISTORY={LANG_BLOCK_HISTORY} '
|
| 1140 |
f'\n| frequence_penalty={frequence_penalty} '
|
| 1141 |
f'\n| temperature={temperature} '
|
| 1142 |
f'\n| hf_model_name={hf_model_name} '
|
|
|
|
| 1144 |
f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
|
| 1145 |
f'\n| gpu_memory_utilization={gpu_memory_utilization} '
|
| 1146 |
f'\n| KEYWORDS={KEYWORDS} '
|
| 1147 |
+
f'\n| Sys={SYSTEM_PROMPT_1}'
|
| 1148 |
+
f'\n| Desc={model_desc}'
|
| 1149 |
)
|
| 1150 |
|
| 1151 |
if DEBUG:
|
|
|
|
| 1215 |
with demo:
|
| 1216 |
# gr.Markdown(warning_markdown)
|
| 1217 |
gr.Markdown(cite_markdown)
|
| 1218 |
+
if DISPLAY_MODEL_PATH:
|
| 1219 |
+
gr.Markdown(path_markdown.format(model_path=model_path))
|
| 1220 |
|
| 1221 |
demo.queue()
|
| 1222 |
demo.launch(server_port=PORT)
|
|
|
|
| 1229 |
|
| 1230 |
if __name__ == "__main__":
|
| 1231 |
main()
|
| 1232 |
+
|