Spaces:
Runtime error
Runtime error
| # Copyright: DAMO Academy, Alibaba Group | |
| # By Xuan Phi Nguyen at DAMO Academy, Alibaba Group | |
| # Description: | |
| """ | |
| VLLM-based demo script to launch Language chat model for Southeast Asian Languages | |
| """ | |
| import os | |
| import numpy as np | |
| import argparse | |
| import torch | |
| import gradio as gr | |
| from typing import Any, Iterator | |
| from typing import Iterator, List, Optional, Tuple | |
| import filelock | |
| import glob | |
| import json | |
| from gradio_client.documentation import document, set_documentation_group | |
| from typing import List, Optional, Union, Dict, Tuple | |
| from tqdm.auto import tqdm | |
| from huggingface_hub import snapshot_download | |
| # @@ environments ================ | |
| DEBUG = bool(int(os.environ.get("DEBUG", "1"))) | |
| BLOCK_ZH = bool(int(os.environ.get("BLOCK_ZH", "1"))) | |
| # for lang block, wether to block in history too | |
| LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0"))) | |
| TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1")) | |
| DTYPE = os.environ.get("DTYPE", "bfloat16") | |
| # ! (no debug) whether to download HF_MODEL_NAME and save to MODEL_PATH | |
| DOWNLOAD_SNAPSHOT = bool(int(os.environ.get("DOWNLOAD_SNAPSHOT", "0"))) | |
| LOG_RESPONSE = bool(int(os.environ.get("LOG_RESPONSE", "0"))) | |
| # ! show model path in the demo page, only for internal | |
| DISPLAY_MODEL_PATH = bool(int(os.environ.get("DISPLAY_MODEL_PATH", "1"))) | |
| # ! uploaded model path, will be downloaded to MODEL_PATH | |
| HF_MODEL_NAME = os.environ.get("HF_MODEL_NAME", "DAMO-NLP-SG/seal-13b-chat-a") | |
| # ! if model is private, need HF_TOKEN to access the model | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| # ! path where the model is downloaded, either on ./ or persistent disc | |
| MODEL_PATH = os.environ.get("MODEL_PATH", "./seal-13b-chat-a") | |
| # ! !! Whether to delete the folder, ONLY SET THIS IF YOU WANT TO DELETE SAVED MODEL ON PERSISTENT DISC | |
| DELETE_FOLDER = os.environ.get("DELETE_FOLDER", "") | |
| IS_DELETE_FOLDER = DELETE_FOLDER is not None and os.path.exists(DELETE_FOLDER) | |
| print(f'DELETE_FOLDER: {DELETE_FOLDER} | {DOWNLOAD_SNAPSHOT=}') | |
| # ! list of keywords to disabled as security measures to comply with local regulation | |
| KEYWORDS = os.environ.get("KEYWORDS", "").strip() | |
| KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else [] | |
| KEYWORDS = [x.lower() for x in KEYWORDS] | |
| # gradio config | |
| PORT = int(os.environ.get("PORT", "7860")) | |
| # how many iterations to yield response | |
| STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1")) | |
| # how many iterations to perform safety check on response | |
| STREAM_CHECK_MULTIPLE = int(os.environ.get("STREAM_CHECK_MULTIPLE", "0")) | |
| # self explanatory | |
| MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048")) | |
| TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1")) | |
| FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.4")) | |
| gpu_memory_utilization = float(os.environ.get("gpu_memory_utilization", "0.9")) | |
| # whether to enable quantization, currently not in use | |
| QUANTIZATION = str(os.environ.get("QUANTIZATION", "")) | |
| """ | |
| Internal instructions of how to configure the DEMO | |
| 1. Upload SFT model as a model to huggingface: hugginface/models/seal_13b_a | |
| 2. If the model weights is private, set HF_TOKEN=<your private hf token> in https://huggingface.co/spaces/????/?????/settings | |
| 3. space config env: `HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a` or the underlining model | |
| 4. If enable persistent storage: set | |
| HF_HOME=/data/.huggingface | |
| MODEL_PATH=/data/.huggingface/seal-13b-chat-a | |
| if not: | |
| MODEL_PATH=./seal-13b-chat-a | |
| """ | |
| # ============================== | |
| print(f'DEBUG mode: {DEBUG}') | |
| print(f'Torch version: {torch.__version__}') | |
| try: | |
| print(f'Torch CUDA version: {torch.version.cuda}') | |
| except Exception as e: | |
| print(f'Failed to print cuda version: {e}') | |
| try: | |
| compute_capability = torch.cuda.get_device_capability() | |
| print(f'Torch CUDA compute_capability: {compute_capability}') | |
| except Exception as e: | |
| print(f'Failed to print compute_capability version: {e}') | |
| # @@ constants ================ | |
| DTYPES = { | |
| 'float16': torch.float16, | |
| 'bfloat16': torch.bfloat16 | |
| } | |
| llm = None | |
| demo = None | |
| BOS_TOKEN = '<s>' | |
| EOS_TOKEN = '</s>' | |
| B_INST, E_INST = "[INST]", "[/INST]" | |
| B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
| SYSTEM_PROMPT_1 = """You are a multilingual, helpful, respectful and honest assistant. Your name is SeaLLM and you are built by DAMO Academy, Alibaba Group. \ | |
| Please always answer as helpfully as possible, while being safe. Your \ | |
| answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure \ | |
| that your responses are socially unbiased and positive in nature. | |
| If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ | |
| correct. If you don't know the answer to a question, please don't share false information. | |
| As a multilingual assistant, you must respond and follow instructions in the native language of the user by default, unless told otherwise. \ | |
| Your response should adapt to the norms and customs of the respective language and culture. | |
| """ | |
| # ============ CONSTANT ============ | |
| # https://github.com/gradio-app/gradio/issues/884 | |
| MODEL_NAME = "SeaLLM-13B" | |
| MODEL_TITLE = "SeaLLM-13B - An Assistant for Southeast Asian Languages" | |
| MODEL_TITLE = """ | |
| <div class="container" style=" | |
| align-items: center; | |
| justify-content: center; | |
| display: flex; | |
| "> | |
| <div class="image" > | |
| <img src="file/seal_logo.png" style=" | |
| max-width: 10em; | |
| max-height: 5%; | |
| height: 3em; | |
| width: 3em; | |
| float: left; | |
| margin-left: auto; | |
| "> | |
| </div> | |
| <div class="text" style=" | |
| padding-left: 20px; | |
| padding-top: 1%; | |
| float: left; | |
| "> | |
| <h1>SeaLLMs - Large Language Models for Southeast Asia</h1> | |
| </div> | |
| </div> | |
| """ | |
| MODEL_DESC = """ | |
| <div style='display:flex; gap: 0.25rem; '> | |
| <a href=''><img src='https://img.shields.io/badge/Github-Code-success'></a> | |
| <a href='https://huggingface.co/spaces/SeaLLMs/SeaLLM-Chat-13b'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> | |
| <a href='https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> | |
| <a href=''><img src='https://img.shields.io/badge/Paper-PDF-red'></a> | |
| </div> | |
| <span style="font-size: larger"> | |
| This is <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b" target="_blank">SeaLLM-13B-Chat</a> - a chatbot assistant optimized for Southeast Asian Languages. It produces helpful responses in English 🇬🇧, Vietnamese 🇻🇳, Indonesian 🇮🇩 and Thai 🇹🇭. | |
| Explore <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b" target="_blank">our article</a> for more details. | |
| </span> | |
| <br> | |
| <span > | |
| NOTE: The chatbot may produce inaccurate and harmful information about people, places, or facts. | |
| <u style="color: red">By using our service, you are required to agree to the following terms:</u><br> | |
| <ul> | |
| <li > | |
| You must not use our service to generate any harmful, unethical or illegal content that violates locally applicable and international laws or regulations, | |
| including but not limited to hate speech, violence, pornography and deception.</li> | |
| <li > | |
| The service collects user dialogue data for testing and performance improvement, and reserves the right to distribute it under | |
| <a href="https://creativecommons.org/licenses/by/4.0/">Creative Commons Attribution (CC-BY)</a> or similar license. So do not enter any personal information! | |
| </li> | |
| </ul> | |
| </span> | |
| """.strip() | |
| cite_markdown = """ | |
| ## Citation | |
| If you find our project useful, hope you can star our repo and cite our paper as follows: | |
| ``` | |
| @article{damonlpsg2023seallm, | |
| author = {Xuan-Phi Nguyen*, Wenxuan Zhang*, Xin Li*, Mahani Aljunied*, Qingyu Tan, Liying Cheng, Guanzheng Chen, Yue Deng, Sen Yang, Chaoqun Liu, Hang Zhang, Lidong Bing}, | |
| title = {SeaLLMs - Large Language Models for Southeast Asia}, | |
| year = 2023, | |
| } | |
| ``` | |
| """ | |
| path_markdown = """ | |
| #### Model path: | |
| {model_path} | |
| """ | |
| def _detect_lang(text): | |
| # Disable language that may have safety risk | |
| from langdetect import detect as detect_lang | |
| dlang = None | |
| try: | |
| dlang = detect_lang(text) | |
| except Exception as e: | |
| print(f'Error: {e}') | |
| if "No features in text." in str(e): | |
| return "en" | |
| else: | |
| return "zh" | |
| return dlang | |
| def custom_hf_model_weights_iterator( | |
| model_name_or_path: str, | |
| cache_dir: Optional[str] = None, | |
| use_np_cache: bool = False, | |
| ) -> Iterator[Tuple[str, torch.Tensor]]: | |
| # ! if use vllm==0.1.4, use this to augment hf_model_weights_iterator loader | |
| from vllm.model_executor.weight_utils import Disabledtqdm | |
| # Prepare file lock directory to prevent multiple processes from | |
| # downloading the same model weights at the same time. | |
| lock_dir = cache_dir if cache_dir is not None else "/tmp" | |
| lock_file_name = model_name_or_path.replace("/", "-") + ".lock" | |
| lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name)) | |
| # Download model weights from huggingface. | |
| is_local = os.path.isdir(model_name_or_path) | |
| if not is_local: | |
| with lock: | |
| hf_folder = snapshot_download(model_name_or_path, | |
| allow_patterns="*.bin", | |
| cache_dir=cache_dir, | |
| local_files_only=True, | |
| tqdm_class=Disabledtqdm) | |
| else: | |
| hf_folder = model_name_or_path | |
| hf_bin_files = [ | |
| x for x in glob.glob(os.path.join(hf_folder, "*model*.bin")) | |
| if not x.endswith("training_args.bin") | |
| ] | |
| hf_safetensors_files = [ | |
| x for x in glob.glob(os.path.join(hf_folder, "*model*.safetensors")) | |
| if not x.endswith("training_args.bin") | |
| ] | |
| if use_np_cache: | |
| # Convert the model weights from torch tensors to numpy arrays for | |
| # faster loading. | |
| np_folder = os.path.join(hf_folder, "np") | |
| os.makedirs(np_folder, exist_ok=True) | |
| weight_names_file = os.path.join(np_folder, "weight_names.json") | |
| with lock: | |
| if not os.path.exists(weight_names_file): | |
| weight_names = [] | |
| for bin_file in hf_bin_files: | |
| state = torch.load(bin_file, map_location="cpu") | |
| for name, param in state.items(): | |
| param_path = os.path.join(np_folder, name) | |
| with open(param_path, "wb") as f: | |
| np.save(f, param.cpu().detach().numpy()) | |
| weight_names.append(name) | |
| with open(weight_names_file, "w") as f: | |
| json.dump(weight_names, f) | |
| with open(weight_names_file, "r") as f: | |
| weight_names = json.load(f) | |
| for name in weight_names: | |
| param_path = os.path.join(np_folder, name) | |
| with open(param_path, "rb") as f: | |
| param = np.load(f) | |
| yield name, torch.from_numpy(param) | |
| else: | |
| if len(hf_bin_files) > 0: | |
| print(F'Load bin files: {hf_bin_files}') | |
| for bin_file in hf_bin_files: | |
| state = torch.load(bin_file, map_location="cpu") | |
| for name, param in state.items(): | |
| yield name, param | |
| del state | |
| torch.cuda.empty_cache() | |
| elif len(hf_safetensors_files) > 0: | |
| print(F'Load safetensor files: {hf_safetensors_files}') | |
| from safetensors.torch import load_file | |
| for safe_file in hf_safetensors_files: | |
| # state = torch.load(bin_file, map_location="cpu") | |
| state = load_file(safe_file) | |
| for name, param in state.items(): | |
| yield name, param | |
| del state | |
| torch.cuda.empty_cache() | |
| else: | |
| raise ValueError(f'no files available either bin or safe') | |
| def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: | |
| """convert PySafeSlice object from safetensors to torch.Tensor | |
| PySafeSlice object supports indexing, which is done before loading the | |
| actual tensor and can reduce the amount of memory being read into the | |
| memory. However, it does not support more advanced functionalities | |
| like `.view()` or `.t()`. Therefore, if we need to modify the loaded | |
| tensor with these more complicated operators, we need to convert to | |
| tensor first. | |
| """ | |
| if not isinstance(x, torch.Tensor): | |
| x = x[:] | |
| return x | |
| def load_padded_tensor_parallel_vocab( | |
| param: torch.Tensor, | |
| loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` | |
| tensor_model_parallel_rank: int, | |
| ) -> None: | |
| shard_size = param.shape[0] | |
| start_idx = tensor_model_parallel_rank * shard_size | |
| end_idx = (tensor_model_parallel_rank + 1) * shard_size | |
| loaded_weight = loaded_weight[start_idx:end_idx] | |
| loaded_weight = convert_pyslice_to_tensor(loaded_weight) | |
| param[:loaded_weight.shape[0]].copy_(loaded_weight) | |
| def llama_load_weights( | |
| self, | |
| model_name_or_path: str, | |
| cache_dir: Optional[str] = None, | |
| use_np_cache: bool = False, | |
| load_format: str = "auto", | |
| revision: Optional[str] = None | |
| ): | |
| # if use vllm==0.1.4 | |
| from vllm.model_executor.weight_utils import ( | |
| load_tensor_parallel_weights | |
| ) | |
| from vllm.model_executor.parallel_utils.parallel_state import ( | |
| get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) | |
| tp_size = get_tensor_model_parallel_world_size() | |
| tensor_model_parallel_rank = get_tensor_model_parallel_rank() | |
| q_proj_shard_size = (self.config.hidden_size // tp_size) | |
| kv_proj_shard_size = (self.config.hidden_size // | |
| self.config.num_attention_heads * | |
| getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) // tp_size) | |
| attention_weight_specs = [ | |
| # (weight_name, shard_size, offset) | |
| ("q_proj", q_proj_shard_size, 0), | |
| ("k_proj", kv_proj_shard_size, q_proj_shard_size), | |
| ("v_proj", kv_proj_shard_size, | |
| q_proj_shard_size + kv_proj_shard_size), | |
| ] | |
| state_dict = self.state_dict() | |
| need_to_load = len(state_dict) | |
| loaded = 0 | |
| iterator = custom_hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache) | |
| for name, loaded_weight in iterator: | |
| if "rotary_emb.inv_freq" in name: | |
| continue | |
| if "embed_tokens" in name or "lm_head" in name: | |
| param = state_dict[name] | |
| # Consider padding in the vocab size. | |
| padded_vocab_size = (param.shape[0] * tp_size) | |
| # num_extra_rows = padded_vocab_size - self.config.vocab_size | |
| num_extra_rows = padded_vocab_size - loaded_weight.size(0) | |
| load_size = loaded_weight.size() | |
| extra_rows = torch.empty(num_extra_rows, | |
| loaded_weight.shape[1]) | |
| extra_rows = extra_rows.to(loaded_weight) | |
| loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) | |
| if num_extra_rows > 0: | |
| print(f'Add empty to {num_extra_rows} extra row for {name}') | |
| print(f'Load: {name} | {padded_vocab_size=} | {self.config.vocab_size=} | {num_extra_rows=} | {param.size()=} | {loaded_weight.size()=} | {load_size=}') | |
| is_attention_weight = False | |
| for weight_name, shard_size, offset in attention_weight_specs: | |
| if weight_name not in name or "qkv_proj" in name: | |
| continue | |
| param = state_dict[name.replace(weight_name, "qkv_proj")] | |
| loaded_weight = loaded_weight[ | |
| shard_size * tensor_model_parallel_rank:shard_size * | |
| (tensor_model_parallel_rank + 1)] | |
| param_slice = param.data[offset:offset + shard_size] | |
| assert param_slice.shape == loaded_weight.shape | |
| param_slice.copy_(loaded_weight) | |
| loaded += 1.0 / 3 | |
| is_attention_weight = True | |
| break | |
| if is_attention_weight: | |
| continue | |
| # ! qkv_proj is sharded differently if concatenated into qkv | |
| # qkv: qqqq kkkk vvvv | |
| # lweight: qq0qq1 kk0kk1 vv0vv1 | |
| # q_shard_size: hidden_size // tp_size = qq | |
| # qkv_s0: qq0_kk0_vv0 | |
| # qkv_s1: qq1_kk1_vv1 | |
| if "qkv_proj" in name: | |
| param = state_dict[name] | |
| # loaded_weight | |
| qsize = self.config.hidden_size | |
| kvsize = self.config.hidden_size // self.config.num_attention_heads * getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) | |
| q_offsets = ( | |
| q_proj_shard_size * tensor_model_parallel_rank, | |
| q_proj_shard_size * (tensor_model_parallel_rank + 1) | |
| ) | |
| k_offsets = ( | |
| qsize + kv_proj_shard_size * tensor_model_parallel_rank, | |
| qsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1) | |
| ) | |
| v_offsets = ( | |
| qsize + kvsize + kv_proj_shard_size * tensor_model_parallel_rank, | |
| qsize + kvsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1) | |
| ) | |
| _loaded_weight = torch.cat( | |
| [ | |
| loaded_weight[q_offsets[0]:q_offsets[1]], | |
| loaded_weight[k_offsets[0]:k_offsets[1]], | |
| loaded_weight[v_offsets[0]:v_offsets[1]], | |
| ], 0 | |
| ) | |
| assert param.shape == _loaded_weight.shape, f'{param.shape=} != {_loaded_weight.shape=}' | |
| param.data.copy_(_loaded_weight) | |
| loaded += 1.0 | |
| is_attention_weight = True | |
| if is_attention_weight: | |
| continue | |
| is_gate_up_weight = False | |
| for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): | |
| if weight_name not in name or "gate_up_proj" in name: | |
| continue | |
| param = state_dict[name.replace(weight_name, "gate_up_proj")] | |
| shard_size = param.shape[0] // 2 | |
| loaded_weight = loaded_weight[ | |
| shard_size * tensor_model_parallel_rank:shard_size * | |
| (tensor_model_parallel_rank + 1)] | |
| param_slice = param.data[shard_size * stride_id:shard_size * | |
| (stride_id + 1)] | |
| assert param_slice.shape == loaded_weight.shape | |
| param_slice.copy_(loaded_weight) | |
| loaded += 1.0 / 2 | |
| is_gate_up_weight = True | |
| break | |
| if is_gate_up_weight: | |
| continue | |
| if "gate_up_proj" in name: | |
| param = state_dict[name] | |
| shard_size = param.shape[0] // 2 | |
| intermediate_size = self.config.intermediate_size | |
| g_offsets = ( | |
| shard_size * tensor_model_parallel_rank, | |
| shard_size * (tensor_model_parallel_rank + 1) | |
| ) | |
| u_offsets = ( | |
| intermediate_size + shard_size * tensor_model_parallel_rank, | |
| intermediate_size + shard_size * (tensor_model_parallel_rank + 1) | |
| ) | |
| _loaded_weight = torch.cat( | |
| [ | |
| loaded_weight[g_offsets[0]:g_offsets[1]], | |
| loaded_weight[u_offsets[0]:u_offsets[1]], | |
| ], 0 | |
| ) | |
| assert param.shape == _loaded_weight.shape | |
| param.data.copy_(_loaded_weight) | |
| loaded += 1.0 | |
| is_gate_up_weight = True | |
| if is_gate_up_weight: | |
| continue | |
| param = state_dict[name] | |
| load_tensor_parallel_weights(param, loaded_weight, name, | |
| self._column_parallel_weights, | |
| self._row_parallel_weights, | |
| tensor_model_parallel_rank) | |
| loaded += 1 | |
| if np.abs(loaded - need_to_load) < 0.01: | |
| print(f'WARNING: only {loaded} params loaded out of {need_to_load}') | |
| else: | |
| print(f'Loaded all {loaded} params loaded out of {need_to_load}') | |
| def new_llama_load_weights( | |
| self, | |
| model_name_or_path: str, | |
| cache_dir: Optional[str] = None, | |
| load_format: str = "auto", | |
| revision: Optional[str] = None | |
| ): | |
| # If use newest vllm, not been thoroughly tested yet. | |
| from vllm.model_executor.weight_utils import ( | |
| load_tensor_parallel_weights, hf_model_weights_iterator | |
| ) | |
| from vllm.model_executor.parallel_utils.parallel_state import ( | |
| get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) | |
| if self.quant_config is None: | |
| weight_suffixes = ["weight"] | |
| else: | |
| weight_suffixes = self.quant_config.get_tp_tensor_names() | |
| column_parallel_weights: List[str] = [] | |
| for layer in self._column_parallel_layers: | |
| for suffix in weight_suffixes: | |
| column_parallel_weights.append(f"{layer}.{suffix}") | |
| row_parallel_weights: List[str] = [] | |
| for layer in self._row_parallel_layers: | |
| for suffix in weight_suffixes: | |
| row_parallel_weights.append(f"{layer}.{suffix}") | |
| tp_size = get_tensor_model_parallel_world_size() | |
| tp_rank = get_tensor_model_parallel_rank() | |
| assert tp_size == 1, f'tensorparallel >=2 not allowed. {tp_size}' | |
| q_proj_shard_size = (self.config.hidden_size // tp_size) | |
| num_kv_heads_replicas = max(1, | |
| tp_size // self.config.num_key_value_heads) | |
| num_kv_heads_per_gpu = max(1, | |
| self.config.num_key_value_heads // tp_size) | |
| kv_proj_shard_size = (self.config.hidden_size // | |
| self.config.num_attention_heads * | |
| num_kv_heads_per_gpu) | |
| attention_weight_specs = [ | |
| # (weight_name, shard_size, offset) | |
| ("q_proj", q_proj_shard_size, 0), | |
| ("k_proj", kv_proj_shard_size, q_proj_shard_size), | |
| ("v_proj", kv_proj_shard_size, | |
| q_proj_shard_size + kv_proj_shard_size), | |
| ] | |
| state_dict = self.state_dict() | |
| need_to_load = len(state_dict) | |
| loaded = 0 | |
| for name, loaded_weight in hf_model_weights_iterator( | |
| model_name_or_path, cache_dir, load_format, revision): | |
| if "rotary_emb.inv_freq" in name: | |
| continue | |
| is_packed = False | |
| is_transposed = False | |
| if self.quant_config is not None: | |
| is_packed = self.quant_config.is_packed(name) | |
| is_transposed = self.quant_config.is_transposed(name) | |
| if is_transposed: | |
| loaded_weight = convert_pyslice_to_tensor(loaded_weight) | |
| loaded_weight = loaded_weight.T | |
| is_attention_weight = False | |
| for weight_name, shard_size, offset in attention_weight_specs: | |
| if weight_name not in name or "qkv_proj" in name: | |
| continue | |
| param = state_dict[name.replace(weight_name, "qkv_proj")] | |
| if is_transposed: | |
| param = param.T | |
| if is_packed: | |
| shard_size //= self.quant_config.pack_factor | |
| offset //= self.quant_config.pack_factor | |
| if weight_name in ["k_proj", "v_proj"]: | |
| shard_id = tp_rank // num_kv_heads_replicas | |
| else: | |
| shard_id = tp_rank | |
| loaded_weight = loaded_weight[shard_size * | |
| shard_id:shard_size * | |
| (shard_id + 1)] | |
| param_slice = param.data[offset:offset + shard_size] | |
| assert param_slice.shape == loaded_weight.shape | |
| param_slice.copy_(loaded_weight) | |
| loaded += 1.0 / 3 | |
| is_attention_weight = True | |
| break | |
| if is_attention_weight: | |
| continue | |
| # TODO: need to figure out to do sharding with qkv_proj fused | |
| is_gate_up_weight = False | |
| for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): | |
| if weight_name not in name or "gate_up_proj" in name: | |
| continue | |
| param = state_dict[name.replace(weight_name, "gate_up_proj")] | |
| if is_transposed: | |
| param = param.T | |
| shard_size = param.shape[0] // 2 | |
| loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * | |
| (tp_rank + 1)] | |
| param_slice = param.data[shard_size * stride_id:shard_size * | |
| (stride_id + 1)] | |
| assert param_slice.shape == loaded_weight.shape | |
| param_slice.copy_(loaded_weight) | |
| loaded += 1.0 / 2 | |
| is_gate_up_weight = True | |
| break | |
| if is_gate_up_weight: | |
| continue | |
| # TODO: need to figure out to do sharding with gate_up_proj fused | |
| param = state_dict[name] | |
| if is_transposed: | |
| param = param.T | |
| if "embed_tokens" in name or "lm_head" in name: | |
| load_padded_tensor_parallel_vocab(param, loaded_weight, | |
| tp_rank) | |
| loaded += 1 | |
| continue | |
| load_tensor_parallel_weights(param, loaded_weight, name, | |
| column_parallel_weights, | |
| row_parallel_weights, tp_rank) | |
| loaded += 1 | |
| if np.abs(loaded - need_to_load) < 0.01: | |
| print(f'WARNING: only {loaded} params loaded out of {need_to_load}') | |
| else: | |
| print(f'Loaded all {loaded} params loaded out of {need_to_load}') | |
| # Reassign LlamaForCausalLM.load_weights with llama_load_weights | |
| if not DEBUG: | |
| try: | |
| import vllm | |
| from vllm.model_executor.model_loader import _MODEL_REGISTRY | |
| from vllm.model_executor.models import LlamaForCausalLM | |
| _MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM | |
| if vllm.__version__ == "0.1.4": | |
| LlamaForCausalLM.load_weights = llama_load_weights | |
| else: | |
| LlamaForCausalLM.load_weights = new_llama_load_weights | |
| if DTYPE == "bfloat16": | |
| try: | |
| compute_capability = torch.cuda.get_device_capability() | |
| if compute_capability[0] < 8: | |
| gpu_name = torch.cuda.get_device_name() | |
| print( | |
| "Bfloat16 is only supported on GPUs with compute capability " | |
| f"of at least 8.0. Your {gpu_name} GPU has compute capability " | |
| f"{compute_capability[0]}.{compute_capability[1]}. --> Move to FLOAT16") | |
| DTYPE = "float16" | |
| except Exception as e: | |
| print(f'Unable to obtain compute_capability: {e}') | |
| except Exception as e: | |
| print(f'Failing import and reconfigure VLLM: {str(e)}') | |
| # ! ================================================================== | |
| set_documentation_group("component") | |
| RES_PRINTED = False | |
| def llama_chat_sys_input_seq_constructor(text, sys_prompt=SYSTEM_PROMPT_1, bos_token=BOS_TOKEN, eos_token=EOS_TOKEN): | |
| return f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {text} {E_INST}" | |
| def llama_chat_multiturn_sys_input_seq_constructor( | |
| message: str, | |
| history: List[Tuple[str, str]], | |
| sys_prompt=SYSTEM_PROMPT_1, | |
| bos_token=BOS_TOKEN, | |
| eos_token=EOS_TOKEN, | |
| ): | |
| """ | |
| ``` | |
| <bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos> | |
| <bos>[INST] Prompt [/INST] Answer <eos> | |
| <bos>[INST] Prompt [/INST] | |
| ``` | |
| """ | |
| text = '' | |
| for i, (prompt, res) in enumerate(history): | |
| if i == 0: | |
| text += f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {prompt} {E_INST}" | |
| else: | |
| text += f"{bos_token}{B_INST} {prompt} {E_INST}" | |
| if res is not None: | |
| text += f" {res} {eos_token} " | |
| if len(history) == 0 or text.strip() == '': | |
| text = f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {message} {E_INST}" | |
| else: | |
| text += f"{bos_token}{B_INST} {message} {E_INST}" | |
| return text | |
| class ChatBot(gr.Chatbot): | |
| def _postprocess_chat_messages( | |
| self, chat_message | |
| ): | |
| x = super()._postprocess_chat_messages(chat_message) | |
| if isinstance(x, str): | |
| x = x.strip().replace("\n", "<br>") | |
| return x | |
| from gradio.components import Button | |
| from gradio.events import Dependency, EventListenerMethod | |
| # replace events so that submit button is disabled during generation, if stop_btn not found | |
| # this prevent weird behavior | |
| def _setup_stop_events( | |
| self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency | |
| ) -> None: | |
| event_triggers = event_triggers if isinstance(event_triggers, (list, tuple)) else [event_triggers] | |
| if self.stop_btn and self.is_generator: | |
| if self.submit_btn: | |
| for event_trigger in event_triggers: | |
| event_trigger( | |
| lambda: ( | |
| Button.update(visible=False), | |
| Button.update(visible=True), | |
| ), | |
| None, | |
| [self.submit_btn, self.stop_btn], | |
| api_name=False, | |
| queue=False, | |
| ) | |
| event_to_cancel.then( | |
| lambda: (Button.update(visible=True), Button.update(visible=False)), | |
| None, | |
| [self.submit_btn, self.stop_btn], | |
| api_name=False, | |
| queue=False, | |
| ) | |
| else: | |
| for event_trigger in event_triggers: | |
| event_trigger( | |
| lambda: Button.update(visible=True), | |
| None, | |
| [self.stop_btn], | |
| api_name=False, | |
| queue=False, | |
| ) | |
| event_to_cancel.then( | |
| lambda: Button.update(visible=False), | |
| None, | |
| [self.stop_btn], | |
| api_name=False, | |
| queue=False, | |
| ) | |
| self.stop_btn.click( | |
| None, | |
| None, | |
| None, | |
| cancels=event_to_cancel, | |
| api_name=False, | |
| ) | |
| else: | |
| if self.submit_btn: | |
| for event_trigger in event_triggers: | |
| event_trigger( | |
| lambda: Button.update(interactive=False), | |
| None, | |
| [self.submit_btn], | |
| api_name=False, | |
| queue=False, | |
| ) | |
| event_to_cancel.then( | |
| lambda: Button.update(interactive=True), | |
| None, | |
| [self.submit_btn], | |
| api_name=False, | |
| queue=False, | |
| ) | |
| # upon clear, cancel the submit event as well | |
| if self.clear_btn: | |
| self.clear_btn.click( | |
| lambda: ([], [], None, Button.update(interactive=True)), | |
| None, | |
| [self.chatbot, self.chatbot_state, self.saved_input, self.submit_btn], | |
| queue=False, | |
| api_name=False, | |
| cancels=event_to_cancel, | |
| ) | |
| # TODO: reconfigure clear button as stop and clear button | |
| def _setup_events(self) -> None: | |
| has_on = False | |
| try: | |
| from gradio.events import Dependency, EventListenerMethod, on | |
| has_on = True | |
| except ImportError as ie: | |
| has_on = False | |
| submit_fn = self._stream_fn if self.is_generator else self._submit_fn | |
| if has_on: | |
| # new version | |
| submit_triggers = ( | |
| [self.textbox.submit, self.submit_btn.click] | |
| if self.submit_btn | |
| else [self.textbox.submit] | |
| ) | |
| submit_event = ( | |
| on( | |
| submit_triggers, | |
| self._clear_and_save_textbox, | |
| [self.textbox], | |
| [self.textbox, self.saved_input], | |
| api_name=False, | |
| queue=False, | |
| ) | |
| .then( | |
| self._display_input, | |
| [self.saved_input, self.chatbot_state], | |
| [self.chatbot, self.chatbot_state], | |
| api_name=False, | |
| queue=False, | |
| ) | |
| .then( | |
| submit_fn, | |
| [self.saved_input, self.chatbot_state] + self.additional_inputs, | |
| [self.chatbot, self.chatbot_state], | |
| api_name=False, | |
| ) | |
| ) | |
| self._setup_stop_events(submit_triggers, submit_event) | |
| else: | |
| raise ValueError(f'Better install new gradio version than 3.44.0') | |
| if self.retry_btn: | |
| retry_event = ( | |
| self.retry_btn.click( | |
| self._delete_prev_fn, | |
| [self.chatbot_state], | |
| [self.chatbot, self.saved_input, self.chatbot_state], | |
| api_name=False, | |
| queue=False, | |
| ) | |
| .then( | |
| self._display_input, | |
| [self.saved_input, self.chatbot_state], | |
| [self.chatbot, self.chatbot_state], | |
| api_name=False, | |
| queue=False, | |
| ) | |
| .then( | |
| submit_fn, | |
| [self.saved_input, self.chatbot_state] + self.additional_inputs, | |
| [self.chatbot, self.chatbot_state], | |
| api_name=False, | |
| ) | |
| ) | |
| self._setup_stop_events([self.retry_btn.click], retry_event) | |
| if self.undo_btn: | |
| self.undo_btn.click( | |
| self._delete_prev_fn, | |
| [self.chatbot_state], | |
| [self.chatbot, self.saved_input, self.chatbot_state], | |
| api_name=False, | |
| queue=False, | |
| ).then( | |
| lambda x: x, | |
| [self.saved_input], | |
| [self.textbox], | |
| api_name=False, | |
| queue=False, | |
| ) | |
| # Reconfigure clear_btn to stop and clear text box | |
| # if self.clear_btn: | |
| # self.clear_btn.click( | |
| # lambda: ([], [], None), | |
| # None, | |
| # [self.chatbot, self.chatbot_state, self.saved_input], | |
| # queue=False, | |
| # api_name=False, | |
| # cancels=submit_event, | |
| # ) | |
| # replace | |
| gr.ChatInterface._setup_stop_events = _setup_stop_events | |
| gr.ChatInterface._setup_events = _setup_events | |
| def vllm_abort(self: Any): | |
| from vllm.sequence import SequenceStatus | |
| scheduler = self.llm_engine.scheduler | |
| for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]: | |
| for seq_group in state_queue: | |
| # if seq_group.request_id == request_id: | |
| # Remove the sequence group from the state queue. | |
| state_queue.remove(seq_group) | |
| for seq in seq_group.seqs: | |
| if seq.is_finished(): | |
| continue | |
| scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED) | |
| def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]: | |
| from vllm.outputs import RequestOutput | |
| # Initialize tqdm. | |
| if use_tqdm: | |
| num_requests = self.llm_engine.get_num_unfinished_requests() | |
| pbar = tqdm(total=num_requests, desc="Processed prompts") | |
| # Run the engine. | |
| outputs: Dict[str, RequestOutput] = {} | |
| while self.llm_engine.has_unfinished_requests(): | |
| step_outputs = self.llm_engine.step() | |
| for output in step_outputs: | |
| outputs[output.request_id] = output | |
| if len(outputs) > 0: | |
| yield outputs | |
| def vllm_generate_stream( | |
| self: Any, | |
| prompts: Optional[Union[str, List[str]]] = None, | |
| sampling_params: Optional[Any] = None, | |
| prompt_token_ids: Optional[List[List[int]]] = None, | |
| use_tqdm: bool = False, | |
| ) -> Dict[str, Any]: | |
| """Generates the completions for the input prompts. | |
| NOTE: This class automatically batches the given prompts, considering | |
| the memory constraint. For the best performance, put all of your prompts | |
| into a single list and pass it to this method. | |
| Args: | |
| prompts: A list of prompts to generate completions for. | |
| sampling_params: The sampling parameters for text generation. If | |
| None, we use the default sampling parameters. | |
| prompt_token_ids: A list of token IDs for the prompts. If None, we | |
| use the tokenizer to convert the prompts to token IDs. | |
| use_tqdm: Whether to use tqdm to display the progress bar. | |
| Returns: | |
| A list of `RequestOutput` objects containing the generated | |
| completions in the same order as the input prompts. | |
| """ | |
| from vllm import LLM, SamplingParams | |
| if prompts is None and prompt_token_ids is None: | |
| raise ValueError("Either prompts or prompt_token_ids must be " | |
| "provided.") | |
| if isinstance(prompts, str): | |
| # Convert a single prompt to a list. | |
| prompts = [prompts] | |
| if prompts is not None and prompt_token_ids is not None: | |
| if len(prompts) != len(prompt_token_ids): | |
| raise ValueError("The lengths of prompts and prompt_token_ids " | |
| "must be the same.") | |
| if sampling_params is None: | |
| # Use default sampling params. | |
| sampling_params = SamplingParams() | |
| # Add requests to the engine. | |
| if prompts is not None: | |
| num_requests = len(prompts) | |
| else: | |
| num_requests = len(prompt_token_ids) | |
| for i in range(num_requests): | |
| prompt = prompts[i] if prompts is not None else None | |
| if prompt_token_ids is None: | |
| token_ids = None | |
| else: | |
| token_ids = prompt_token_ids[i] | |
| self._add_request(prompt, sampling_params, token_ids) | |
| # return self._run_engine(use_tqdm) | |
| yield from _vllm_run_engine(self, use_tqdm) | |
| # ! avoid saying | |
| LANG_BLOCK_MESSAGE = """Sorry, the language you have asked is currently not supported. If you have questions in other supported languages, I'll be glad to help. \ | |
| Please also consider clearing the chat box for a better experience.""" | |
| KEYWORD_BLOCK_MESSAGE = "Sorry, I cannot fulfill your request. If you have any unrelated question, I'll be glad to help." | |
| def block_zh( | |
| message: str, | |
| history: List[Tuple[str, str]] = None, | |
| ) -> str: | |
| # relieve history base block | |
| if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history): | |
| return True | |
| elif 'zh' in _detect_lang(message): | |
| print(f'Detect zh: {message}') | |
| return True | |
| else: | |
| return False | |
| def log_responses(history, message, response): | |
| pass | |
| def safety_check(text, history=None, ) -> Optional[str]: | |
| """ | |
| Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content. | |
| This provides an additional security measure to enhance safety and compliance with local regulations. | |
| """ | |
| if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS): | |
| return KEYWORD_BLOCK_MESSAGE | |
| if BLOCK_ZH: | |
| if history is not None: | |
| if block_zh(text, history): | |
| return LANG_BLOCK_MESSAGE | |
| else: | |
| if "zh" in _detect_lang(text): | |
| return LANG_BLOCK_MESSAGE | |
| return None | |
| def chat_response_stream_multiturn( | |
| message: str, | |
| history: List[Tuple[str, str]], | |
| temperature: float, | |
| max_tokens: int, | |
| frequency_penalty: float, | |
| system_prompt: Optional[str] = SYSTEM_PROMPT_1 | |
| ) -> str: | |
| from vllm import LLM, SamplingParams | |
| """Build multi turn | |
| <bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos> | |
| <bos>[INST] Prompt [/INST] Answer <eos> | |
| <bos>[INST] Prompt [/INST] | |
| message is incoming prompt | |
| history don't have the current messauge | |
| """ | |
| global llm, RES_PRINTED | |
| assert llm is not None | |
| assert system_prompt.strip() != '', f'system prompt is empty' | |
| # force removing all | |
| vllm_abort(llm) | |
| temperature = float(temperature) | |
| frequency_penalty = float(frequency_penalty) | |
| max_tokens = int(max_tokens) | |
| message = message.strip() | |
| message_safety = safety_check(message, history=history) | |
| if message_safety is not None: | |
| yield message_safety | |
| return | |
| # history will be appended with message later on | |
| full_prompt = llama_chat_multiturn_sys_input_seq_constructor( | |
| message, history, sys_prompt=system_prompt | |
| ) | |
| sampling_params = SamplingParams( | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| frequency_penalty=frequency_penalty, | |
| stop=['<s>', '</s>', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'] | |
| ) | |
| cur_out = None | |
| for j, gen in enumerate(vllm_generate_stream(llm, full_prompt, sampling_params)): | |
| if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0: | |
| cur_out = cur_out.replace("\\n", "\n") | |
| # optionally check safety, and respond | |
| if STREAM_CHECK_MULTIPLE > 0 and j % STREAM_CHECK_MULTIPLE == 0: | |
| message_safety = safety_check(cur_out, history=None) | |
| if message_safety is not None: | |
| yield message_safety | |
| return | |
| yield cur_out | |
| assert len(gen) == 1, f'{gen}' | |
| item = next(iter(gen.values())) | |
| cur_out = item.outputs[0].text | |
| print(f'@@@@@@@@@@\n{full_prompt}<<<{cur_out}>>>\n@@@@@@@@@@\n') | |
| if cur_out is not None and "\\n" in cur_out: | |
| print(f'double slash-n in cur_out:\n{cur_out}') | |
| cur_out = cur_out.replace("\\n", "\n") | |
| if cur_out is not None: | |
| yield cur_out | |
| message_safety = safety_check(cur_out, history=None) | |
| if message_safety is not None: | |
| yield message_safety | |
| return | |
| if LOG_RESPONSE: | |
| log_responses(history, message, cur_out) | |
| def debug_chat_response_echo( | |
| message: str, | |
| history: List[Tuple[str, str]], | |
| temperature: float = 0.0, | |
| max_tokens: int = 4096, | |
| frequency_penalty: float = 0.4, | |
| system_prompt: str = SYSTEM_PROMPT_1, | |
| ) -> str: | |
| import time | |
| time.sleep(0.5) | |
| yield f"repeat: {message}" | |
| def check_model_path(model_path) -> str: | |
| assert os.path.exists(model_path), f'{model_path} not found' | |
| ckpt_info = "None" | |
| if os.path.isdir(model_path): | |
| if os.path.exists(f'{model_path}/info.txt'): | |
| with open(f'{model_path}/info.txt', 'r') as f: | |
| ckpt_info = f.read() | |
| print(f'Checkpoint info:\n{ckpt_info}\n-----') | |
| else: | |
| print(f'info.txt not found in {model_path}') | |
| print(f'model path dir: {list(os.listdir(model_path))}') | |
| return ckpt_info | |
| def maybe_delete_folder(): | |
| if IS_DELETE_FOLDER and DOWNLOAD_SNAPSHOT: | |
| print(f'DELETE ALL FILES IN {DELETE_FOLDER}') | |
| for filename in os.listdir(DELETE_FOLDER): | |
| file_path = os.path.join(DELETE_FOLDER, filename) | |
| try: | |
| if os.path.isfile(file_path) or os.path.islink(file_path): | |
| os.unlink(file_path) | |
| elif os.path.isdir(file_path): | |
| shutil.rmtree(file_path) | |
| except Exception as e: | |
| print('Failed to delete %s. Reason: %s' % (file_path, e)) | |
| def launch(): | |
| global demo, llm, DEBUG | |
| model_desc = MODEL_DESC | |
| model_path = MODEL_PATH | |
| model_title = MODEL_TITLE | |
| hf_model_name = HF_MODEL_NAME | |
| tensor_parallel = TENSOR_PARALLEL | |
| assert tensor_parallel > 0 , f'{tensor_parallel} invalid' | |
| dtype = DTYPE | |
| sys_prompt = SYSTEM_PROMPT_1 | |
| max_tokens = MAX_TOKENS | |
| temperature = TEMPERATURE | |
| frequence_penalty = FREQUENCE_PENALTY | |
| ckpt_info = "None" | |
| print( | |
| f'Launch config: {tensor_parallel=} / {dtype=} / {max_tokens} | {BLOCK_ZH=} ' | |
| f'\n| model_title=`{model_title}` ' | |
| f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} ' | |
| f'\n| STREAM_CHECK_MULTIPLE={STREAM_CHECK_MULTIPLE} ' | |
| f'\n| DISPLAY_MODEL_PATH={DISPLAY_MODEL_PATH} ' | |
| f'\n| LANG_BLOCK_HISTORY={LANG_BLOCK_HISTORY} ' | |
| f'\n| frequence_penalty={frequence_penalty} ' | |
| f'\n| temperature={temperature} ' | |
| f'\n| hf_model_name={hf_model_name} ' | |
| f'\n| model_path={model_path} ' | |
| f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} ' | |
| f'\n| gpu_memory_utilization={gpu_memory_utilization} ' | |
| f'\n| KEYWORDS={KEYWORDS} ' | |
| f'\n| Sys={SYSTEM_PROMPT_1}' | |
| f'\n| Desc={model_desc}' | |
| ) | |
| if DEBUG: | |
| model_desc += "\n<br>!!!!! This is in debug mode, responses will copy original" | |
| response_fn = debug_chat_response_echo | |
| print(f'Creating in DEBUG MODE') | |
| else: | |
| # ! load the model | |
| if DOWNLOAD_SNAPSHOT: | |
| print(f'Downloading from HF_MODEL_NAME={hf_model_name} -> {model_path}') | |
| if HF_TOKEN is not None: | |
| print(f'Load with HF_TOKEN: {HF_TOKEN}') | |
| snapshot_download(hf_model_name, local_dir=model_path, use_auth_token=True, token=HF_TOKEN) | |
| else: | |
| snapshot_download(hf_model_name, local_dir=model_path) | |
| import vllm | |
| from vllm import LLM | |
| print(F'VLLM: {vllm.__version__}') | |
| ckpt_info = check_model_path(model_path) | |
| print(f'Load path: {model_path} | {ckpt_info}') | |
| if QUANTIZATION == 'awq': | |
| print(F'Load model in int4 quantization') | |
| llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel, gpu_memory_utilization=gpu_memory_utilization, quantization="awq") | |
| else: | |
| llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel, gpu_memory_utilization=gpu_memory_utilization) | |
| try: | |
| print(llm.llm_engine.workers[0].model) | |
| except Exception as e: | |
| print(f'Cannot print model worker: {e}') | |
| try: | |
| llm.llm_engine.scheduler_config.max_model_len = 4096 | |
| llm.llm_engine.scheduler_config.max_num_batched_tokens = 4096 | |
| llm.llm_engine.tokenizer.add_special_tokens = False | |
| except Exception as e: | |
| print(f'Cannot set parameters: {e}') | |
| print(f'Use system prompt:\n{sys_prompt}') | |
| response_fn = chat_response_stream_multiturn | |
| print(F'respond: {response_fn}') | |
| demo = gr.ChatInterface( | |
| response_fn, | |
| chatbot=ChatBot( | |
| label=MODEL_NAME, | |
| bubble_full_width=False, | |
| latex_delimiters=[ | |
| { "left": "$", "right": "$", "display": False}, | |
| { "left": "$$", "right": "$$", "display": True}, | |
| ] | |
| ), | |
| textbox=gr.Textbox(placeholder='Type message', lines=8, max_lines=128, min_width=200), | |
| submit_btn=gr.Button(value='Submit', variant="primary", scale=0), | |
| # ! consider preventing the stop button | |
| stop_btn=None, | |
| title=f"{model_title}", | |
| description=f"{model_desc}", | |
| additional_inputs=[ | |
| gr.Number(value=temperature, label='Temperature (higher -> more random)'), | |
| gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'), | |
| gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens)'), | |
| # ! Remove the system prompt textbox to avoid jailbreaking | |
| # gr.Textbox(value=sys_prompt, label='System prompt', lines=8) | |
| ], | |
| ) | |
| demo.title = MODEL_NAME | |
| with demo: | |
| # gr.Markdown(warning_markdown) | |
| gr.Markdown(cite_markdown) | |
| if DISPLAY_MODEL_PATH: | |
| gr.Markdown(path_markdown.format(model_path=model_path)) | |
| demo.queue() | |
| demo.launch(server_port=PORT) | |
| def main(): | |
| launch() | |
| if __name__ == "__main__": | |
| main() | |