Spaces:
Paused
Paused
phi
commited on
Commit
·
6ded56f
1
Parent(s):
a572fd2
update
Browse files
app.py
CHANGED
|
@@ -32,17 +32,72 @@ from huggingface_hub import snapshot_download
|
|
| 32 |
# @@ constants ================
|
| 33 |
|
| 34 |
DEBUG = bool(int(os.environ.get("DEBUG", "1")))
|
| 35 |
-
BLOCK_ZH = bool(int(os.environ.get("BLOCK_ZH", "
|
| 36 |
-
|
| 37 |
TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1"))
|
| 38 |
DTYPE = os.environ.get("DTYPE", "bfloat16")
|
| 39 |
-
# DTYPE = 'float16'
|
| 40 |
|
| 41 |
-
#
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
PORT = int(os.environ.get("PORT", "7860"))
|
| 44 |
STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1"))
|
| 45 |
-
MAX_TOKENS = 2048
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
# @@ constants ================
|
| 48 |
if not DEBUG:
|
|
@@ -115,7 +170,6 @@ def hf_model_weights_iterator(
|
|
| 115 |
x for x in glob.glob(os.path.join(hf_folder, "*model*.safetensors"))
|
| 116 |
if not x.endswith("training_args.bin")
|
| 117 |
]
|
| 118 |
-
# print(F'Load bin files: {hf_bin_files} // safetensors: {hf_safetensors_files}')
|
| 119 |
|
| 120 |
if use_np_cache:
|
| 121 |
# Convert the model weights from torch tensors to numpy arrays for
|
|
@@ -226,15 +280,8 @@ def llama_load_weights(
|
|
| 226 |
state_dict = self.state_dict()
|
| 227 |
need_to_load = len(state_dict)
|
| 228 |
loaded = 0
|
| 229 |
-
# try:
|
| 230 |
-
# iterator = hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache)
|
| 231 |
-
# except Exception as e:
|
| 232 |
-
# iterator = hf_model_weights_iterator(model_name_or_path, cache_dir, load_format, revision)
|
| 233 |
iterator = hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache)
|
| 234 |
|
| 235 |
-
# for name, loaded_weight in hf_model_weights_iterator(
|
| 236 |
-
# model_name_or_path, cache_dir, load_format, revision):
|
| 237 |
-
# model_name_or_path, cache_dir, use_np_cache):
|
| 238 |
for name, loaded_weight in iterator:
|
| 239 |
if "rotary_emb.inv_freq" in name:
|
| 240 |
continue
|
|
@@ -253,12 +300,6 @@ def llama_load_weights(
|
|
| 253 |
if num_extra_rows > 0:
|
| 254 |
print(f'Add empty to {num_extra_rows} extra row for {name}')
|
| 255 |
print(f'Load: {name} | {padded_vocab_size=} | {self.config.vocab_size=} | {num_extra_rows=} | {param.size()=} | {loaded_weight.size()=} | {load_size=}')
|
| 256 |
-
|
| 257 |
-
# if "embed_tokens" in name or "lm_head" in name:
|
| 258 |
-
# param = state_dict[name]
|
| 259 |
-
# load_padded_tensor_parallel_vocab(param, loaded_weight, tensor_model_parallel_rank)
|
| 260 |
-
# loaded += 1
|
| 261 |
-
# continue
|
| 262 |
|
| 263 |
is_attention_weight = False
|
| 264 |
for weight_name, shard_size, offset in attention_weight_specs:
|
|
@@ -385,8 +426,6 @@ if not DEBUG:
|
|
| 385 |
|
| 386 |
set_documentation_group("component")
|
| 387 |
|
| 388 |
-
DATA_ROOT = os.environ.get("dataroot", "/mnt/workspace/workgroup/phi")
|
| 389 |
-
MODEL_CACHE_DIR = os.path.join(DATA_ROOT, "pret_models")
|
| 390 |
|
| 391 |
|
| 392 |
DTYPES = {
|
|
@@ -397,7 +436,6 @@ DTYPES = {
|
|
| 397 |
llm = None
|
| 398 |
demo = None
|
| 399 |
|
| 400 |
-
RELOAD_SIGNAL = '<<<reload:'
|
| 401 |
|
| 402 |
BOS_TOKEN = '<s>'
|
| 403 |
EOS_TOKEN = '</s>'
|
|
@@ -824,28 +862,64 @@ path_markdown = """
|
|
| 824 |
{model_path}
|
| 825 |
"""
|
| 826 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 827 |
|
| 828 |
def launch():
|
| 829 |
global demo, llm, DEBUG
|
| 830 |
model_desc = MODEL_DESC
|
| 831 |
model_path = MODEL_PATH
|
| 832 |
model_title = MODEL_TITLE
|
|
|
|
| 833 |
tensor_parallel = TENSOR_PARALLEL
|
| 834 |
assert tensor_parallel > 0 , f'{tensor_parallel} invalid'
|
| 835 |
dtype = DTYPE
|
| 836 |
sys_prompt = SYSTEM_PROMPT_1
|
| 837 |
max_tokens = MAX_TOKENS
|
| 838 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 839 |
|
| 840 |
if DEBUG:
|
| 841 |
-
model_desc += "\n<br>!!!!! This is in debug mode, responses will
|
| 842 |
response_fn = debug_chat_response_echo
|
|
|
|
| 843 |
else:
|
| 844 |
# ! load the model
|
| 845 |
import vllm
|
| 846 |
-
assert os.path.exists(model_path), f'{model_path} not found'
|
| 847 |
print(F'VLLM: {vllm.__version__}')
|
| 848 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 849 |
llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel)
|
| 850 |
|
| 851 |
print(f'Use system prompt:\n{sys_prompt}')
|
|
@@ -871,9 +945,9 @@ def launch():
|
|
| 871 |
description=f"{model_desc}",
|
| 872 |
# ! decide if can change the system prompt.
|
| 873 |
additional_inputs=[
|
| 874 |
-
gr.Number(value=
|
| 875 |
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
| 876 |
-
gr.Number(value=
|
| 877 |
# gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
|
| 878 |
],
|
| 879 |
)
|
|
|
|
| 32 |
# @@ constants ================
|
| 33 |
|
| 34 |
DEBUG = bool(int(os.environ.get("DEBUG", "1")))
|
| 35 |
+
BLOCK_ZH = bool(int(os.environ.get("BLOCK_ZH", "1")))
|
|
|
|
| 36 |
TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1"))
|
| 37 |
DTYPE = os.environ.get("DTYPE", "bfloat16")
|
|
|
|
| 38 |
|
| 39 |
+
# ! (no debug) whether to download HF_MODEL_NAME and save to MODEL_PATH
|
| 40 |
+
DOWNLOAD_SNAPSHOT = bool(int(os.environ.get("DOWNLOAD_SNAPSHOT", "0")))
|
| 41 |
+
# ! uploaded model path, will be downloaded to MODEL_PATH
|
| 42 |
+
HF_MODEL_NAME = os.environ.get("HF_MODEL_NAME", "DAMO-NLP-SG/seal-13b-chat-a")
|
| 43 |
+
MODEL_PATH = os.environ.get("MODEL_PATH", "./seal-13b-chat-a")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# gradio config
|
| 48 |
PORT = int(os.environ.get("PORT", "7860"))
|
| 49 |
STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1"))
|
| 50 |
+
MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
|
| 51 |
+
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1"))
|
| 52 |
+
FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.4"))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
"""
|
| 56 |
+
TODO:
|
| 57 |
+
need to upload the model as hugginface/models/seal_13b_a
|
| 58 |
+
# https://huggingface.co/docs/hub/spaces-overview#managing-secrets
|
| 59 |
+
set
|
| 60 |
+
MODEL_REPO_ID=hugginface/models/seal_13b_a
|
| 61 |
+
|
| 62 |
+
# if persistent, then export the following
|
| 63 |
+
HF_HOME=/data/.huggingface
|
| 64 |
+
TRANSFORMERS_CACHE=/data/.huggingface
|
| 65 |
+
MODEL_PATH=/data/.huggingface/seal-13b-chat-a
|
| 66 |
+
HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a
|
| 67 |
+
# if not persistent
|
| 68 |
+
MODEL_PATH=./seal-13b-chat-a
|
| 69 |
+
HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# download will auto detect and get the most updated one
|
| 74 |
+
if DOWNLOAD_SNAPSHOT:
|
| 75 |
+
print(f'Download from HF_MODEL_NAME={HF_MODEL_NAME} -> {MODEL_PATH}')
|
| 76 |
+
snapshot_download(HF_MODEL_NAME, local_dir=MODEL_PATH)
|
| 77 |
+
elif not DEBUG:
|
| 78 |
+
assert os.path.exists(MODEL_PATH), f'{MODEL_PATH} not found and no snapshot download'
|
| 79 |
+
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ==============================
|
| 86 |
+
print(f'DEBUG mode: {DEBUG}')
|
| 87 |
+
|
| 88 |
+
if DTYPE == "bfloat16" and not DEBUG:
|
| 89 |
+
try:
|
| 90 |
+
compute_capability = torch.cuda.get_device_capability()
|
| 91 |
+
if compute_capability[0] < 8:
|
| 92 |
+
gpu_name = torch.cuda.get_device_name()
|
| 93 |
+
print(
|
| 94 |
+
"Bfloat16 is only supported on GPUs with compute capability "
|
| 95 |
+
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
| 96 |
+
f"{compute_capability[0]}.{compute_capability[1]}. --> Move to FLOAT16")
|
| 97 |
+
DTYPE = "float16"
|
| 98 |
+
except Exception as e:
|
| 99 |
+
print(f'Unable to obtain compute_capability: {e}')
|
| 100 |
+
|
| 101 |
|
| 102 |
# @@ constants ================
|
| 103 |
if not DEBUG:
|
|
|
|
| 170 |
x for x in glob.glob(os.path.join(hf_folder, "*model*.safetensors"))
|
| 171 |
if not x.endswith("training_args.bin")
|
| 172 |
]
|
|
|
|
| 173 |
|
| 174 |
if use_np_cache:
|
| 175 |
# Convert the model weights from torch tensors to numpy arrays for
|
|
|
|
| 280 |
state_dict = self.state_dict()
|
| 281 |
need_to_load = len(state_dict)
|
| 282 |
loaded = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
iterator = hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache)
|
| 284 |
|
|
|
|
|
|
|
|
|
|
| 285 |
for name, loaded_weight in iterator:
|
| 286 |
if "rotary_emb.inv_freq" in name:
|
| 287 |
continue
|
|
|
|
| 300 |
if num_extra_rows > 0:
|
| 301 |
print(f'Add empty to {num_extra_rows} extra row for {name}')
|
| 302 |
print(f'Load: {name} | {padded_vocab_size=} | {self.config.vocab_size=} | {num_extra_rows=} | {param.size()=} | {loaded_weight.size()=} | {load_size=}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
is_attention_weight = False
|
| 305 |
for weight_name, shard_size, offset in attention_weight_specs:
|
|
|
|
| 426 |
|
| 427 |
set_documentation_group("component")
|
| 428 |
|
|
|
|
|
|
|
| 429 |
|
| 430 |
|
| 431 |
DTYPES = {
|
|
|
|
| 436 |
llm = None
|
| 437 |
demo = None
|
| 438 |
|
|
|
|
| 439 |
|
| 440 |
BOS_TOKEN = '<s>'
|
| 441 |
EOS_TOKEN = '</s>'
|
|
|
|
| 862 |
{model_path}
|
| 863 |
"""
|
| 864 |
|
| 865 |
+
def check_model_path(model_path) -> str:
|
| 866 |
+
assert os.path.exists(model_path), f'{model_path} not found'
|
| 867 |
+
ckpt_info = "None"
|
| 868 |
+
if os.path.isdir(model_path):
|
| 869 |
+
if os.path.exists(f'{model_path}/info.txt'):
|
| 870 |
+
with open(f'{model_path}/info.txt', 'r') as f:
|
| 871 |
+
ckpt_info = f.read()
|
| 872 |
+
print(f'Checkpoint info:\n{ckpt_info}\n-----')
|
| 873 |
+
else:
|
| 874 |
+
print(f'info.txt not found in {model_path}')
|
| 875 |
+
print(f'model path dir: {list(os.listdir(model_path))}')
|
| 876 |
+
|
| 877 |
+
return ckpt_info
|
| 878 |
+
|
| 879 |
|
| 880 |
def launch():
|
| 881 |
global demo, llm, DEBUG
|
| 882 |
model_desc = MODEL_DESC
|
| 883 |
model_path = MODEL_PATH
|
| 884 |
model_title = MODEL_TITLE
|
| 885 |
+
hf_model_name = HF_MODEL_NAME
|
| 886 |
tensor_parallel = TENSOR_PARALLEL
|
| 887 |
assert tensor_parallel > 0 , f'{tensor_parallel} invalid'
|
| 888 |
dtype = DTYPE
|
| 889 |
sys_prompt = SYSTEM_PROMPT_1
|
| 890 |
max_tokens = MAX_TOKENS
|
| 891 |
+
temperature = TEMPERATURE
|
| 892 |
+
frequence_penalty = FREQUENCE_PENALTY
|
| 893 |
+
ckpt_info = "None"
|
| 894 |
+
|
| 895 |
+
print(
|
| 896 |
+
f'Launch config: {model_path=} / {model_title=} / {tensor_parallel=} / {dtype=} / {max_tokens} | {BLOCK_ZH=} '
|
| 897 |
+
f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
|
| 898 |
+
f'\n| frequence_penalty={frequence_penalty} '
|
| 899 |
+
f'\n| temperature={temperature} '
|
| 900 |
+
f'\n| hf_model_name={hf_model_name} '
|
| 901 |
+
f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
|
| 902 |
+
f'\nsys={SYSTEM_PROMPT_1}'
|
| 903 |
+
f'\ndesc={model_desc}'
|
| 904 |
+
)
|
| 905 |
|
| 906 |
if DEBUG:
|
| 907 |
+
model_desc += "\n<br>!!!!! This is in debug mode, responses will copy original"
|
| 908 |
response_fn = debug_chat_response_echo
|
| 909 |
+
print(f'Creating in DEBUG MODE')
|
| 910 |
else:
|
| 911 |
# ! load the model
|
| 912 |
import vllm
|
|
|
|
| 913 |
print(F'VLLM: {vllm.__version__}')
|
| 914 |
+
|
| 915 |
+
if DOWNLOAD_SNAPSHOT:
|
| 916 |
+
print(f'Downloading from HF_MODEL_NAME={hf_model_name} -> {model_path}')
|
| 917 |
+
snapshot_download(hf_model_name, local_dir=model_path)
|
| 918 |
+
|
| 919 |
+
assert os.path.exists(model_path), f'{model_path} not found and no snapshot download'
|
| 920 |
+
ckpt_info = check_model_path(model_path)
|
| 921 |
+
|
| 922 |
+
print(f'Load path: {model_path} | {ckpt_info}')
|
| 923 |
llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel)
|
| 924 |
|
| 925 |
print(f'Use system prompt:\n{sys_prompt}')
|
|
|
|
| 945 |
description=f"{model_desc}",
|
| 946 |
# ! decide if can change the system prompt.
|
| 947 |
additional_inputs=[
|
| 948 |
+
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
| 949 |
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
| 950 |
+
gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens)'),
|
| 951 |
# gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
|
| 952 |
],
|
| 953 |
)
|