Gemini
fix: Resolve runtime errors and timeout issues\n\n- Create writable directories for logs and Hugging Face cache.\n- Set TRANSFORMERS_CACHE environment variable.\n- Fix SyntaxWarning in lpm_kernel/utils.py.\n- Disable ChromaDB telemetry to prevent posthog errors.
e09e6c3
import logging
from enum import Enum
import tiktoken
import re
from typing import Any, Optional, Union, Collection, AbstractSet, Literal, List
from langchain.text_splitter import TextSplitter
import random
import string
from itertools import chain
import json
from lpm_kernel.configs.logging import get_train_process_logger
logger = get_train_process_logger()
class IntentType(Enum):
Emotion = "Emotion"
Knowledge = "Knowledge"
def select_language_desc(
preferred_language,
default_desc="Identify the language of the provided Hint. Your response must be in the same language.",
):
custom_desc = "You must respond in {}."
if isinstance(preferred_language, str) and "/" in preferred_language:
native, es = preferred_language.split("/")
logging.info(f"Native: {native}, ES: {es}")
return custom_desc.format(es)
else:
logging.info(
"Error: preferred_language is not in the correct format. It should be 'native/es'."
)
return default_desc
def cal_upperbound(
model_limit: int = 4096,
generage_limit: int = 512,
tolerance: int = 500,
raw: str = "",
model_name: str = "gpt-3.5-turbo",
) -> int:
"""
:param model_limit: Maximum token count for the underlying model call
:param tolerance: Error tolerance buffer
:param raw: system prompt and raw content
:return:
"""
if model_name is not None:
if model_name in tiktoken.model.MODEL_TO_ENCODING:
enc = tiktoken.encoding_for_model(model_name)
logging.info(f"Successfully initialized tokenizer for model: {model_name}")
else:
enc = tiktoken.get_encoding("cl100k_base")
logging.warning(f"Model '{model_name}' doesn't have a corresponding tokenizer, falling back to default: cl100k_base")
else:
enc = tiktoken.get_encoding("cl100k_base")
logging.info(f"No model specified, using default tokenizer: cl100k_base")
raw_token = len(enc.encode(raw))
upper_bound = model_limit - raw_token - tolerance - generage_limit
if upper_bound < 0:
logging.info(f"raw content is too long: {raw_token}")
return 0
return upper_bound
def equidistant_filter(chunks, separator, filtered_chunks_n=6):
# Select the first and last two chunks, sample the remaining chunks evenly from the middle
gap = (len(chunks) - 2) / (filtered_chunks_n - 2)
indexes = [
int(gap * i)
for i in range(int(len(chunks) / gap) + 1)
if (gap * i < len(chunks) - 2)
]
filtered_chunks = [chunks[i] for i in indexes]
filtered_chunks.append(separator.join(chunks[-2:]))
return filtered_chunks
def tab_or_space_replacement(match):
# If there is a tab character in the matched string, replace it with a single tab, otherwise replace it with a single space
return "\t" if "\t" in match.group() else " "
def text_filter(text: str) -> str:
pattern_tab_space = "[ \t]{3,}"
pattern_wordwrap = "[\n\f\r\v]{3,}"
# Replace when encountering three or more spaces or tabs
replaced_text = re.sub(pattern_tab_space, tab_or_space_replacement, text)
# When there are multiple consecutive \n (newline), \f (form feed), \r (carriage return), \v (vertical tab), replace them with 2 original newlines
replaced_text = re.sub(pattern_wordwrap, "\n\n", replaced_text)
return replaced_text
ALLOW_SPECIAL_TOKEN = {"<|endofprompt|>", "<|endoftext|>"}
def find_sublist_indices(main_list, sublist):
indices = []
length = len(sublist)
for i in range(len(main_list) - length + 1):
if main_list[i : i + length] == sublist:
indices.append((i, i + length))
return indices
class TokenTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at tokens."""
def __init__(
self,
encoding_name: str = "cl100k_base",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = ALLOW_SPECIAL_TOKEN,
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
try:
import tiktoken
except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to for TokenTextSplitter. "
"Please it install it with `pip install tiktoken`."
)
# create a GPT-3 encoder instance
if model_name is not None:
if model_name in tiktoken.model.MODEL_TO_ENCODING:
enc = tiktoken.encoding_for_model(model_name)
logging.info(f"Successfully initialized tokenizer for model: {model_name}")
else:
enc = tiktoken.get_encoding(encoding_name)
logging.warning(f"Model '{model_name}' doesn't have a corresponding tokenizer, falling back to default: {encoding_name}")
else:
enc = tiktoken.get_encoding(encoding_name)
logging.info(f"No model specified, using default tokenizer: {encoding_name}")
self._tokenizer = enc
self._allowed_special = allowed_special
self._disallowed_special = disallowed_special
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# Filter content with a large number of whitespace characters in the input text to increase the proportion of effective content within chunks
text = text_filter(text)
splits = []
input_ids = self._tokenizer.encode(
text,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
start_idx = 0
while start_idx < len(input_ids):
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
s = self._tokenizer.decode(chunk_ids).strip()
if s:
s = self._cut_meaningless_head_tail(s)
if s:
splits.append(s)
start_idx += self._chunk_size - self._chunk_overlap
logging.debug("finished split_text(): %s splits", len(splits))
return splits
def _cut_meaningless_head_tail(self, text: str) -> str:
# Only split when there are multiple newlines, as parsing of PDF/Word often contains false newlines
sentences = re.split(r"\. |! |\? |。|!|?|\n+ *\n+", text)
if len(sentences) < 2:
return text
head = sentences[0]
body = ". ".join(sentences[1:-1])
tail = sentences[-1]
head_len = len(
self._tokenizer.encode(
body,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
)
body_len = len(
self._tokenizer.encode(
body,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
)
tail_len = len(
self._tokenizer.encode(
tail,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
)
parts = []
# Use length to roughly estimate the impact of discarding the tail; if the impact is not significant, discard it
# Rough estimate: Chinese 20 tokens, 8 characters; English 10 tokens, 30 characters
if head_len >= 20 or len(head) >= 30:
parts.append(head)
if body_len > 0:
parts.append(body)
if tail_len >= 20 or len(tail) >= 30:
parts.append(tail)
res = "\n".join(parts)
logger.info(
"_cut_meaningless_tail() removes redundant sentence tails from chunks, before cut: %s characters, after cut: %s characters",
len(text),
len(res),
)
return res
def chunk_filter(
chunks, filter, filtered_chunks_n=6, separator="\n", spacer="\n……\n……\n……\n"
):
if len(chunks) <= filtered_chunks_n:
return separator.join(chunks)
return spacer.join(filter(chunks, separator, filtered_chunks_n))
def get_safe_content_turncate(content, model_name="gpt-3.5-turbo", max_tokens=3300):
if model_name is not None:
if model_name in tiktoken.model.MODEL_TO_ENCODING:
enc = tiktoken.encoding_for_model(model_name)
logging.info(f"Successfully initialized tokenizer for model: {model_name}")
else:
enc = tiktoken.get_encoding("cl100k_base")
logging.warning(f"Model '{model_name}' doesn't have a corresponding tokenizer, falling back to default: cl100k_base")
else:
enc = tiktoken.get_encoding("cl100k_base")
logging.info(f"No model specified, using default tokenizer: cl100k_base")
logging.warning(
"get_safe_content_turncate(): current model maximum input length is %s, current input length is %s",
max_tokens,
len(enc.encode(content)),
)
if len(enc.encode(content)) > max_tokens:
content = enc.decode(enc.encode(content)[:max_tokens])
return content
class DataType(Enum):
DOCUMENT = "DOCUMENT"
WEBSITE = "WEBSITE"
IMAGE = "IMAGE"
TABLE = "TABLE"
AUDIO = "AUDIO"
TEXT = "TEXT"
@staticmethod
def extra_values_map():
return {
"SHORT_AUDIO": "AUDIO",
}
@classmethod
def _missing_(cls, value):
# Try to find the corresponding primary key value from the extra value mapping
extra_map = cls.extra_values_map()
if value in extra_map:
value = extra_map[value]
return cls.__members__.get(value)
# If not found, return DOCUMENT by default
logging.error("DataType._missing_(): Could not find corresponding DataType enum value: %s", value)
return cls.DOCUMENT
def get_urls(string):
url_arr = []
if not string:
return url_arr
pattern = re.compile(
r"(https?|ftp|file)://[-A-Za-z0-9+&@#/%?=~_|!:,.;\u4e00-\u9fa5]+[-A-Za-z0-9+&@#/%=~_|]"
)
matcher = pattern.finditer(string)
for match in matcher:
url_arr.append(match.group())
sorted_url_arr = sorted(set(url_arr), key=len, reverse=True)
return sorted_url_arr
def get_random_string(s_length: int) -> str:
# Generate a random string
letters = string.ascii_letters + string.digits
return "".join(random.choice(letters) for i in range(s_length))
def get_random_strings(n: int, s_length: int) -> List[str]:
unique_strings = set()
while len(unique_strings) < n:
unique_strings.add(get_random_string(s_length))
return list(unique_strings)
def encode_urls(text, random_string_len: int = 16):
urls = get_urls(text)
random_strings = get_random_strings(len(urls), random_string_len)
url2string_dict = dict(zip(urls, random_strings))
string2url_dict = dict(zip(random_strings, urls))
for url, random_string in url2string_dict.items():
text = text.replace(url, random_string)
return text, string2url_dict
def decode_urls(text, string2url_dict):
for random_string, url in string2url_dict.items():
text = text.replace(random_string, url)
return text
class TokenParagraphSplitter(TextSplitter):
"""For business data characteristics, perform some additional processing. This includes:
1. Complete fragments as independent chunks help improve information focus in each chunk. Complete fragments are mainly determined by period+newline.
2. When complete fragments are too long, split them into sentences and combine sentences into chunks that meet window size limits
3. If a sentence is too long, split it directly by token granularity
"""
line_break_characters = ["\n", "\f", "\r", "\v"]
whitespace_characters = [" ", "\t"]
sentence_terminators = [
".",
"!",
"?",
"。",
"!",
"?",
"……",
"...",
] + line_break_characters
paired_punctuation = [
("(", ")"),
("[", "]"),
("{", "}"),
("<", ">"),
("“", "”"),
("‘", "’"),
("《", "》"),
("【", "】"),
]
intra_sentence_delimiters = [",", ",", ";", ";"] + whitespace_characters
def __init__(
self,
encoding_name: str = "cl100k_base",
allowed_special: Union[Literal["all"], AbstractSet[str]] = ALLOW_SPECIAL_TOKEN,
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
try:
import tiktoken
except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to for TokenTextSplitter. "
"Please it install it with `pip install tiktoken`."
)
# create a GPT-3 encoder instance
self._tokenizer = tiktoken.get_encoding(encoding_name)
self._allowed_special = allowed_special
self._disallowed_special = disallowed_special
def split_text(self, text: str) -> List[str]:
chunks = []
# Clean up abnormal whitespace characters in the text, such as replacing 3 or more consecutive \n with \n\n
text = text_filter(text)
# Replace URLs in the text to avoid symbols like ./?/ in URLs interfering with sentence splitting
text, string2url_dict = encode_urls(text)
url_strings = list(string2url_dict.keys())
# Split by paragraphs according to rules
paragraphs = self._split_to_paragraphs(
text, min_paragraph_length=self._chunk_size // 2
)
for i, paragraph in enumerate(paragraphs):
splits = self._split_to_chunks(paragraph, url_strings)
logging.debug(
"paragraph %s/%s %s characters: %s",
i + 1,
len(paragraphs),
len(paragraph),
paragraph,
)
logging.debug(
"paragraph %s/%s split into %s chunks: %s",
i + 1,
len(paragraphs),
len(splits),
splits,
)
chunks.extend(splits)
chunks = [decode_urls(chunk, string2url_dict) for chunk in chunks]
return chunks
def _split_to_chunks(self, text: str, url_strings: List[str] = []) -> List[str]:
sentences = self._split_to_sentences(text, url_strings)
chunks = self._merge_sentences_into_chunks(
sentences, min_chunk_size=self._chunk_size // 2
)
return chunks
def _split_to_paragraphs(
self, text: str, min_paragraph_length: int = 0
) -> List[str]:
"""Currently split the original document into paragraphs directly based on the \n[any space]\n rule."""
line_break_characters = "".join(self.line_break_characters)
whitespace_characters = "".join(self.whitespace_characters)
paragraphs = re.split(
f"([{line_break_characters}]+[{whitespace_characters}]*[{line_break_characters}])+",
text,
)
if len(paragraphs) % 2 == 1:
paragraphs = [""] + paragraphs
paragraphs = [
(paragraphs[i], paragraphs[i + 1])
for i in range(0, len(paragraphs), 2)
if (paragraphs[i] + paragraphs[i + 1]).strip()
]
if not paragraphs:
return []
new_paragraphs = []
cur_paragraph, cur_paragraph_len = "", 0
# merge short or broken paragraphs
for sep, paragraph in paragraphs:
if cur_paragraph_len >= min_paragraph_length and any(
cur_paragraph.endswith(sym) for sym in self.sentence_terminators
):
new_paragraphs.append(cur_paragraph.strip())
cur_paragraph, cur_paragraph_len = "", 0
cur_paragraph_len += len(self._tokenizer.encode(sep + paragraph))
cur_paragraph += sep + paragraph
if cur_paragraph:
new_paragraphs.append(cur_paragraph.strip())
return new_paragraphs
def _split_to_sentences(self, text: str, url_strings: List[str] = []) -> List[str]:
# Use capture groups to preserve sentence separators
pattern = (
f"({'|'.join(re.escape(symbol) for symbol in self.sentence_terminators)})+"
)
parts = re.split(pattern, text)
sentences = []
# Merge by skipping steps to ensure punctuation is added to the end of the corresponding sentence
if len(parts) % 2 == 1:
parts.append("")
sentences = ["".join(parts[i : i + 2]) for i in range(0, len(parts), 2)]
sentences = [s for s in sentences if s.strip()]
if not sentences:
return []
# Fix fragmented sentences, mainly for special cases such as numeric indices, floating-point numbers, etc., which may be separated
sentences = self.recombine_broken_sentences(sentences)
# Split sentences that are too long; in the short term, split directly by character length; future optimizations could consider splitting by punctuation within sentences
sentences_list = [
self._force_split_to_chunks(s, url_strings) for s in sentences
]
sentences = list(chain.from_iterable(sentences_list))
return sentences
def recombine_broken_sentences(self, sentences: List[str]) -> List[str]:
"""Fix fragmented sentences, mainly for special cases such as numeric indices, floating-point numbers, etc., which may be separated。"""
if len(sentences) < 2:
return sentences
open_symbols_dict = {
open_sym: close_sym for open_sym, close_sym in self.paired_punctuation
}
close_symbols_dict = {
close_sym: open_sym for open_sym, close_sym in self.paired_punctuation
}
new_sentences = []
cur_sentences = ""
unmatched_symbol = []
for sent in sentences:
# If the current sentence is not empty, doesn't meet predefined merge conditions, and has no pending matching punctuation ([, (, {, etc.), then consider the sentence complete
if cur_sentences.strip() and not (
self.check_merge(cur_sentences, sent) or unmatched_symbol
):
new_sentences.append(cur_sentences)
cur_sentences = ""
for c in sent:
if c in open_symbols_dict:
unmatched_symbol.append(c)
elif c in close_symbols_dict:
if (
unmatched_symbol
and unmatched_symbol[-1] == close_symbols_dict[c]
):
unmatched_symbol.pop()
# By default, the current sentence ends when a newline-like character appears
if c in self.line_break_characters:
unmatched_symbol = []
if cur_sentences.strip():
new_sentences.append(cur_sentences)
cur_sentences = ""
cur_sentences += c
if cur_sentences:
new_sentences.append(cur_sentences)
return new_sentences
def check_merge(self, pre_sen, cur_sen):
if len(pre_sen) > 1 and len(cur_sen) > 0:
# If it's a decimal point in the middle of a floating-point number
if pre_sen[-1] == "." and pre_sen[-2].isdigit() and cur_sen[0].isdigit():
return True
# If it's a numeric index at the beginning of a sentence, such as 1. *****\n2. *****
if (
pre_sen[-1] == "."
and pre_sen[-2].isdigit()
and cur_sen[0] not in self.line_break_characters
):
return True
# In markdown format, ! followed by [ may be an image link
if (
pre_sen[-1] == "!"
and pre_sen[-2] in self.line_break_characters
and cur_sen[0] == "["
):
return True
return False
def _merge_sentences_into_chunks(
self, sentences: List[str], min_chunk_size: int = 200
) -> List[str]:
"""Assemble into chunks according to chunk_size and overlap. Note that external guarantees ensure that the length of a single sentence does not exceed chunk_size"""
if not sentences:
return []
n_tokens = [
len(
self._tokenizer.encode(
sentence,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
)
for sentence in sentences
]
chunks = []
start_idx = 0
end_idx = start_idx + 1
cur_token_num = n_tokens[start_idx]
while start_idx < len(n_tokens):
# Tail reaches the end point,
if end_idx >= len(n_tokens):
chunk = "".join(sentences[start_idx:end_idx])
logging.debug(
"sentences[%s:%s] merged into chunk, current num_tokens: %s(%s)",
start_idx,
end_idx,
sum(n_tokens[start_idx:end_idx]),
cur_token_num,
)
chunks.append(chunk)
break
else:
# +The next sentence will not exceed chunk_size, continue to include new sentences
if cur_token_num + n_tokens[end_idx] <= self._chunk_size:
cur_token_num += n_tokens[end_idx]
end_idx += 1
# +The next sentence will exceed chunk_size, assemble the current chunk and move to the next chunk
else:
chunk = "".join(sentences[start_idx:end_idx])
logging.debug(
"sentences[%s:%s] merged into chunk, current num_tokens: %s(%s)",
start_idx,
end_idx,
sum(n_tokens[start_idx:end_idx]),
cur_token_num,
)
chunks.append(chunk)
# Next chunk: idx moves at least one position forward, start_idx allows overlap
end_idx = end_idx + 1
# Find a new starting point for start_idx that doesn't exceed the overlap
new_start_idx = end_idx - 1
overlap = 0
new_cur_token_num = n_tokens[new_start_idx]
while new_start_idx > start_idx + 1:
if (
overlap + n_tokens[new_start_idx - 1] >= self._chunk_overlap
or new_cur_token_num >= self._chunk_size
):
break
new_start_idx -= 1
overlap += n_tokens[new_start_idx]
new_cur_token_num += n_tokens[new_start_idx]
start_idx = new_start_idx
cur_token_num = new_cur_token_num
if len(chunks) > 1 and len(chunks[-1]) < min_chunk_size:
logging.warning(
"The last chunk length %s is less than %s, merge with the previous chunk",
len(chunks[-1]),
min_chunk_size,
)
last_chunk = chunks.pop()
chunks[-1] += last_chunk
chunks = [chunk for chunk in chunks if chunk.strip()]
return chunks
def _force_split_to_chunks(
self, text: str, url_strings: List[str] = []
) -> List[str]:
# TODO: In the future, consider adding forced splitting logic, such as: if a single sentence is too long, split by punctuation within the sentence, trying to preserve links and other data that require complete information
"""If a single sentence is too long, it can only be forcibly split, split by punctuation within the sentence, trying to preserve links and other data that require complete information"""
splits = []
input_ids = self._tokenizer.encode(
text,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
if len(input_ids) < self._chunk_size:
return [text]
if text[-1] not in self.sentence_terminators + self.intra_sentence_delimiters:
text += self.sentence_terminators[0]
cur_sentence, cur_sentence_len = "", 0
sub_sentence = ""
for c in text:
sub_sentence += c
if c in self.intra_sentence_delimiters + self.sentence_terminators:
sub_sentence_len = len(self._tokenizer.encode(sub_sentence))
if (
cur_sentence_len + sub_sentence_len
> self._chunk_size - self._chunk_overlap
):
if cur_sentence:
splits.append(cur_sentence)
cur_sentence, cur_sentence_len = sub_sentence, sub_sentence_len
else:
# This indicates that sub_sentence is too long, at this point directly follow the forced splitting logic based on tokens
_splits = self.safe_split(sub_sentence, url_strings)
splits.extend(_splits[:-1])
cur_sentence, cur_sentence_len = _splits[-1], len(_splits[-1])
else:
cur_sentence += sub_sentence
cur_sentence_len += sub_sentence_len
sub_sentence = ""
if cur_sentence:
splits.append(cur_sentence)
return splits
def safe_split(self, sub_sentence: str, url_strings: List[str] = []) -> List[str]:
sub_sentence_tokens = self._tokenizer.encode(sub_sentence)
# Find the position intervals of all strings in url_strings
url_string_intervals = []
for url_string in url_strings:
encoded_url_string = self._tokenizer.encode(url_string)
# Use find_sublist_indices to find all position intervals
url_string_intervals.extend(
find_sublist_indices(sub_sentence_tokens, encoded_url_string)
)
_splits = []
i = 0
while i < len(sub_sentence_tokens):
if i + self._chunk_size >= len(sub_sentence_tokens):
slice_end = len(sub_sentence_tokens)
else:
slice_end = i + self._chunk_size - self._chunk_overlap
# Determine if the split interval overlaps with any important string intervals
for s_begin, s_end in url_string_intervals:
if i < s_end <= slice_end or i < s_begin < slice_end:
slice_end = max(slice_end, s_end)
# Split and record the current chunk
_splits.append(self._tokenizer.decode(sub_sentence_tokens[i:slice_end]))
# Move to the starting point of the next chunk
i = slice_end
return _splits
def get_summarize_title_keywords(responses):
# Clean LLM generated content to obtain summarized text titles, abstracts, and keywords
pattern = re.compile(r"\{.*(\}|\]|\,)", re.DOTALL)
gen_texts = [each.choices[0].message.content for each in responses]
logging.info("gen_texts: %s", gen_texts)
results = []
for res in gen_texts:
try:
# Match against the pattern
matches = list(pattern.finditer(res))
if not matches:
results.append(("", "", []))
else:
answer = matches[0].group(0)
content = answer.strip().strip(",")
content += "]" * (content.count("[") - content.count("]"))
content += "}" * (content.count("{") - content.count("}"))
d = json.loads(res)
results.append(
(d.get("title", ""), d.get("summary", ""), d.get("keywords", []))
)
except json.JSONDecodeError:
logging.warning("JSON parsing failed, returning empty list")
results.append(("", "", []))
return results