Spaces:
Runtime error
Runtime error
| # Copyright (c) 2023 Microsoft | |
| # Licensed under The MIT License [see LICENSE for details] | |
| import bisect | |
| import re | |
| from collections import defaultdict | |
| from typing import List | |
| import numpy as np | |
| import torch | |
| import nltk | |
| import tiktoken | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModelForCausalLM, | |
| AutoModelForTokenClassification, | |
| AutoTokenizer, | |
| ) | |
| import torch.nn.functional as F | |
| import string | |
| import copy | |
| from torch.utils.data import DataLoader | |
| from .utils import TokenClfDataset, seed_everything, is_begin_of_new_word, replace_added_token, get_pure_token | |
| class PromptCompressor: | |
| """ | |
| PromptCompressor is designed for compressing prompts based on a given language model. | |
| This class initializes with the language model and its configuration, preparing it for prompt compression tasks. | |
| The PromptCompressor class is versatile and can be adapted for various models and specific requirements in prompt processing. | |
| Users can specify different model names and configurations as needed for their particular use case.The architecture is | |
| based on the paper "LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models". Jiang, Huiqiang, Qianhui Wu, | |
| Chin-Yew Lin, Yuqing Yang, and Lili Qiu. "Llmlingua: Compressing prompts for accelerated inference of large language models." | |
| arXiv preprint arXiv:2310.05736 (2023). | |
| Args: | |
| model_name (str, optional): The name of the language model to be loaded. Default is "NousResearch/Llama-2-7b-hf". | |
| device_map (str, optional): The device to load the model onto, e.g., "cuda" for GPU. Default is "cuda". | |
| model_config (dict, optional): A dictionary containing the configuration parameters for the model. Default is an empty dictionary. | |
| open_api_config (dict, optional): A dictionary containing configuration for openai APIs that may be used in conjunction with the model. Default is an empty dictionary. | |
| use_llmlingua2 (bool, optional): Whether to use llmlingua-2 compressor based on the paper | |
| "LLMLingua-2: Context-Aware Data Distillation for Efficient and Faithful Task-Agnostic Prompt Compression". | |
| Zhuoshi Pan, Qianhui Wu, Huiqiang Jiang, Menglin Xia, Xufang Luo, Jue Zhang, Qingwei Lin, Victor Ruhle, Yuqing Yang, Chin-Yew Lin, H. Vicky Zhao, Lili Qiu, Dongmei Zhang. | |
| "LLMLingua-2: Context-Aware Data Distillation for Efficient and Faithful Task-Agnostic Prompt Compression". arXiv preprint arXiv:, | |
| Default is True. | |
| llmlingua2_config (dict, optional): A dictionary containing the configuration parameters for llmlingua-2. Default is | |
| { | |
| "max_batch_size": 50, | |
| "max_force_token": 100, # max number of the tokens which will be forcely preserved | |
| } | |
| Example: | |
| >>> compress_method = PromptCompressor(model_name="xxx/llmlingua-2-xlm-roberta-large-meetingbank", use_llmlingua2=True, ) | |
| >>> context = ["This is the first context sentence.", "Here is another context sentence."] | |
| >>> result = compress_method.compress_prompt(context, use_context_level_filter=True, target_token=5) | |
| >>> print(result["compressed_prompt"]) | |
| # This will print the compressed version of the context. | |
| Note: | |
| The `PromptCompressor` class requires the Hugging Face Transformers library and an appropriate environment to load and run the models. | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str = "NousResearch/Llama-2-7b-hf", | |
| device_map: str = "cuda", | |
| model_config: dict = {}, | |
| open_api_config: dict = {}, | |
| use_llmlingua2: bool = True, | |
| llmlingua2_config: dict = {}, | |
| ): | |
| self.model_name = model_name | |
| self.use_llmlingua2 = use_llmlingua2 | |
| self.retrieval_model = None | |
| self.retrieval_model_name = None | |
| self.open_api_config = open_api_config | |
| self.cache_bos_num = 10 | |
| self.prefix_bos_num = 100 | |
| self.oai_tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") | |
| self.load_model(model_name, device_map, model_config) | |
| if use_llmlingua2: | |
| self.init_llmlingua2(**llmlingua2_config) | |
| def init_llmlingua2( | |
| self, | |
| max_batch_size: int = 50, | |
| max_force_token: int = 100, | |
| ): | |
| seed_everything(42) | |
| self.max_batch_size = max_batch_size | |
| self.max_seq_len = 512 | |
| self.max_force_token = max_force_token | |
| self.special_tokens = set(self.tokenizer.special_tokens_map.values()) | |
| self.added_tokens = [f"[NEW{i}]" for i in range(max_force_token)] | |
| self.tokenizer.add_special_tokens( | |
| {"additional_special_tokens": self.added_tokens} | |
| ) | |
| self.model.resize_token_embeddings(len(self.tokenizer)) | |
| def load_model( | |
| self, model_name: str, device_map: str = "cuda", model_config: dict = {} | |
| ): | |
| trust_remote_code = model_config.get("trust_remote_code", True) | |
| if "trust_remote_code" not in model_config: | |
| model_config["trust_remote_code"] = trust_remote_code | |
| config = AutoConfig.from_pretrained(model_name, **model_config) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, **model_config) | |
| if model_config.get("pad_to_left", True): | |
| tokenizer.padding_side = "left" | |
| tokenizer.pad_token_id = ( | |
| config.pad_token_id if config.pad_token_id else tokenizer.eos_token_id | |
| ) | |
| MODEL_CLASS = ( | |
| AutoModelForTokenClassification | |
| if any("ForTokenClassification" in ar for ar in config.architectures) | |
| else AutoModelForCausalLM | |
| ) | |
| self.device = ( | |
| device_map | |
| if any(key in device_map for key in ["cuda", "cpu", "mps"]) | |
| else "cuda" | |
| ) | |
| if "cuda" in device_map or "cpu" in device_map: | |
| model = MODEL_CLASS.from_pretrained( | |
| model_name, | |
| torch_dtype=model_config.get( | |
| "torch_dtype", "auto" if device_map == "cuda" else torch.float32 | |
| ), | |
| device_map=device_map, | |
| config=config, | |
| ignore_mismatched_sizes=True, | |
| **model_config, | |
| ) | |
| else: | |
| model = MODEL_CLASS.from_pretrained( | |
| model_name, | |
| device_map=device_map, | |
| torch_dtype=model_config.get("torch_dtype", "auto"), | |
| pad_token_id=tokenizer.pad_token_id, | |
| **model_config, | |
| ) | |
| self.tokenizer = tokenizer | |
| self.model = model | |
| self.context_idxs = [] | |
| self.max_position_embeddings = config.max_position_embeddings | |
| def get_ppl( | |
| self, | |
| text: str, | |
| granularity: str = "sentence", | |
| input_ids=None, | |
| attention_mask=None, | |
| past_key_values=None, | |
| return_kv=False, | |
| end=None, | |
| condition_mode: str = "none", | |
| condition_pos_id: int = 0, | |
| ): | |
| if input_ids is None: | |
| tokenized_text = self.tokenizer(text, return_tensors="pt") | |
| input_ids = tokenized_text["input_ids"].to(self.device) | |
| attention_mask = tokenized_text["attention_mask"].to(self.device) | |
| if past_key_values is not None: | |
| past_length = past_key_values[0][0].shape[2] | |
| else: | |
| past_length = 0 | |
| if end is None: | |
| end = input_ids.shape[1] | |
| end = min(end, past_length + self.max_position_embeddings) | |
| with torch.no_grad(): | |
| response = self.model( | |
| input_ids[:, past_length:end], | |
| attention_mask=attention_mask[:, :end], | |
| past_key_values=past_key_values, | |
| use_cache=True, | |
| ) | |
| past_key_values = response.past_key_values | |
| shift_logits = response.logits[..., :-1, :].contiguous() | |
| shift_labels = input_ids[..., past_length + 1 : end].contiguous() | |
| # Flatten the tokens | |
| active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1) | |
| active_logits = shift_logits.view(-1, shift_logits.size(-1))[active] | |
| active_labels = shift_labels.view(-1)[active] | |
| loss_fct = torch.nn.CrossEntropyLoss(reduction="none") | |
| loss = loss_fct(active_logits, active_labels) | |
| if condition_mode == "before": | |
| loss = loss[:condition_pos_id] | |
| elif condition_mode == "after": | |
| loss = loss[condition_pos_id:] | |
| res = loss.mean() if granularity == "sentence" else loss | |
| return (res, past_key_values) if return_kv else res | |
| def __call__(self, *args, **kwargs): | |
| return self.compress_prompt(*args, **kwargs) | |
| def structured_compress_prompt( | |
| self, | |
| context: List[str], | |
| instruction: str = "", | |
| question: str = "", | |
| rate: float = 0.5, | |
| target_token: float = -1, | |
| iterative_size: int = 200, | |
| force_context_ids: List[int] = None, | |
| force_context_number: int = None, | |
| use_sentence_level_filter: bool = False, | |
| use_context_level_filter: bool = True, | |
| use_token_level_filter: bool = True, | |
| keep_split: bool = False, | |
| keep_first_sentence: int = 0, | |
| keep_last_sentence: int = 0, | |
| keep_sentence_number: int = 0, | |
| high_priority_bonus: int = 100, | |
| context_budget: str = "+100", | |
| token_budget_ratio: float = 1.4, | |
| condition_in_question: str = "none", | |
| reorder_context: str = "original", | |
| dynamic_context_compression_ratio: float = 0.0, | |
| condition_compare: bool = False, | |
| add_instruction: bool = False, | |
| rank_method: str = "llmlingua", | |
| concate_question: bool = True, | |
| ): | |
| """ | |
| Compresses the given prompt context based on a specified structure. | |
| Each element of context should be segmented using one or more non-nested '<llmlingua></llmlingua>' tags. | |
| Each '<llmlingua>' tag can include optional parameters 'rate' and 'compress' (e.g., '<llmlingua, rate=0.3, compress=True>'), | |
| indicating the compression rate for that segment. Default values are 'rate=rate' and 'compress=True'. | |
| When 'compress' is set to False, it overrides the 'rate' parameter, resulting in no compression for that segment. | |
| Args: | |
| context (List[str]): List of context strings divided by '<llmlingua></llmlingua>' tags with optional compression settings. | |
| instruction (str, optional): Additional instruction text to be included in the prompt. Default is an empty string. | |
| question (str, optional): A specific question that the prompt is addressing. Default is an empty string. | |
| rate (float, optional): The compression rate is defined the same as in paper "Language Modeling Is Compression". | |
| Delétang, Grégoire, Anian Ruoss, Paul-Ambroise Duquenne, Elliot Catt, Tim Genewein, Christopher Mattern, | |
| Jordi Grau-Moya et al. "Language modeling is compression." arXiv preprint arXiv:2309.10668 (2023): | |
| .. math::\text{Compression Rate} = \frac{\text{Compressed Size}}{\text{Raw Size}} | |
| Default is 0.5. The actual compression rate is generally lower than the specified target, but there can be | |
| fluctuations due to differences in tokenizers. If specified, it should be a float less than or equal | |
| to 1.0, representing the target compression rate. ``rate``, is applicable only within the context-level filter | |
| and the sentence-level filter. In the token-level filter, the rate for each segment overrides the global rate. | |
| However, for segments where no specific rate is defined, the global rate serves as the default value. The final | |
| compression rate of the entire text is a composite result of multiple compression rates applied across different sections. | |
| target_token (float, optional): The global maximum number of tokens to be achieved. Default is -1, indicating no | |
| specific target. The actual number of tokens after compression should generally be less than the specified target_token, | |
| but there can be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as | |
| the sole criterion, overriding the ``rate``. ``target_token``, is applicable only within the context-level | |
| filter and the sentence-level filter. In the token-level filter, the rate for each segment overrides the global target token. | |
| However, for segments where no specific rate is defined, the global rate calculated from global target token serves | |
| as the default value. The final target token of the entire text is a composite result of multiple compression rates | |
| applied across different sections. | |
| iterative_size (int, optional): The number of tokens to consider in each iteration of compression. Default is 200. | |
| force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None. | |
| force_context_number (int, optional): The number of context sections to forcibly include. Default is None. | |
| use_sentence_level_filter (bool, optional): Whether to apply sentence-level filtering in compression. Default is False. | |
| use_context_level_filter (bool, optional): Whether to apply context-level filtering in compression. Default is True. | |
| use_token_level_filter (bool, optional): Whether to apply token-level filtering in compression. Default is True. | |
| keep_split (bool, optional): Whether to preserve the original separators without compression. Default is False. | |
| keep_first_sentence (int, optional): Number of sentences to forcibly preserve from the start of the context. Default is 0. | |
| keep_last_sentence (int, optional): Number of sentences to forcibly preserve from the end of the context. Default is 0. | |
| keep_sentence_number (int, optional): Total number of sentences to forcibly preserve in the compression. Default is 0. | |
| high_priority_bonus (int, optional): Bonus score for high-priority sentences to influence their likelihood of being retained. Default is 100. | |
| context_budget (str, optional): Token budget for the context-level filtering, expressed as a string to indicate flexibility. Default is "+100". | |
| token_budget_ratio (float, optional): Ratio to adjust token budget during sentence-level filtering. Default is 1.4. | |
| condition_in_question (str, optional): Specific condition to apply to question in the context. Default is "none". | |
| reorder_context (str, optional): Strategy for reordering context in the compressed result. Default is "original". | |
| dynamic_context_compression_ratio (float, optional): Ratio for dynamically adjusting context compression. Default is 0.0. | |
| condition_compare (bool, optional): Whether to enable condition comparison during token-level compression. Default is False. | |
| add_instruction (bool, optional): Whether to add the instruction to the prompt prefix. Default is False. | |
| rank_method (str, optional): Method used for ranking elements during compression. Default is "llmlingua". | |
| concate_question (bool, optional): Whether to concatenate the question to the compressed prompt. Default is True. | |
| Returns: | |
| dict: A dictionary containing: | |
| - "compressed_prompt" (str): The resulting compressed prompt. | |
| - "origin_tokens" (int): The original number of tokens in the input. | |
| - "compressed_tokens" (int): The number of tokens in the compressed output. | |
| - "ratio" (str): The compression ratio achieved, calculated as the original token number divided by the token number after compression. | |
| - "rate" (str): The compression rate achieved, in a human-readable format. | |
| - "saving" (str): Estimated savings in GPT-4 token usage. | |
| """ | |
| if not context: | |
| context = [" "] | |
| if isinstance(context, str): | |
| context = [context] | |
| context = [ | |
| self.tokenizer.decode(self.tokenizer(c, add_special_tokens=False).input_ids) | |
| for c in context | |
| ] | |
| context_tokens_length = [self.get_token_length(c) for c in context] | |
| instruction_tokens_length, question_tokens_length = self.get_token_length( | |
| instruction | |
| ), self.get_token_length(question) | |
| if target_token == -1: | |
| target_token = ( | |
| ( | |
| instruction_tokens_length | |
| + question_tokens_length | |
| + sum(context_tokens_length) | |
| ) | |
| * rate | |
| - instruction_tokens_length | |
| - (question_tokens_length if concate_question else 0) | |
| ) | |
| else: | |
| rate = target_token / sum(context_tokens_length) | |
| ( | |
| context, | |
| context_segs, | |
| context_segs_rate, | |
| context_segs_compress, | |
| ) = self.segment_structured_context(context, rate) | |
| return self.compress_prompt( | |
| context, | |
| instruction, | |
| question, | |
| rate, | |
| target_token, | |
| iterative_size, | |
| force_context_ids, | |
| force_context_number, | |
| use_sentence_level_filter, | |
| use_context_level_filter, | |
| use_token_level_filter, | |
| keep_split, | |
| keep_first_sentence, | |
| keep_last_sentence, | |
| keep_sentence_number, | |
| high_priority_bonus, | |
| context_budget, | |
| token_budget_ratio, | |
| condition_in_question, | |
| reorder_context, | |
| dynamic_context_compression_ratio, | |
| condition_compare, | |
| add_instruction, | |
| rank_method, | |
| concate_question, | |
| context_segs=context_segs, | |
| context_segs_rate=context_segs_rate, | |
| context_segs_compress=context_segs_compress, | |
| ) | |
| def compress_prompt( | |
| self, | |
| context: List[str], | |
| instruction: str = "", | |
| question: str = "", | |
| rate: float = 0.5, | |
| target_token: float = -1, | |
| iterative_size: int = 200, | |
| force_context_ids: List[int] = None, | |
| force_context_number: int = None, | |
| use_sentence_level_filter: bool = False, | |
| use_context_level_filter: bool = True, | |
| use_token_level_filter: bool = True, | |
| keep_split: bool = False, | |
| keep_first_sentence: int = 0, | |
| keep_last_sentence: int = 0, | |
| keep_sentence_number: int = 0, | |
| high_priority_bonus: int = 100, | |
| context_budget: str = "+100", | |
| token_budget_ratio: float = 1.4, | |
| condition_in_question: str = "none", | |
| reorder_context: str = "original", | |
| dynamic_context_compression_ratio: float = 0.0, | |
| condition_compare: bool = False, | |
| add_instruction: bool = False, | |
| rank_method: str = "llmlingua", | |
| concate_question: bool = True, | |
| context_segs: List[str] = None, | |
| context_segs_rate: List[float] = None, | |
| context_segs_compress: List[bool] = None, | |
| target_context: int = -1, | |
| context_level_rate: float = 1.0, | |
| context_level_target_token: int = -1, | |
| return_word_label: bool = False, | |
| word_sep: str = "\t\t|\t\t", | |
| label_sep: str = " ", | |
| token_to_word: str = "mean", | |
| force_tokens: List[str] = [], | |
| force_reserve_digit: bool = False, | |
| drop_consecutive: bool = False, | |
| chunk_end_tokens: List[str] = [".", "\n"], | |
| ): | |
| """ | |
| Compresses the given context. | |
| Args: | |
| context (List[str]): List of context strings that form the basis of the prompt. | |
| instruction (str, optional): Additional instruction text to be included in the prompt. Default is an empty string. | |
| question (str, optional): A specific question that the prompt is addressing. Default is an empty string. | |
| rate (float, optional): The maximum compression rate target to be achieved. The compression rate is defined | |
| the same as in paper "Language Modeling Is Compression". Delétang, Grégoire, Anian Ruoss, Paul-Ambroise Duquenne, | |
| Elliot Catt, Tim Genewein, Christopher Mattern, Jordi Grau-Moya et al. "Language modeling is compression." | |
| arXiv preprint arXiv:2309.10668 (2023): | |
| .. math::\text{Compression Rate} = \frac{\text{Compressed Size}}{\text{Raw Size}} | |
| Default is 0.5. The actual compression rate is generally lower than the specified target, but there can be | |
| fluctuations due to differences in tokenizers. If specified, it should be a float less than or equal | |
| to 1.0, representing the target compression rate. | |
| target_token (float, optional): The maximum number of tokens to be achieved. Default is -1, indicating no specific target. | |
| The actual number of tokens after compression should generally be less than the specified target_token, but there can | |
| be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as | |
| the sole criterion, overriding the ``rate``. | |
| iterative_size (int, optional): The number of tokens to consider in each iteration of compression. Default is 200. | |
| force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None. | |
| force_context_number (int, optional): The number of context sections to forcibly include. Default is None. | |
| use_sentence_level_filter (bool, optional): Whether to apply sentence-level filtering in compression. Default is False. | |
| use_context_level_filter (bool, optional): Whether to apply context-level filtering in compression. Default is True. | |
| use_token_level_filter (bool, optional): Whether to apply token-level filtering in compression. Default is True. | |
| keep_split (bool, optional): Whether to preserve the original separators without compression. Default is False. | |
| keep_first_sentence (int, optional): Number of sentences to forcibly preserve from the start of the context. Default is 0. | |
| keep_last_sentence (int, optional): Number of sentences to forcibly preserve from the end of the context. Default is 0. | |
| keep_sentence_number (int, optional): Total number of sentences to forcibly preserve in the compression. Default is 0. | |
| high_priority_bonus (int, optional): Bonus score for high-priority sentences to influence their likelihood of being retained. Default is 100. | |
| context_budget (str, optional): Token budget for the context-level filtering, expressed as a string to indicate flexibility. Default is "+100". | |
| token_budget_ratio (float, optional): Ratio to adjust token budget during sentence-level filtering. Default is 1.4. | |
| condition_in_question (str, optional): Specific condition to apply to question in the context. Default is "none". | |
| reorder_context (str, optional): Strategy for reordering context in the compressed result. Default is "original". | |
| dynamic_context_compression_ratio (float, optional): Ratio for dynamically adjusting context compression. Default is 0.0. | |
| condition_compare (bool, optional): Whether to enable condition comparison during token-level compression. Default is False. | |
| add_instruction (bool, optional): Whether to add the instruction to the prompt prefix. Default is False. | |
| rank_method (str, optional): Method used for ranking elements during compression. Default is "llmlingua". | |
| concate_question (bool, optional): Whether to concatenate the question to the compressed prompt. Default is True. | |
| target_context (int, optional): The maximum number of contexts to be achieved. Default is -1, indicating no specific target. | |
| context_level_rate (float, optional): The minimum compression rate target to be achieved in context level. Default is 1.0. | |
| context_level_target_token (float, optional): The maximum number of tokens to be achieved in context level compression. | |
| Default is -1, indicating no specific target. Only used in the coarse-to-fine compression senario. | |
| force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None. | |
| return_word_label (bool, optional): Whether to return word with corresponding label. Default is False. | |
| word_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition words. Default is "\t\t|\t\t". | |
| label_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition word and label. Default is " ". | |
| token_to_word (str, optional): How to convert token probability to word probability. Default is "mean". | |
| force_tokens (List[str], optional): List of specific tokens to always include in the compressed result. Default is []. | |
| force_reserve_digit (bool, optional): Whether to forcibly reserve tokens that containing digit (0,...,9). Default is False. | |
| drop_consecutive (bool, optinal): Whether to drop tokens which are in 'force_tokens' but appears consecutively in compressed prompt. | |
| Default is False. | |
| chunk_end_tokens (List[str], optinal): The early stop tokens for segmenting chunk. Default is [".", "\n"], | |
| Returns: | |
| dict: A dictionary containing: | |
| - "compressed_prompt" (str): The resulting compressed prompt. | |
| - "compressed_prompt_list" (List[str]): List of the resulting compressed prompt. Only used in llmlingua2. | |
| - "fn_labeled_original_prompt" (str): original words along with their labels | |
| indicating whether to reserve in compressed prompt, in the format (word label_sep label) | |
| Only used in llmlingua2 when return_word_label = True. | |
| - "origin_tokens" (int): The original number of tokens in the input. | |
| - "compressed_tokens" (int): The number of tokens in the compressed output. | |
| - "ratio" (str): The compression ratio achieved, calculated as the original token number divided by the token number after compression. | |
| - "rate" (str): The compression rate achieved, in a human-readable format. | |
| - "saving" (str): Estimated savings in GPT-4 token usage. | |
| """ | |
| if self.use_llmlingua2: | |
| return self.compress_prompt_llmlingua2( | |
| context, | |
| rate=rate, | |
| target_token=target_token, | |
| use_context_level_filter=use_context_level_filter, | |
| use_token_level_filter=use_token_level_filter, | |
| target_context=target_context, | |
| context_level_rate=context_level_rate, | |
| context_level_target_token=context_level_target_token, | |
| force_context_ids=force_context_ids, | |
| return_word_label=return_word_label, | |
| word_sep=word_sep, | |
| label_sep=label_sep, | |
| token_to_word=token_to_word, | |
| force_tokens=force_tokens, | |
| force_reserve_digit=force_reserve_digit, | |
| drop_consecutive=drop_consecutive, | |
| chunk_end_tokens=chunk_end_tokens, | |
| ) | |
| assert ( | |
| rate <= 1.0 | |
| ), "Error: 'rate' must not exceed 1.0. The value of 'rate' indicates compression rate and must be within the range [0, 1]." | |
| if not context: | |
| context = [" "] | |
| if isinstance(context, str): | |
| context = [context] | |
| assert not ( | |
| rank_method == "longllmlingua" and not question | |
| ), "In the LongLLMLingua, it is necessary to set a question." | |
| if condition_compare and "_condition" not in condition_in_question: | |
| condition_in_question += "_condition" | |
| if rank_method == "longllmlingua": | |
| if condition_in_question == "none": | |
| condition_in_question = "after" | |
| elif rank_method == "llmlingua": | |
| condition_in_question = ( | |
| "none" | |
| if "_condition" not in condition_in_question | |
| else "none_condition" | |
| ) | |
| origin_tokens = len( | |
| self.oai_tokenizer.encode( | |
| "\n\n".join([instruction] + context + [question]).strip() | |
| ) | |
| ) | |
| context_tokens_length = [self.get_token_length(c) for c in context] | |
| instruction_tokens_length, question_tokens_length = self.get_token_length( | |
| instruction | |
| ), self.get_token_length(question) | |
| if target_token == -1: | |
| target_token = ( | |
| ( | |
| instruction_tokens_length | |
| + question_tokens_length | |
| + sum(context_tokens_length) | |
| ) | |
| * rate | |
| - instruction_tokens_length | |
| - (question_tokens_length if concate_question else 0) | |
| ) | |
| condition_flag = "_condition" in condition_in_question | |
| condition_in_question = condition_in_question.replace("_condition", "") | |
| if len(context) > 1 and use_context_level_filter: | |
| context, dynamic_ratio, context_used = self.control_context_budget( | |
| context, | |
| context_tokens_length, | |
| target_token, | |
| force_context_ids, | |
| force_context_number, | |
| question, | |
| condition_in_question, | |
| reorder_context=reorder_context, | |
| dynamic_context_compression_ratio=dynamic_context_compression_ratio, | |
| rank_method=rank_method, | |
| context_budget=context_budget, | |
| context_segs=context_segs, | |
| context_segs_rate=context_segs_rate, | |
| context_segs_compress=context_segs_compress, | |
| ) | |
| if context_segs is not None: | |
| context_segs = [context_segs[idx] for idx in context_used] | |
| context_segs_rate = [context_segs_rate[idx] for idx in context_used] | |
| context_segs_compress = [ | |
| context_segs_compress[idx] for idx in context_used | |
| ] | |
| else: | |
| dynamic_ratio = [0.0] * len(context) | |
| segments_info = [] | |
| if use_sentence_level_filter: | |
| context, segments_info = self.control_sentence_budget( | |
| context, | |
| target_token, | |
| keep_first_sentence=keep_first_sentence, | |
| keep_last_sentence=keep_last_sentence, | |
| keep_sentence_number=keep_sentence_number, | |
| high_priority_bonus=high_priority_bonus, | |
| token_budget_ratio=token_budget_ratio, | |
| question=question, | |
| condition_in_question=condition_in_question, | |
| rank_method=rank_method, | |
| context_segs=context_segs, | |
| context_segs_rate=context_segs_rate, | |
| context_segs_compress=context_segs_compress, | |
| ) | |
| elif context_segs is not None: | |
| for context_idx in range(len(context)): | |
| segments_info.append( | |
| [ | |
| (len(seg_text), seg_rate, seg_compress) | |
| for seg_text, seg_rate, seg_compress in zip( | |
| context_segs[context_idx], | |
| context_segs_rate[context_idx], | |
| context_segs_compress[context_idx], | |
| ) | |
| ] | |
| ) | |
| segments_info = [ | |
| self.concate_segment_info(segment_info) for segment_info in segments_info | |
| ] | |
| if condition_flag: | |
| prefix = question + "\n\n" + instruction if add_instruction else question | |
| if ( | |
| self.get_token_length(prefix + "\n\n") + iterative_size * 2 | |
| > self.max_position_embeddings | |
| ): | |
| tokens = self.tokenizer(prefix, add_special_tokens=False).input_ids | |
| prefix = self.tokenizer.decode( | |
| tokens[: self.prefix_bos_num] | |
| + tokens[ | |
| len(tokens) | |
| - self.max_position_embeddings | |
| + 2 | |
| + self.prefix_bos_num | |
| + 2 * iterative_size : | |
| ] | |
| ) | |
| start = self.get_prefix_length(prefix + "\n\n", context[0]) | |
| context = [prefix] + context | |
| else: | |
| start = 0 | |
| if use_token_level_filter: | |
| context = self.iterative_compress_prompt( | |
| context, | |
| target_token, | |
| iterative_size=iterative_size, | |
| keep_split=keep_split, | |
| start=start, | |
| dynamic_ratio=dynamic_ratio, | |
| condition_compare=condition_compare, | |
| segments_info=segments_info, | |
| ) | |
| compressed_prompt = ( | |
| self.tokenizer.batch_decode(context[0])[0] | |
| .replace("<s> ", "") | |
| .replace("<s>", "") | |
| ) | |
| else: | |
| if condition_flag: | |
| context = context[1:] | |
| compressed_prompt = "\n\n".join(context) | |
| res = [] | |
| if instruction: | |
| res.append(instruction) | |
| if compressed_prompt.strip(): | |
| res.append(compressed_prompt) | |
| if question and concate_question: | |
| res.append(question) | |
| compressed_prompt = "\n\n".join(res) | |
| compressed_tokens = len(self.oai_tokenizer.encode(compressed_prompt)) | |
| saving = (origin_tokens - compressed_tokens) * 0.06 / 1000 | |
| ratio = 1 if compressed_tokens == 0 else origin_tokens / compressed_tokens | |
| rate = 1 / ratio | |
| return { | |
| "compressed_prompt": compressed_prompt, | |
| "origin_tokens": origin_tokens, | |
| "compressed_tokens": compressed_tokens, | |
| "ratio": f"{ratio:.1f}x", | |
| "rate": f"{rate * 100:.1f}%", | |
| "saving": f", Saving ${saving:.1f} in GPT-4.", | |
| } | |
| def compress_prompt_llmlingua2( | |
| self, | |
| context: List[str], | |
| rate: float = 0.5, | |
| target_token: int = -1, | |
| use_context_level_filter: bool = False, | |
| use_token_level_filter: bool = True, | |
| target_context: int = -1, | |
| context_level_rate: float = 1.0, | |
| context_level_target_token: int = -1, | |
| force_context_ids: List[int] = [], | |
| return_word_label: bool = False, | |
| word_sep: str = "\t\t|\t\t", | |
| label_sep: str = " ", | |
| token_to_word: str = "mean", | |
| force_tokens: List[str] = [], | |
| force_reserve_digit: bool = False, | |
| drop_consecutive: bool = False, | |
| chunk_end_tokens: List[str] = [".", "\n"], | |
| ): | |
| """ | |
| Compresses the given context, instruction and question. | |
| Args: | |
| context (List[str]): List of context strings that form the basis of the prompt. | |
| rate (float, optional): The minimum compression rate target to be achieved. Default is 0.5. The actual compression rate | |
| generally exceeds the specified target, but there can be fluctuations due to differences in tokenizers. If specified, | |
| it should be a float greater than or equal to 1.0, representing the target compression rate. | |
| target_token (int, optional): The maximum number of tokens to be achieved. Default is -1, indicating no specific target. | |
| The actual number of tokens after compression should generally be less than the specified target_token, but there can | |
| be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as | |
| the sole criterion, overriding the rate. | |
| target_context (int, optional): The maximum number of contexts to be achieved. Default is -1, indicating no specific target. | |
| Only used in the coarse-to-fine compression. | |
| context_level_rate (float, optional): The minimum compression rate target to be achieved in context level. Default is 1.0. | |
| Only used in the coarse-to-fine compression. | |
| context_level_target_token (float, optional): The maximum number of tokens to be achieved in context level compression. | |
| Default is -1, indicating no specific target. Only used in the coarse-to-fine compression senario. | |
| force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None. | |
| return_word_label (bool, optional): Whether to return word with corresponding label. Default is False. | |
| word_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition words. Default is "\t\t|\t\t". | |
| label_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition word and label. Default is " ". | |
| token_to_word (str, optional): How to convert token probability to word probability. Default is "mean". | |
| force_tokens (List[str], optional): List of specific tokens to always include in the compressed result. Default is []. | |
| force_reserve_digit (bool, optional): Whether to forcibly reserve tokens that containing digit (0,...,9). Default is False. | |
| drop_consecutive (bool, optinal): Whether to drop tokens which are in 'force_tokens' but appears consecutively in compressed prompt. | |
| Default is False. | |
| chunk_end_tokens (List[str], optional): The early stop tokens for segmenting chunk. Default is [".", "\n"]. | |
| Returns: | |
| dict: A dictionary containing: | |
| - "compressed_prompt" (str): The resulting compressed prompt. | |
| - "compressed_prompt_list" (List[str]): List of the resulting compressed prompt. | |
| - "fn_labeled_original_prompt" (str): original words along with their labels | |
| indicating whether to reserve in compressed prompt, in the format (word label_sep label) | |
| - "origin_tokens" (int): The original number of tokens in the input. | |
| - "compressed_tokens" (int): The number of tokens in the compressed output. | |
| - "ratio" (str): The compression ratio achieved, in a human-readable format. | |
| - "rate" (str): The compression rate achieved, in a human-readable format. | |
| - "saving" (str): Estimated savings in GPT-4 token usage. | |
| """ | |
| assert len(force_tokens) <= self.max_force_token | |
| token_map = {} | |
| for i, t in enumerate(force_tokens): | |
| if len(self.tokenizer.tokenize(t)) != 1: | |
| token_map[t] = self.added_tokens[i] | |
| chunk_end_tokens = copy.deepcopy(chunk_end_tokens) | |
| for c in chunk_end_tokens: | |
| if c in token_map: | |
| chunk_end_tokens.append(token_map[c]) | |
| chunk_end_tokens = set(chunk_end_tokens) | |
| if type(context) == str: | |
| context = [context] | |
| context = copy.deepcopy(context) | |
| if len(context) == 1 and use_context_level_filter: | |
| use_context_level_filter = False | |
| n_original_token = 0 | |
| context_chunked = [] | |
| for i in range(len(context)): | |
| n_original_token += self.get_token_length(context[i], use_oai_tokenizer=True) | |
| for ori_token, new_token in token_map.items(): | |
| context[i] = context[i].replace(ori_token, new_token) | |
| context_chunked.append(self.__chunk_context(context[i], chunk_end_tokens=chunk_end_tokens)) | |
| if use_context_level_filter: | |
| # want use_context_level_filter but do not specify any parameters in context level? | |
| # we will set context_level_rate = (rate + 1.0) / 2 if specify rate or target_token * 2 if specify target_token | |
| if ( | |
| target_context <= 0 | |
| and context_level_rate >= 1.0 | |
| and context_level_target_token <= 0 | |
| ): | |
| if target_token < 0 and rate < 1.0: | |
| context_level_rate = ( | |
| (rate + 1.0) / 2 if use_token_level_filter else rate | |
| ) | |
| print( | |
| f"set context level compression rate to {context_level_rate}." | |
| ) | |
| if target_token >= 0: | |
| context_level_target_token = ( | |
| target_token * 2 if use_token_level_filter else target_token | |
| ) | |
| print( | |
| f"set context level target token to {context_level_target_token}." | |
| ) | |
| if target_context >= 0: | |
| context_level_rate = min(target_context / len(context), 1.0) | |
| # print(f'override context level compression rate to {context_level_rate} because you specified target_context = {target_context}.') | |
| if context_level_target_token >= 0: | |
| context_level_rate = min( | |
| context_level_target_token / n_original_token, 1.0 | |
| ) | |
| # print(f'override context level compression rate to {context_level_rate} because you specified context_level_target_token = {context_level_target_token}.') | |
| context_probs, context_words = self.__get_context_prob( | |
| context_chunked, | |
| token_to_word=token_to_word, | |
| force_tokens=force_tokens, | |
| token_map=token_map, | |
| force_reserve_digit=force_reserve_digit, | |
| ) | |
| threshold = np.percentile( | |
| context_probs, int(100 * (1 - context_level_rate)) | |
| ) | |
| reserved_context = [] | |
| context_label = [False] * len(context_probs) | |
| for i, p in enumerate(context_probs): | |
| if p >= threshold or ( | |
| force_context_ids is not None and i in force_context_ids | |
| ): | |
| reserved_context.append(context_chunked[i]) | |
| context_label[i] = True | |
| n_reserved_token = 0 | |
| for chunks in reserved_context: | |
| for c in chunks: | |
| n_reserved_token += self.get_token_length(c, use_oai_tokenizer=True) | |
| if target_token >= 0: | |
| rate = min(target_token / n_reserved_token, 1.0) | |
| print( | |
| f"override compression rate to {rate} because you specified target_token = {target_token}." | |
| ) | |
| if use_token_level_filter: | |
| compressed_context, word_list, word_label_list = self.__compress( | |
| reserved_context, | |
| reduce_rate=max(0, 1 - rate), | |
| token_to_word=token_to_word, | |
| force_tokens=force_tokens, | |
| token_map=token_map, | |
| force_reserve_digit=force_reserve_digit, | |
| drop_consecutive=drop_consecutive, | |
| ) | |
| else: | |
| compressed_context, word_list, word_label_list = self.__compress( | |
| reserved_context, | |
| reduce_rate=0, | |
| token_to_word=token_to_word, | |
| force_tokens=force_tokens, | |
| token_map=token_map, | |
| force_reserve_digit=force_reserve_digit, | |
| drop_consecutive=drop_consecutive, | |
| ) | |
| print( | |
| "return the original text because you specify use_token_level_filter=False" | |
| ) | |
| n_compressed_token = 0 | |
| for c in compressed_context: | |
| n_compressed_token += self.get_token_length(c, use_oai_tokenizer=True) | |
| saving = (n_original_token - n_compressed_token) * 0.06 / 1000 | |
| ratio = ( | |
| 1 if n_compressed_token == 0 else n_original_token / n_compressed_token | |
| ) | |
| res = { | |
| "compressed_prompt": "\n\n".join(compressed_context), | |
| "compressed_prompt_list": compressed_context, | |
| "origin_tokens": n_original_token, | |
| "compressed_tokens": n_compressed_token, | |
| "ratio": f"{ratio:.1f}x", | |
| "rate": f"{1 / ratio * 100:.1f}%", | |
| "saving": f", Saving ${saving:.1f} in GPT-4.", | |
| } | |
| if return_word_label: | |
| words = [] | |
| labels = [] | |
| j = 0 | |
| for i in range(len(context)): | |
| if context_label[i]: | |
| words.extend(word_list[j]) | |
| labels.extend(word_label_list[j]) | |
| j += 1 | |
| else: | |
| words.extend(context_words[i]) | |
| labels.extend([0] * len(context_words[i])) | |
| word_label_lines = word_sep.join( | |
| [f"{word}{label_sep}{label}" for word, label in zip(words, labels)] | |
| ) | |
| res["fn_labeled_original_prompt"] = word_label_lines | |
| return res | |
| if target_token > 0: | |
| rate = min(target_token / n_original_token, 1.0) | |
| print( | |
| f"override compression rate to {rate} \ | |
| because you specified target_token = {target_token}." | |
| ) | |
| if use_token_level_filter: | |
| compressed_context, word_list, word_label_list = self.__compress( | |
| context_chunked, | |
| reduce_rate=max(0, 1 - rate), | |
| token_to_word=token_to_word, | |
| force_tokens=force_tokens, | |
| token_map=token_map, | |
| force_reserve_digit=force_reserve_digit, | |
| drop_consecutive=drop_consecutive, | |
| ) | |
| else: | |
| compressed_context, word_list, word_label_list = self.__compress( | |
| context_chunked, | |
| reduce_rate=0, | |
| token_to_word=token_to_word, | |
| force_tokens=force_tokens, | |
| token_map=token_map, | |
| force_reserve_digit=force_reserve_digit, | |
| drop_consecutive=drop_consecutive, | |
| ) | |
| print( | |
| "return the original text because you specify use_token_level_filter=False" | |
| ) | |
| n_compressed_token = 0 | |
| for c in compressed_context: | |
| n_compressed_token += self.get_token_length(c, use_oai_tokenizer=True) | |
| saving = (n_original_token - n_compressed_token) * 0.06 / 1000 | |
| ratio = 1 if n_compressed_token == 0 else n_original_token / n_compressed_token | |
| res = { | |
| "compressed_prompt": "\n\n".join(compressed_context), | |
| "compressed_prompt_list": compressed_context, | |
| "origin_tokens": n_original_token, | |
| "compressed_tokens": n_compressed_token, | |
| "ratio": f"{ratio:.1f}x", | |
| "rate": f"{1 / ratio * 100:.1f}%", | |
| "saving": f", Saving ${saving:.1f} in GPT-4.", | |
| } | |
| if return_word_label: | |
| words = [] | |
| labels = [] | |
| for w_list, l_list in zip(word_list, word_label_list): | |
| words.extend(w_list) | |
| labels.extend(l_list) | |
| # new_words = [] | |
| # new_labels = [] | |
| # for i in range(len(words)): | |
| # word, label = words[i], labels[i] | |
| # if word in string.punctuation: | |
| # if labels[i-1] == 1 and label == 1 and i > 0: | |
| # new_words[-1] += word | |
| # else: | |
| # new_words.append(word) | |
| # new_labels.append(label) | |
| # word_label_lines = word_sep.join([f'{word}{label_sep}{label}' for word, label in zip(new_words, new_labels)]) | |
| word_label_lines = word_sep.join( | |
| [f"{word}{label_sep}{label}" for word, label in zip(words, labels)] | |
| ) | |
| res["fn_labeled_original_prompt"] = word_label_lines | |
| return res | |
| def get_token_length(self, text: str, add_special_tokens: bool = True, use_oai_tokenizer: bool = False): | |
| if use_oai_tokenizer: | |
| return len(self.oai_tokenizer.encode(text)) | |
| else: | |
| return len( | |
| self.tokenizer(text, add_special_tokens=add_special_tokens).input_ids | |
| ) | |
| def get_prefix_length(self, prefix: str, text: str): | |
| possible_prefix_token = max(self.get_token_length(prefix, False) - 3, 1) | |
| full_input_ids = self.tokenizer( | |
| prefix + text[:100], add_special_tokens=False | |
| ).input_ids | |
| for i in range(possible_prefix_token, len(full_input_ids)): | |
| cur_prefix = self.tokenizer.decode(full_input_ids[:i]) | |
| if cur_prefix == prefix: | |
| break | |
| assert self.tokenizer.decode(full_input_ids[i:]) == text[:100] | |
| return i | |
| def get_condition_ppl( | |
| self, | |
| text: str, | |
| question: str, | |
| condition_in_question: str = "none", | |
| granularity: str = "sentence", | |
| ): | |
| if condition_in_question == "none": | |
| return self.get_ppl(text, granularity=granularity) | |
| elif condition_in_question == "before": | |
| return self.get_ppl( | |
| question + text, | |
| granularity=granularity, | |
| condition_mode="after", | |
| condition_pos_id=self.get_token_length(question) - 1, | |
| ) | |
| elif condition_in_question == "after": | |
| return self.get_ppl( | |
| text + question, | |
| granularity=granularity, | |
| condition_mode="after", | |
| condition_pos_id=self.get_token_length(text) - 1, | |
| ) | |
| def get_dynamic_compression_ratio( | |
| self, | |
| context: list, | |
| target_token: float, | |
| iterative_size: int, | |
| dynamic_ratio: list, | |
| start: int, | |
| seg_info: List[List[tuple]] = None, | |
| ): | |
| def get_ratio(base: float, delta: float): | |
| return max(min(1, base + delta), 0) | |
| context_length = [self.get_token_length(ii, False) + 2 for ii in context] | |
| if start: | |
| context_length = context_length[1:] | |
| tau = target_token / (sum(context_length) + 1) | |
| res, idx, last, last_target = [], 0, 1, [] | |
| while idx < len(context_length): | |
| if last + context_length[idx] >= iterative_size: | |
| last_target.append( | |
| (iterative_size - last, get_ratio(tau, dynamic_ratio[idx])) | |
| ) | |
| res.append(last_target) | |
| last = last + context_length[idx] - iterative_size | |
| if last > iterative_size: | |
| k = last // iterative_size | |
| res.extend( | |
| [[(iterative_size, get_ratio(tau, dynamic_ratio[idx]))]] * k | |
| ) | |
| last -= k * iterative_size | |
| last_target = ( | |
| [(last, get_ratio(tau, dynamic_ratio[idx]))] if last else [] | |
| ) | |
| else: | |
| last += context_length[idx] | |
| last_target.append( | |
| (context_length[idx], get_ratio(tau, dynamic_ratio[idx])) | |
| ) | |
| idx += 1 | |
| if last_target: | |
| res.append(last_target) | |
| return res | |
| def get_structured_dynamic_compression_ratio( | |
| self, | |
| context: list, | |
| iterative_size: int, | |
| dynamic_ratio: list, | |
| start: int, | |
| seg_info: List[List[tuple]] = None, | |
| ): | |
| if start: | |
| pure_context = context[1:] | |
| else: | |
| pure_context = context | |
| global_dynamic_rate, global_dynamic_compress, segments = [], [], [] | |
| for context_idx, text in enumerate(pure_context): | |
| text_seen = 0 | |
| for seg_idx, (seg_len, seg_rate, seg_compress) in enumerate( | |
| seg_info[context_idx] | |
| ): | |
| seg_text = text[text_seen : text_seen + seg_len] | |
| if ( | |
| seg_idx == len(seg_info[context_idx]) - 1 | |
| and context_idx != len(pure_context) - 1 | |
| ): | |
| seg_text += "\n\n" | |
| segments.append(seg_text) | |
| if seg_compress: | |
| global_dynamic_rate.append(seg_rate) | |
| else: | |
| global_dynamic_rate.append(1.0) | |
| global_dynamic_compress.append(seg_compress) | |
| text_seen += seg_len | |
| origin_text = "\n\n".join(pure_context) | |
| assert len("".join(segments)) == len(origin_text) | |
| assert len(segments) == len(global_dynamic_rate) == len(global_dynamic_compress) | |
| text_input_ids = self.tokenizer( | |
| "\n\n".join(context), add_special_tokens=False | |
| ).input_ids[start:] | |
| assert self.tokenizer.decode(text_input_ids) == origin_text | |
| dynamic_compression_ratio = self.token_segment( | |
| text_input_ids, | |
| iterative_size, | |
| segments, | |
| global_dynamic_rate, | |
| global_dynamic_compress, | |
| ) | |
| return dynamic_compression_ratio | |
| def token_segment( | |
| self, | |
| text_input_ids: List[int], | |
| iterative_size: int, | |
| segments: List[str], | |
| global_dynamic_rate: List[float], | |
| global_dynamic_compress: List[bool], | |
| ): | |
| decode_window = 3 | |
| seg_idx, seg_seen, token_seen_num, last_rate = 0, 0, 0, -1 | |
| dynamic_compression_rate, local_compresssion_rate = [], [] | |
| for i in range(len(text_input_ids)): | |
| if i < decode_window: | |
| id_pre, id_cur = text_input_ids[:i], text_input_ids[: i + 1] | |
| else: | |
| id_pre, id_cur = ( | |
| text_input_ids[i - decode_window + 1 : i], | |
| text_input_ids[i - decode_window + 1 : i + 1], | |
| ) | |
| cur_word = self.tokenizer.decode(id_cur)[ | |
| len(self.tokenizer.decode(id_pre)) : | |
| ] | |
| cur_word_len = len(cur_word) | |
| if cur_word_len and cur_word_len >= len(segments[seg_idx]) - seg_seen: | |
| possible_rate, possible_compress = [], [] | |
| while ( | |
| cur_word_len and cur_word_len >= len(segments[seg_idx]) - seg_seen | |
| ): | |
| possible_rate.append(global_dynamic_rate[seg_idx]) | |
| possible_compress.append(global_dynamic_compress[seg_idx]) | |
| cur_word_len -= len(segments[seg_idx]) - seg_seen | |
| seg_idx += 1 | |
| seg_seen = 0 | |
| if cur_word_len: | |
| possible_rate.append(global_dynamic_rate[seg_idx]) | |
| possible_compress.append(global_dynamic_compress[seg_idx]) | |
| new_rate = 1.0 if False in possible_compress else min(possible_rate) | |
| else: | |
| new_rate = global_dynamic_rate[seg_idx] | |
| if new_rate != last_rate and i - token_seen_num: | |
| local_compresssion_rate.append((i - token_seen_num, last_rate)) | |
| token_seen_num = i | |
| last_rate = new_rate | |
| seg_seen += cur_word_len | |
| if (i + 1) % iterative_size == 0: | |
| if token_seen_num != i + 1: | |
| local_compresssion_rate.append((i + 1 - token_seen_num, last_rate)) | |
| token_seen_num = i + 1 | |
| dynamic_compression_rate.append(local_compresssion_rate[:]) | |
| local_compresssion_rate = [] | |
| if token_seen_num != len(text_input_ids): | |
| local_compresssion_rate.append( | |
| (len(text_input_ids) - token_seen_num, last_rate) | |
| ) | |
| if local_compresssion_rate != []: | |
| dynamic_compression_rate.append(local_compresssion_rate[:]) | |
| return dynamic_compression_rate | |
| def control_context_budget( | |
| self, | |
| context: List[str], | |
| context_tokens_length: List[int], | |
| target_token: float, | |
| force_context_ids: List[int] = None, | |
| force_context_number: int = None, | |
| question: str = "", | |
| condition_in_question: str = "none", | |
| reorder_context: str = "original", | |
| dynamic_context_compression_ratio: float = 0.0, | |
| rank_method: str = "longllmlingua", | |
| context_budget: str = "+100", | |
| context_segs: List[List[str]] = None, | |
| context_segs_rate: List[List[float]] = None, | |
| context_segs_compress: List[List[bool]] = None, | |
| ): | |
| demostrations_sort = self.get_rank_results( | |
| context, | |
| question, | |
| rank_method, | |
| condition_in_question, | |
| context_tokens_length, | |
| ) | |
| if target_token < 0: | |
| target_token = 100 | |
| target_token = eval("target_token" + context_budget) | |
| res = [] | |
| used = force_context_ids if force_context_ids is not None else [] | |
| if context_segs is not None: | |
| for idx, _ in enumerate(context): | |
| if False in context_segs_compress[idx]: | |
| used.append(idx) | |
| self.context_idxs.append([x for idx, (x, _) in enumerate(demostrations_sort)]) | |
| for idx, _ in demostrations_sort: | |
| if idx >= len(context_tokens_length): | |
| continue | |
| target_token -= context_tokens_length[idx] | |
| if idx not in used: | |
| used.append(idx) | |
| if target_token < 0 or ( | |
| force_context_number is not None and len(res) >= force_context_number | |
| ): | |
| break | |
| original_used = used | |
| if reorder_context == "original": | |
| used = sorted(used) | |
| elif reorder_context == "two_stage": | |
| l, r = [_ for idx, _ in enumerate(used) if idx % 2 == 0], [ | |
| _ for idx, _ in enumerate(used) if idx % 2 == 1 | |
| ] | |
| used = l + r[::-1] | |
| if dynamic_context_compression_ratio > 0: | |
| N = len(used) | |
| dynamic_ratio = [ | |
| i * (abs(dynamic_context_compression_ratio) / (N - 1)) if N > 1 else 0 | |
| for i in range(-(N - 1), N, 2) | |
| ][::-1] | |
| dynamic_ratio_map = {i: j for i, j in zip(original_used, dynamic_ratio)} | |
| dynamic_ratio = [dynamic_ratio_map[i] for i in used] | |
| else: | |
| dynamic_ratio = [0.0] * len(used) | |
| res = [context[idx] for idx in used if idx < len(context)] | |
| return res, dynamic_ratio, used | |
| def control_sentence_budget( | |
| self, | |
| context: List[str], | |
| target_token: float, | |
| keep_first_sentence: int = 0, | |
| keep_last_sentence: int = 0, | |
| keep_sentence_number: int = 0, | |
| high_priority_bonus: int = 100, | |
| token_budget_ratio: float = 1.4, | |
| question: str = "", | |
| condition_in_question: str = "none", | |
| rank_method: str = "longllmlingua", | |
| context_segs: List[List[str]] = None, | |
| context_segs_rate: List[List[float]] = None, | |
| context_segs_compress: List[List[bool]] = None, | |
| ): | |
| def keep_sentence(dem_idx: int, sent_keep: int): | |
| idxs = sorted(dem_g[dem_idx], key=lambda x: sentence_ppl[x])[:sent_keep] | |
| for idx in idxs: | |
| sentence_ppl[idx] += high_priority_bonus | |
| def sync_sentence(segments, text): | |
| seg_num = len(segments) | |
| new_segments = [] | |
| text_seen = 0 | |
| seg_idx, cur_seg_seen = 0, 0 | |
| for i, s in enumerate(text): | |
| while seg_idx < seg_num and s != segments[seg_idx][cur_seg_seen]: | |
| if cur_seg_seen < len(segments[seg_idx]) - 1: | |
| cur_seg_seen += 1 | |
| continue | |
| new_segments.append(text[text_seen:i]) | |
| text_seen = i | |
| seg_idx += 1 | |
| cur_seg_seen = 0 | |
| cur_seg_seen += 1 | |
| if seg_idx == seg_num: | |
| break | |
| if cur_seg_seen == len(segments[seg_idx]): | |
| new_segments.append(text[text_seen : i + 1]) | |
| text_seen = i + 1 | |
| seg_idx += 1 | |
| cur_seg_seen = 0 | |
| if text_seen < len(text): | |
| new_segments.append(text[text_seen:]) | |
| assert len("".join(new_segments)) == len(text) | |
| return new_segments | |
| sentences = [nltk.sent_tokenize(c) for c in context] | |
| dem_g, s2de, idx = defaultdict(set), defaultdict(int), 0 | |
| for idx_d, s in enumerate(sentences): | |
| for _ in s: | |
| dem_g[idx_d].add(idx) | |
| s2de[idx] = idx_d | |
| idx += 1 | |
| if context_segs is not None: | |
| context_segs = [ | |
| sync_sentence(s, "".join(c)) for s, c in zip(context_segs, sentences) | |
| ] | |
| sen2seg_ratio = {} | |
| idx = 0 | |
| for idx_d, sentences_each_context in enumerate(sentences): | |
| segments_length = [len(s) for s in context_segs[idx_d]] | |
| seg_idx, cur_seg_seen = 0, 0 | |
| for sentence in sentences_each_context: | |
| sentence_seg_ratio = [] | |
| remain = len(sentence) | |
| while remain: | |
| if segments_length[seg_idx] - cur_seg_seen <= remain: | |
| new_seg_len = segments_length[seg_idx] - cur_seg_seen | |
| sentence_seg_ratio.append( | |
| ( | |
| new_seg_len, | |
| context_segs_rate[idx_d][seg_idx], | |
| context_segs_compress[idx_d][seg_idx], | |
| ) | |
| ) | |
| seg_idx += 1 | |
| cur_seg_seen = 0 | |
| remain -= new_seg_len | |
| else: | |
| sentence_seg_ratio.append( | |
| ( | |
| remain, | |
| context_segs_rate[idx_d][seg_idx], | |
| context_segs_compress[idx_d][seg_idx], | |
| ) | |
| ) | |
| cur_seg_seen += remain | |
| remain = 0 | |
| sen2seg_ratio[idx] = sentence_seg_ratio | |
| idx += 1 | |
| context_sentences = [s for ii in sentences for s in ii] | |
| sentence_tokens_length = [ | |
| self.get_token_length(sentence) for sentence in context_sentences | |
| ] | |
| N = len(context_sentences) | |
| flags = list(range(len(context_sentences))) | |
| if len(sentence_tokens_length) == 1: | |
| return context | |
| if rank_method == "longllmlingua": | |
| sentence_ppl = [ | |
| self.get_condition_ppl(sentence, question, condition_in_question) | |
| .cpu() | |
| .numpy() | |
| .item() | |
| for sentence in context_sentences | |
| ] | |
| if keep_first_sentence: | |
| sentence_ppl[:keep_first_sentence] = [ | |
| ii + high_priority_bonus | |
| for ii in sentence_ppl[:keep_first_sentence] | |
| ] | |
| if keep_last_sentence: | |
| sentence_ppl[-keep_last_sentence:] = [ | |
| ii + high_priority_bonus | |
| for ii in sentence_ppl[-keep_last_sentence:] | |
| ] | |
| if keep_sentence_number: | |
| for dem_idx in range(len(sentences)): | |
| keep_sentence(dem_idx, keep_sentence_number) | |
| sort_direct = -1 if condition_in_question == "none" else 1 | |
| sent_sort = sorted( | |
| enumerate(sentence_ppl), key=lambda x: sort_direct * x[1] | |
| ) | |
| else: | |
| sent_sort = self.get_rank_results( | |
| context_sentences, | |
| question, | |
| rank_method, | |
| condition_in_question, | |
| [0] * len(context_sentences), | |
| ) | |
| sentence_flags = [False] * N | |
| if target_token < 0: | |
| target_token = 100 | |
| target_token *= token_budget_ratio | |
| res = [] | |
| for idx, _ in sent_sort: | |
| idx = flags[idx] | |
| target_token -= sentence_tokens_length[idx] | |
| sentence_flags[idx] = True | |
| if target_token < 0: | |
| break | |
| if context_segs is not None: | |
| for idx in range(N): | |
| preserved = [sen_seg_info[2] for sen_seg_info in sen2seg_ratio[idx]] | |
| if False in preserved: | |
| sentence_flags[idx] = True | |
| idx = 0 | |
| res = [] | |
| new_segments_info = [] | |
| for s in sentences: | |
| tmp = [jj for ii, jj in enumerate(s) if sentence_flags[idx + ii]] | |
| res.append("".join(tmp)) | |
| if context_segs is not None: | |
| segment_ratio = [] | |
| for ii in range(len(s)): | |
| if sentence_flags[idx + ii]: | |
| segment_ratio.extend(sen2seg_ratio[idx + ii]) | |
| new_segments_info.append(segment_ratio) | |
| idx += len(s) | |
| if context_segs is not None: | |
| new_segments_info = [ | |
| self.concate_segment_info(segment_info) | |
| for segment_info in new_segments_info | |
| ] | |
| return res, new_segments_info | |
| def get_compressed_input( | |
| self, | |
| loss, | |
| input_ids, | |
| attention_mask, | |
| end=200, | |
| iterative_size=200, | |
| threshold=0.5, | |
| keep_flag=None, | |
| split_token_id: int = 13, | |
| start: int = 0, | |
| self_loss=None, | |
| self_input_ids=None, | |
| self_attention_mask=None, | |
| ): | |
| if self_loss is not None: | |
| need_idx = torch.concat( | |
| [ | |
| loss[:start] > 0, | |
| self_loss[: loss[start:].shape[0]] - loss[start:] > threshold, | |
| loss[:1] > 0, | |
| ] | |
| ) | |
| else: | |
| need_idx = torch.concat([loss > threshold, loss[:1] > 0]) | |
| need_idx[end:] = 1 | |
| need_idx[: end - iterative_size] = 1 | |
| loss = loss[need_idx[:-1]] | |
| if self_loss is not None: | |
| if need_idx.shape[0] < self_loss.shape[0] + start + 1: | |
| need_idx = torch.cat( | |
| [ | |
| need_idx, | |
| torch.ones( | |
| self_loss.shape[0] - need_idx.shape[0] + start + 1, | |
| dtype=torch.bool, | |
| ).to(need_idx.device), | |
| ] | |
| ) | |
| self_loss = self_loss[need_idx[start:-1]] | |
| if need_idx.shape[0] < input_ids.shape[1]: | |
| need_idx = torch.cat( | |
| [ | |
| need_idx, | |
| torch.ones( | |
| input_ids.shape[1] - need_idx.shape[0], dtype=torch.bool | |
| ).to(need_idx.device), | |
| ] | |
| ) | |
| elif need_idx.shape[0] > input_ids.shape[1]: | |
| need_idx = need_idx[: input_ids.shape[1]] | |
| if keep_flag is not None: | |
| need_idx[keep_flag == 1] = 1 | |
| last = -1 | |
| if keep_flag is not None: | |
| for ii in range(max(0, end - iterative_size), end): | |
| if need_idx[ii] != 1: | |
| continue | |
| now = input_ids[0][ii].detach().cpu().item() | |
| if ( | |
| now == split_token_id | |
| and last == split_token_id | |
| and keep_flag[ii].detach().cpu().item() == 0 | |
| ): | |
| need_idx[ii] = 0 | |
| else: | |
| last = now | |
| compressed_input_ids = input_ids[attention_mask == 1][need_idx].unsqueeze(0) | |
| compressed_attention_mask = attention_mask[attention_mask == 1][ | |
| need_idx | |
| ].unsqueeze(0) | |
| if self_loss is not None: | |
| self_compressed_input_ids = self_input_ids[self_attention_mask == 1][ | |
| need_idx[start:] | |
| ].unsqueeze(0) | |
| self_compressed_attention_mask = self_attention_mask[ | |
| self_attention_mask == 1 | |
| ][need_idx[start:]].unsqueeze(0) | |
| else: | |
| self_compressed_input_ids, self_compressed_attention_mask = None, None | |
| if keep_flag is not None: | |
| if len(keep_flag) > len(need_idx): | |
| keep_flag = torch.cat( | |
| [ | |
| keep_flag[:start], | |
| keep_flag[start : len(need_idx) + start][need_idx], | |
| keep_flag[start + len(need_idx) :], | |
| ] | |
| ) | |
| else: | |
| keep_flag = keep_flag[need_idx] | |
| end -= (need_idx[:end] == 0).sum() | |
| return ( | |
| compressed_input_ids, | |
| compressed_attention_mask, | |
| keep_flag, | |
| end, | |
| loss, | |
| self_loss, | |
| self_compressed_input_ids, | |
| self_compressed_attention_mask, | |
| ) | |
| def get_estimate_threshold_base_distribution( | |
| self, ppl, ratio: float, condition_flag: bool = False | |
| ): | |
| if ratio == 1.0: | |
| return float("-inf") | |
| ppl = ppl[ppl != 10000] | |
| target_token = max(0, min(len(ppl) - 1, int(len(ppl) * ratio) - 1)) | |
| return ( | |
| ppl.sort(descending=not condition_flag) | |
| .values[target_token] | |
| .detach() | |
| .cpu() | |
| .item() | |
| ) | |
| def iterative_compress_prompt( | |
| self, | |
| context: List[str], | |
| target_token: float, | |
| iterative_size: int = 200, | |
| keep_split: bool = False, | |
| split_token_id: int = 13, | |
| start: int = 0, | |
| dynamic_ratio: list = None, | |
| condition_compare: bool = False, | |
| segments_info: List[List[tuple]] = None, | |
| ): | |
| if segments_info is None or segments_info == []: | |
| iterative_ratios = self.get_dynamic_compression_ratio( | |
| context, target_token, iterative_size, dynamic_ratio, start | |
| ) | |
| else: | |
| iterative_ratios = self.get_structured_dynamic_compression_ratio( | |
| context, iterative_size, dynamic_ratio, start, segments_info | |
| ) | |
| context = "\n\n".join(context) | |
| tokenized_text = self.tokenizer( | |
| context, return_tensors="pt", add_special_tokens=False | |
| ) | |
| input_ids = tokenized_text["input_ids"].to(self.device) | |
| attention_mask = tokenized_text["attention_mask"].to(self.device) | |
| N = (attention_mask == 1).sum() | |
| compressed_input_ids, compressed_attention_mask = input_ids, attention_mask | |
| if condition_compare: | |
| self_input_ids, self_attention_mask = ( | |
| input_ids[:, start:], | |
| attention_mask[:, start:], | |
| ) | |
| self_compressed_input_ids, self_compressed_attention_mask = ( | |
| self_input_ids, | |
| self_attention_mask, | |
| ) | |
| end = min(iterative_size + start, compressed_input_ids.shape[1]) | |
| threshold, keep_flag = None, None | |
| if keep_split: | |
| input_ids_numpy = input_ids.cpu().detach().numpy()[0] | |
| N = len(input_ids_numpy) | |
| keep_flag = [ | |
| int( | |
| ( | |
| ii > 0 | |
| and input_ids_numpy[ii] == split_token_id | |
| and input_ids_numpy[ii - 1] == split_token_id | |
| ) | |
| or ( | |
| ii < N - 1 | |
| and input_ids_numpy[ii] == split_token_id | |
| and input_ids_numpy[ii + 1] == split_token_id | |
| ) | |
| ) | |
| for ii in range(N) | |
| ] | |
| keep_flag = torch.tensor(keep_flag).to(self.device) | |
| past_key_values, past_loss, ready_end = None, None, 0 | |
| self_past_key_values, self_past_loss, self_ready_end = None, None, 0 | |
| pop_compressed_input_ids, pop_self_compressed_input_ids = None, None | |
| idx = 0 | |
| while end <= compressed_input_ids.shape[1]: | |
| if end > self.max_position_embeddings and past_key_values is not None: | |
| # KV-Cache Compression | |
| e, s = end - self.max_position_embeddings, min( | |
| self.cache_bos_num + start, self.max_position_embeddings | |
| ) | |
| if pop_compressed_input_ids is None: | |
| pop_compressed_input_ids = compressed_input_ids[:, :e] | |
| else: | |
| pop_compressed_input_ids = torch.cat( | |
| [pop_compressed_input_ids, compressed_input_ids[:, :e]], dim=-1 | |
| ) | |
| compressed_input_ids = compressed_input_ids[:, e:] | |
| compressed_attention_mask = compressed_attention_mask[:, e:] | |
| past_key_values = [ | |
| [ | |
| torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2), | |
| torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2), | |
| ] | |
| for k, v in past_key_values | |
| ] | |
| if keep_flag is not None: | |
| keep_flag = keep_flag[e:] | |
| end, ready_end = end - e, ready_end - e | |
| if condition_compare: | |
| s = min(s, self_past_key_values[0][0].shape[2] - e) | |
| self_ready_end -= e | |
| if pop_self_compressed_input_ids is None: | |
| pop_self_compressed_input_ids = self_compressed_input_ids[:, :e] | |
| else: | |
| pop_self_compressed_input_ids = torch.cat( | |
| [ | |
| pop_self_compressed_input_ids, | |
| self_compressed_input_ids[:, :e], | |
| ], | |
| dim=-1, | |
| ) | |
| self_compressed_input_ids = self_compressed_input_ids[:, e:] | |
| self_compressed_attention_mask = self_compressed_attention_mask[ | |
| :, e: | |
| ] | |
| self_past_key_values = [ | |
| [ | |
| torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2), | |
| torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2), | |
| ] | |
| for k, v in self_past_key_values | |
| ] | |
| loss, past_key_values = self.get_ppl( | |
| "", | |
| "token", | |
| compressed_input_ids, | |
| compressed_attention_mask, | |
| past_key_values=past_key_values, | |
| return_kv=True, | |
| end=end if idx else None, | |
| ) | |
| if loss.shape[0] == 0: | |
| break | |
| if past_loss is not None: | |
| if end - 1 > len(past_loss): | |
| past_loss = torch.cat( | |
| [past_loss, torch.zeros_like(loss)[: end - 1 - len(past_loss)]] | |
| ) | |
| past_loss[ready_end : end - 1] = loss | |
| loss = past_loss | |
| else: | |
| past_loss = loss | |
| if idx: | |
| past_key_values = [ | |
| [k[:, :, : end - iterative_size], v[:, :, : end - iterative_size]] | |
| for k, v in past_key_values | |
| ] | |
| else: | |
| past_key_values = None | |
| if condition_compare: | |
| self_loss, self_past_key_values = self.get_ppl( | |
| "", | |
| "token", | |
| self_compressed_input_ids, | |
| self_compressed_attention_mask, | |
| past_key_values=self_past_key_values, | |
| return_kv=True, | |
| end=end - start if idx else None, | |
| ) | |
| if self_past_loss is not None: | |
| if end - start - 1 > len(self_past_loss): | |
| self_past_loss = torch.cat( | |
| [ | |
| self_past_loss, | |
| torch.zeros_like(self_loss)[ | |
| : end - 1 - start - len(self_past_loss) | |
| ], | |
| ] | |
| ) | |
| self_past_loss[self_ready_end : end - start - 1] = self_loss | |
| self_loss = self_past_loss | |
| else: | |
| self_past_loss = self_loss | |
| if idx: | |
| self_past_key_values = [ | |
| [ | |
| k[:, :, : end - iterative_size - start], | |
| v[:, :, : end - iterative_size - start], | |
| ] | |
| for k, v in self_past_key_values | |
| ] | |
| else: | |
| self_past_key_values = None | |
| self_ready_end = ( | |
| end - start - iterative_size if not (start and idx == 0) else 0 | |
| ) | |
| ready_end = end - iterative_size if not (start and idx == 0) else 0 | |
| for delta_end, ratio in iterative_ratios[idx]: | |
| loss = past_loss | |
| if condition_compare: | |
| self_loss = self_past_loss | |
| threshold = self.get_estimate_threshold_base_distribution( | |
| self_loss[: loss[start:].shape[0]] - loss[start:], ratio, False | |
| ) | |
| else: | |
| threshold = self.get_estimate_threshold_base_distribution( | |
| loss, ratio, False | |
| ) | |
| ( | |
| compressed_input_ids, | |
| compressed_attention_mask, | |
| keep_flag, | |
| end, | |
| past_loss, | |
| self_past_loss, | |
| self_compressed_input_ids, | |
| self_compressed_attention_mask, | |
| ) = self.get_compressed_input( | |
| loss, | |
| compressed_input_ids, | |
| compressed_attention_mask, | |
| end - iterative_size + delta_end, | |
| iterative_size=delta_end, | |
| threshold=threshold, | |
| keep_flag=keep_flag, | |
| split_token_id=split_token_id, | |
| start=start, | |
| self_loss=self_loss if condition_compare else None, | |
| self_input_ids=( | |
| self_compressed_input_ids if condition_compare else None | |
| ), | |
| self_attention_mask=( | |
| self_compressed_attention_mask if condition_compare else None | |
| ), | |
| ) | |
| end += iterative_size | |
| idx += 1 | |
| if pop_compressed_input_ids is not None: | |
| compressed_input_ids = torch.cat( | |
| [pop_compressed_input_ids, compressed_input_ids], dim=-1 | |
| ) | |
| return compressed_input_ids[:, start:], compressed_attention_mask[:, start:] | |
| def recover( | |
| self, | |
| original_prompt: str, | |
| compressed_prompt: str, | |
| response: str, | |
| ): | |
| def match_from_compressed(response_word): | |
| response_input_ids = self.tokenizer( | |
| response_word, add_special_tokens=False | |
| )["input_ids"] | |
| response_set, response_c = set(response_input_ids), defaultdict(list) | |
| for idx in range(M): | |
| if original_input_ids[idx] in response_set: | |
| response_c[original_input_ids[idx]].append(idx) | |
| res, res_min, res_c = None, float("inf"), 1 | |
| n = len(response_input_ids) | |
| for l in response_c[response_input_ids[0]]: | |
| x, y, c = 0, l, 1 | |
| for x in range(1, n): | |
| idx = bisect.bisect_right(response_c[response_input_ids[x]], y) | |
| if ( | |
| idx >= len(response_c[response_input_ids[x]]) | |
| or response_c[response_input_ids[x]][idx] - y > 10 | |
| ): | |
| continue | |
| c += 1 | |
| y = response_c[response_input_ids[x]][idx] | |
| if c > res_c: | |
| res_c = c | |
| res_min = y - l + 1 | |
| res = (l, y + 1) | |
| elif c == res_c and y - l + 1 < res_min: | |
| res_min = y - l + 1 | |
| res = (l, y + 1) | |
| if res is None: | |
| return response_word | |
| # while l > 0 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"): | |
| # l -= 1 | |
| # while r < M - 1 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"): | |
| # l -= 1 | |
| return self.tokenizer.decode(original_input_ids[res[0] : res[1]]) | |
| response_words = response.split(" ") | |
| original_input_ids = self.tokenizer(original_prompt, add_special_tokens=False)[ | |
| "input_ids" | |
| ] | |
| N, M = len(response_words), len(original_input_ids) | |
| recovered_response_words = [] | |
| l = 0 | |
| while l < N: | |
| if response_words[l] not in compressed_prompt: | |
| recovered_response_words.append(response_words[l]) | |
| l += 1 | |
| continue | |
| r = l | |
| while ( | |
| r + 1 < N and " ".join(response_words[l : r + 2]) in compressed_prompt | |
| ): | |
| r += 1 | |
| match_words = match_from_compressed(" ".join(response_words[l : r + 1])) | |
| recovered_response_words.append(match_words) | |
| l = r + 1 | |
| return " ".join(recovered_response_words) | |
| def get_rank_results( | |
| self, | |
| context: list, | |
| question: str, | |
| rank_method: str, | |
| condition_in_question: str, | |
| context_tokens_length: list, | |
| ): | |
| def get_distance_bm25(corpus, query): | |
| from rank_bm25 import BM25Okapi | |
| tokenized_corpus = [doc.split(" ") for doc in corpus] | |
| bm25 = BM25Okapi(tokenized_corpus) | |
| tokenized_query = query.split(" ") | |
| doc_scores = bm25.get_scores(tokenized_query) | |
| idx = [(ii, 0) for ii in (-doc_scores).argsort()] | |
| return idx | |
| def get_distance_gzip(corpus, query): | |
| def get_score(x, y): | |
| cx, cy = len(gzip.compress(x.encode())), len(gzip.compress(y.encode())) | |
| cxy = len(gzip.compress(f"{x} {y}".encode())) | |
| return (cxy - min(cx, cy)) / max(cx, cy) | |
| import gzip | |
| doc_scores = [get_score(doc, query) for doc in corpus] | |
| idx = [(ii, 0) for ii in np.argsort(doc_scores)] | |
| return idx | |
| def get_distance_sentbert(corpus, query): | |
| from sentence_transformers import SentenceTransformer, util | |
| if self.retrieval_model is None or self.retrieval_model_name != rank_method: | |
| self.retrieval_model = SentenceTransformer("multi-qa-mpnet-base-dot-v1") | |
| self.retrieval_model_name = rank_method | |
| doc_embeds = self.retrieval_model.encode(corpus) | |
| query = self.retrieval_model.encode(query) | |
| doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) | |
| idx = [(ii, 0) for ii in np.argsort(doc_scores)] | |
| return idx | |
| def get_distance_openai(corpus, query): | |
| import openai | |
| from sentence_transformers import util | |
| openai.api_key = self.open_api_config.get("api_key", "") | |
| openai.api_base = self.open_api_config.get( | |
| "api_base", "https://api.openai.com/v1" | |
| ) | |
| openai.api_type = self.open_api_config.get("api_type", "open_ai") | |
| openai.api_version = self.open_api_config.get("api_version", "2023-05-15") | |
| engine = self.open_api_config.get("engine", "text-embedding-ada-002") | |
| def get_embed(text): | |
| return openai.Embedding.create( | |
| input=[text.replace("\n", " ")], engine=engine | |
| )["data"][0]["embedding"] | |
| doc_embeds = [get_embed(i) for i in corpus] | |
| query = get_embed(query) | |
| doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) | |
| idx = [(ii, 0) for ii in np.argsort(doc_scores)] | |
| return idx | |
| def get_distance_sentbert_bge(corpus, query): | |
| from sentence_transformers import SentenceTransformer, util | |
| if self.retrieval_model is None or self.retrieval_model_name != rank_method: | |
| self.retrieval_model = SentenceTransformer("BAAI/bge-large-en-v1.5") | |
| self.retrieval_model_name = rank_method | |
| doc_embeds = self.retrieval_model.encode( | |
| [i for i in corpus], normalize_embeddings=True | |
| ) | |
| query = self.retrieval_model.encode(query, normalize_embeddings=True) | |
| doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) | |
| idx = [(ii, 0) for ii in np.argsort(doc_scores)] | |
| return idx | |
| def get_distance_bge_ranker(corpus, query): | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| pairs = [[i, query] for i in corpus] | |
| if self.retrieval_model is None or self.retrieval_model_name != rank_method: | |
| tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large") | |
| model = ( | |
| AutoModelForSequenceClassification.from_pretrained( | |
| "BAAI/bge-reranker-large" | |
| ) | |
| .eval() | |
| .to(self.device) | |
| ) | |
| self.retrieval_model = [tokenizer, model] | |
| self.retrieval_model_name = rank_method | |
| with torch.no_grad(): | |
| inputs = self.retrieval_model[0]( | |
| pairs, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=512, | |
| ).to(self.device) | |
| scores = ( | |
| self.retrieval_model[1](**inputs, return_dict=True) | |
| .logits.view( | |
| -1, | |
| ) | |
| .float() | |
| ) | |
| idx = [(ii, 0) for ii in np.argsort(-scores.cpu())] | |
| return idx | |
| def get_distance_bge_llmembedder(corpus, query): | |
| from transformers import AutoModel, AutoTokenizer | |
| if self.retrieval_model is None or self.retrieval_model_name != rank_method: | |
| tokenizer = AutoTokenizer.from_pretrained("BAAI/llm-embedder") | |
| model = ( | |
| AutoModel.from_pretrained("BAAI/llm-embedder") | |
| .eval() | |
| .to(self.device) | |
| ) | |
| self.retrieval_model = [tokenizer, model] | |
| self.retrieval_model_name = rank_method | |
| instruction_qa_query = ( | |
| "Represent this query for retrieving relevant documents: " | |
| ) | |
| instruction_qa_key = "Represent this document for retrieval: " | |
| queries = [instruction_qa_query + query for _ in corpus] | |
| keys = [instruction_qa_key + key for key in corpus] | |
| with torch.no_grad(): | |
| query_inputs = self.retrieval_model[0]( | |
| queries, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=512, | |
| ).to(self.device) | |
| key_inputs = self.retrieval_model[0]( | |
| keys, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=512, | |
| ).to(self.device) | |
| query_outputs = self.retrieval_model[1](**query_inputs) | |
| key_outputs = self.retrieval_model[1](**key_inputs) | |
| # CLS pooling | |
| query_embeddings = query_outputs.last_hidden_state[:, 0] | |
| key_embeddings = key_outputs.last_hidden_state[:, 0] | |
| # Normalize | |
| query_embeddings = torch.nn.functional.normalize( | |
| query_embeddings, p=2, dim=1 | |
| ) | |
| key_embeddings = torch.nn.functional.normalize( | |
| key_embeddings, p=2, dim=1 | |
| ) | |
| similarity = query_embeddings @ key_embeddings.T | |
| idx = [(ii, 0) for ii in np.argsort(-similarity[0].cpu())] | |
| return idx | |
| def get_distance_jinza(corpus, query): | |
| from numpy.linalg import norm | |
| from transformers import AutoModel | |
| def cos_sim(a, b): | |
| return (a @ b.T) / (norm(a) * norm(b)) | |
| if self.retrieval_model is None or self.retrieval_model_name != rank_method: | |
| model = ( | |
| AutoModel.from_pretrained( | |
| "jinaai/jina-embeddings-v2-base-en", trust_remote_code=True | |
| ) | |
| .eval() | |
| .to(self.device) | |
| ) | |
| self.retrieval_model = model | |
| self.retrieval_model_name = rank_method | |
| doc_embeds = self.retrieval_model.encode(corpus) | |
| query = self.retrieval_model.encode(query) | |
| doc_scores = cos_sim(doc_embeds, query) | |
| idx = [(ii, 0) for ii in np.argsort(-doc_scores)] | |
| return idx | |
| def get_distance_voyageai(corpus, query): | |
| import voyageai | |
| from sentence_transformers import util | |
| voyageai.api_key = self.open_api_config.get("voyageai_api_key", "") | |
| def get_embed(text): | |
| return voyageai.get_embedding(text, model="voyage-01") | |
| doc_embeds = [get_embed(i) for i in corpus] | |
| query = get_embed(query) | |
| doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) | |
| idx = [(ii, 0) for ii in np.argsort(doc_scores)] | |
| return idx | |
| def get_distance_cohere(corpus, query): | |
| import cohere | |
| api_key = self.open_api_config.get("cohere_api_key", "") | |
| co = cohere.Client(api_key) | |
| results = co.rerank( | |
| model="rerank-english-v2.0", query=query, documents=corpus, top_n=20 | |
| ) | |
| c_map = {jj: ii for ii, jj in enumerate(corpus)} | |
| doc_rank = [c_map[ii.document["text"]] for ii in results] | |
| idx = [(ii, 0) for ii in doc_rank] | |
| return idx | |
| def get_distance_longllmlingua(corpus, query): | |
| context_ppl = [ | |
| self.get_condition_ppl( | |
| d, | |
| query | |
| + " We can get the answer to this question in the given documents.", | |
| condition_in_question, | |
| ) | |
| - dl * 2 / 250 * 0 | |
| for d, dl in zip(corpus, context_tokens_length) | |
| ] | |
| sort_direct = -1 if condition_in_question == "none" else 1 | |
| ys = sorted(enumerate(context_ppl), key=lambda x: sort_direct * x[1]) | |
| return ys | |
| method = None | |
| if rank_method == "bm25": | |
| method = get_distance_bm25 | |
| elif rank_method == "gzip": | |
| method = get_distance_gzip | |
| elif rank_method == "sentbert": | |
| method = get_distance_sentbert | |
| elif rank_method == "openai": | |
| method = get_distance_openai | |
| elif rank_method in ["longllmlingua", "llmlingua"]: | |
| method = get_distance_longllmlingua | |
| elif rank_method == "bge": | |
| method = get_distance_sentbert_bge | |
| elif rank_method == "bge_reranker": | |
| method = get_distance_bge_ranker | |
| elif rank_method == "bge_llmembedder": | |
| method = get_distance_bge_llmembedder | |
| elif rank_method == "jinza": | |
| method = get_distance_jinza | |
| elif rank_method == "voyageai": | |
| method = get_distance_voyageai | |
| elif rank_method == "cohere": | |
| method = get_distance_cohere | |
| return method(context, question) | |
| def segment_structured_context( | |
| self, | |
| context: List[str], | |
| global_rate: float, | |
| ): | |
| new_context, context_segs, context_segs_rate, context_segs_compress = ( | |
| [], | |
| [], | |
| [], | |
| [], | |
| ) | |
| for text in context: | |
| if not text.startswith("<llmlingua"): | |
| text = "<llmlingua>" + text | |
| if not text.endswith("</llmlingua>"): | |
| text = text + "</llmlingua>" | |
| # Regular expression to match <llmlingua, rate=x, compress=y>content</llmlingua>, allowing rate and compress in any order | |
| pattern = r"<llmlingua\s*(?:,\s*rate\s*=\s*([\d\.]+))?\s*(?:,\s*compress\s*=\s*(True|False))?\s*(?:,\s*rate\s*=\s*([\d\.]+))?\s*(?:,\s*compress\s*=\s*(True|False))?\s*>([^<]+)</llmlingua>" | |
| matches = re.findall(pattern, text) | |
| # Extracting segment contents | |
| segments = [match[4] for match in matches] | |
| # Extracting rate and compress, considering their possible positions | |
| segs_rate = [ | |
| float(match[0]) if match[0] else (float(match[2]) if match[2] else None) | |
| for match in matches | |
| ] | |
| segs_compress = [ | |
| ( | |
| match[1] == "True" | |
| if match[1] | |
| else (match[3] == "True" if match[3] else None) | |
| ) | |
| for match in matches | |
| ] | |
| segs_compress = [ | |
| compress if compress is not None else True for compress in segs_compress | |
| ] | |
| segs_rate = [ | |
| rate if rate else (global_rate if compress else 1.0) | |
| for rate, compress in zip(segs_rate, segs_compress) | |
| ] | |
| assert ( | |
| len(segments) == len(segs_rate) == len(segs_compress) | |
| ), "The number of segments, rates, and compress flags should be the same." | |
| assert all( | |
| seg_rate <= 1.0 for seg_rate in segs_rate | |
| ), "Error: 'rate' must not exceed 1.0. The value of 'rate' indicates compression rate and must be within the range [0, 1]." | |
| new_context.append("".join(segments)) | |
| context_segs.append(segments) | |
| context_segs_rate.append(segs_rate) | |
| context_segs_compress.append(segs_compress) | |
| return new_context, context_segs, context_segs_rate, context_segs_compress | |
| def concate_segment_info( | |
| self, | |
| segment_info: List[List[tuple]], | |
| ): | |
| new_segment_info = [] | |
| for i, (seg_len, seg_ratio, seg_compress) in enumerate(segment_info): | |
| if ( | |
| new_segment_info | |
| and new_segment_info[-1][1] == seg_ratio | |
| and new_segment_info[-1][2] == seg_compress | |
| ): | |
| new_segment_info[-1] = ( | |
| new_segment_info[-1][0] + seg_len, | |
| seg_ratio, | |
| seg_compress, | |
| ) | |
| else: | |
| new_segment_info.append((seg_len, seg_ratio, seg_compress)) | |
| return new_segment_info | |
| def __get_context_prob( | |
| self, | |
| context_list: list, | |
| token_to_word="mean", | |
| force_tokens: List[str]=[], | |
| token_map: dict={}, | |
| force_reserve_digit: bool=False, | |
| ): | |
| chunk_list = [] | |
| for chunks in context_list: | |
| for c in chunks: | |
| chunk_list.append(c) | |
| dataset = TokenClfDataset( | |
| chunk_list, tokenizer=self.tokenizer, max_len=self.max_seq_len | |
| ) | |
| dataloader = DataLoader( | |
| dataset, batch_size=self.max_batch_size, shuffle=False, drop_last=False | |
| ) | |
| chunk_probs = [] | |
| chunk_words = [] | |
| with torch.no_grad(): | |
| for batch in dataloader: | |
| ids = batch["ids"].to(self.device, dtype=torch.long) | |
| mask = batch["mask"].to(self.device, dtype=torch.long) == 1 | |
| outputs = self.model(input_ids=ids, attention_mask=mask) | |
| loss, logits = outputs.loss, outputs.logits | |
| probs = F.softmax(logits, dim=-1) | |
| for j in range(ids.shape[0]): | |
| _probs = probs[j, :, 1] | |
| _ids = ids[j] | |
| _mask = mask[j] | |
| active_probs = torch.masked_select(_probs, _mask) | |
| active_ids = torch.masked_select(_ids, _mask) | |
| tokens = self.tokenizer.convert_ids_to_tokens( | |
| active_ids.squeeze().tolist() | |
| ) | |
| token_probs = [prob for prob in active_probs.cpu().numpy()] | |
| ( | |
| words, | |
| valid_token_probs, | |
| valid_token_probs_no_force, | |
| ) = self.__merge_token_to_word( | |
| tokens, | |
| token_probs, | |
| force_tokens=force_tokens, | |
| token_map=token_map, | |
| force_reserve_digit=force_reserve_digit, | |
| ) | |
| word_probs_no_force = self.__token_prob_to_word_prob( | |
| valid_token_probs_no_force, convert_mode=token_to_word | |
| ) | |
| if "xlm-roberta-large" in self.model_name: | |
| for i in range(len(words)): | |
| words[i] = words[i].lstrip("▁") | |
| chunk_words.append(words) | |
| chunk_probs.append(word_probs_no_force) | |
| prev_idx = 0 | |
| context_probs = [] | |
| context_words = [] | |
| for chunk_list in context_list: | |
| n_chunk = len(chunk_list) | |
| context_probs.append([]) | |
| context_words.append([]) | |
| for i in range(n_chunk): | |
| context_probs[-1].extend(chunk_probs[prev_idx + i]) | |
| context_words[-1].extend(chunk_words[prev_idx + i]) | |
| prev_idx = prev_idx + n_chunk | |
| context_probs = [sum(probs) / len(probs) for probs in context_probs] | |
| return context_probs, context_words | |
| def __chunk_context(self, origin_text, chunk_end_tokens): | |
| origin_list = [] | |
| origin_tokens = self.tokenizer.tokenize(origin_text) | |
| n = len(origin_tokens) | |
| st = 0 | |
| while st < n: | |
| if st + self.max_seq_len > n - 1: | |
| chunk = self.tokenizer.convert_tokens_to_string(origin_tokens[st:n]) | |
| origin_list.append(chunk) | |
| break | |
| else: | |
| ed = st + self.max_seq_len | |
| for j in range(0, ed - st): | |
| if origin_tokens[ed - j] in chunk_end_tokens: | |
| ed = ed - j | |
| break | |
| chunk = self.tokenizer.convert_tokens_to_string( | |
| origin_tokens[st : ed + 1] | |
| ) | |
| origin_list.append(chunk) | |
| st = ed + 1 | |
| return origin_list | |
| def __merge_token_to_word(self, tokens, token_probs, force_tokens, token_map, force_reserve_digit): | |
| words = [] | |
| word_probs = [] | |
| word_probs_no_force = [] | |
| for token, prob in zip(tokens, token_probs): | |
| if token in self.special_tokens: | |
| continue | |
| # add a new word | |
| elif is_begin_of_new_word(token, self.model_name, force_tokens, token_map): | |
| pure_token = get_pure_token(token, self.model_name) | |
| prob_no_force = prob | |
| if pure_token in force_tokens or pure_token in set(token_map.values()): | |
| prob=1.0 | |
| token = replace_added_token(token, token_map) | |
| words.append(token) | |
| word_probs.append( | |
| [ | |
| 1.0 | |
| if force_reserve_digit | |
| and bool(re.search(r"\d", token)) | |
| else prob | |
| ] | |
| ) | |
| word_probs_no_force.append([prob_no_force]) | |
| # concatenate with previous token | |
| else: | |
| pure_token = get_pure_token(token, self.model_name) | |
| words[-1] += pure_token | |
| word_probs[-1].append( | |
| 1.0 | |
| if force_reserve_digit | |
| and bool(re.search(r"\d", token)) | |
| else prob | |
| ) | |
| word_probs_no_force[-1].append(prob_no_force) | |
| return words, word_probs, word_probs_no_force | |
| def __token_prob_to_word_prob(self, token_probs, convert_mode="mean"): | |
| if convert_mode == "mean": | |
| word_probs = [sum(p) / len(p) for p in token_probs] | |
| elif convert_mode == "first": | |
| word_probs = [p[0] for p in token_probs] | |
| else: | |
| raise NotImplementedError() | |
| return word_probs | |
| def __compress( | |
| self, | |
| context_list: list, | |
| reduce_rate: float=0.5, | |
| token_to_word: str="mean", | |
| force_tokens: List[str]=[], | |
| token_map: dict={}, | |
| force_reserve_digit: bool=False, | |
| drop_consecutive: bool=False, | |
| ): | |
| def split_string_to_words(input_string): | |
| pattern = r'\b\w+\b|[<>=/!@#$%^&*()?":{}|\\`~;_+-]' | |
| result = re.findall(pattern, input_string) | |
| return result | |
| # print(force_tokens, token_map, force_reserve_digit, drop_consecutive) | |
| if reduce_rate <= 0: | |
| words, word_labels = [], [] | |
| for i in range(len(context_list)): | |
| chunk_list = context_list[i] | |
| chunk_words = [] | |
| chunk_word_labels = [] | |
| for j in range(len(chunk_list)): | |
| # replace to original token | |
| for ori_token, new_token in token_map.items(): | |
| chunk_list[j] = chunk_list[j].replace(new_token, ori_token) | |
| ws = split_string_to_words(chunk_list[j]) | |
| chunk_words.extend(ws) | |
| chunk_word_labels.extend([1 for _ in range(len(ws))]) | |
| context_list[i] = "".join(chunk_list) | |
| words.append(chunk_words) | |
| word_labels.append(chunk_word_labels) | |
| return context_list, words, word_labels | |
| chunk_list = [] | |
| for chunks in context_list: | |
| for c in chunks: | |
| chunk_list.append(c) | |
| dataset = TokenClfDataset( | |
| chunk_list, tokenizer=self.tokenizer, max_len=self.max_seq_len | |
| ) | |
| dataloader = DataLoader( | |
| dataset, batch_size=self.max_batch_size, shuffle=False, drop_last=False | |
| ) | |
| compressed_chunk_list = [] | |
| word_list = [] | |
| word_label_list = [] | |
| with torch.no_grad(): | |
| for batch in dataloader: | |
| ids = batch["ids"].to(self.device, dtype=torch.long) | |
| mask = batch["mask"].to(self.device, dtype=torch.long) == 1 | |
| outputs = self.model(input_ids=ids, attention_mask=mask) | |
| loss, logits = outputs.loss, outputs.logits | |
| probs = F.softmax(logits, dim=-1) | |
| for j in range(ids.shape[0]): | |
| chunk_probs = probs[j, :, 1] | |
| chunk_ids = ids[j] | |
| chunk_mask = mask[j] | |
| active_probs = torch.masked_select(chunk_probs, chunk_mask) | |
| active_ids = torch.masked_select(chunk_ids, chunk_mask) | |
| tokens = self.tokenizer.convert_ids_to_tokens( | |
| active_ids.squeeze().tolist() | |
| ) | |
| token_probs = [prob for prob in active_probs.cpu().numpy()] | |
| words, valid_token_probs, _ = self.__merge_token_to_word( | |
| tokens=tokens, | |
| token_probs=token_probs, | |
| force_tokens=force_tokens, | |
| token_map=token_map, | |
| force_reserve_digit=force_reserve_digit, | |
| ) | |
| word_probs = self.__token_prob_to_word_prob( | |
| valid_token_probs, convert_mode=token_to_word | |
| ) | |
| if drop_consecutive: | |
| threshold = np.percentile(word_probs, int(100 * reduce_rate)) | |
| is_token_between = False | |
| prev = None | |
| for i, (word, word_prob) in enumerate(zip(words, word_probs)): | |
| if word in force_tokens: | |
| if is_token_between: | |
| is_token_between = False | |
| elif not is_token_between and word == prev: | |
| word_probs[i] = 0.0 | |
| prev = word | |
| else: | |
| is_token_between |= word_prob > threshold | |
| # calculate compression ratio w.r.t. gpt-4 tokenizer | |
| new_token_probs = [] | |
| for word, word_prob in zip(words, word_probs): | |
| num_token = len(self.oai_tokenizer.encode(word)) | |
| new_token_probs.extend([word_prob for _ in range(num_token)]) | |
| threshold = np.percentile( | |
| new_token_probs, int(100 * reduce_rate + 1) | |
| ) | |
| keep_words = [] | |
| word_labels = [] | |
| assert len(words) == len(word_probs) | |
| for word, word_porb in zip(words, word_probs): | |
| if word_porb > threshold: | |
| if ( | |
| drop_consecutive | |
| and word in force_tokens | |
| and len(keep_words) > 0 | |
| and keep_words[-1] == word | |
| ): | |
| word_labels.append(0) | |
| else: | |
| keep_words.append(word) | |
| word_labels.append(1) | |
| else: | |
| word_labels.append(0) | |
| keep_str = self.tokenizer.convert_tokens_to_string(keep_words) | |
| if "xlm-roberta-large" in self.model_name: | |
| for i in range(len(words)): | |
| words[i] = words[i].lstrip("▁") | |
| compressed_chunk_list.append(keep_str) | |
| word_list.append(words[:]) | |
| word_label_list.append(word_labels[:]) | |
| compressed_context_list = [] | |
| original_word_list = [] | |
| original_word_label_list = [] | |
| prev_idx = 0 | |
| for chunk_list in context_list: | |
| n_chunk = len(chunk_list) | |
| compressed_context_list.append( | |
| "".join(compressed_chunk_list[prev_idx : prev_idx + n_chunk]) | |
| ) | |
| original_word_list.append([]) | |
| original_word_label_list.append([]) | |
| for i in range(n_chunk): | |
| original_word_list[-1].extend(word_list[prev_idx + i]) | |
| original_word_label_list[-1].extend(word_label_list[prev_idx + i]) | |
| prev_idx = prev_idx + n_chunk | |
| return compressed_context_list, original_word_list, original_word_label_list | |