Spaces:
Sleeping
Sleeping
Gemini
fix: Resolve runtime errors and timeout issues\n\n- Create writable directories for logs and Hugging Face cache.\n- Set TRANSFORMERS_CACHE environment variable.\n- Fix SyntaxWarning in lpm_kernel/utils.py.\n- Disable ChromaDB telemetry to prevent posthog errors.
e09e6c3
import logging | |
from enum import Enum | |
import tiktoken | |
import re | |
from typing import Any, Optional, Union, Collection, AbstractSet, Literal, List | |
from langchain.text_splitter import TextSplitter | |
import random | |
import string | |
from itertools import chain | |
import json | |
from lpm_kernel.configs.logging import get_train_process_logger | |
logger = get_train_process_logger() | |
class IntentType(Enum): | |
Emotion = "Emotion" | |
Knowledge = "Knowledge" | |
def select_language_desc( | |
preferred_language, | |
default_desc="Identify the language of the provided Hint. Your response must be in the same language.", | |
): | |
custom_desc = "You must respond in {}." | |
if isinstance(preferred_language, str) and "/" in preferred_language: | |
native, es = preferred_language.split("/") | |
logging.info(f"Native: {native}, ES: {es}") | |
return custom_desc.format(es) | |
else: | |
logging.info( | |
"Error: preferred_language is not in the correct format. It should be 'native/es'." | |
) | |
return default_desc | |
def cal_upperbound( | |
model_limit: int = 4096, | |
generage_limit: int = 512, | |
tolerance: int = 500, | |
raw: str = "", | |
model_name: str = "gpt-3.5-turbo", | |
) -> int: | |
""" | |
:param model_limit: Maximum token count for the underlying model call | |
:param tolerance: Error tolerance buffer | |
:param raw: system prompt and raw content | |
:return: | |
""" | |
if model_name is not None: | |
if model_name in tiktoken.model.MODEL_TO_ENCODING: | |
enc = tiktoken.encoding_for_model(model_name) | |
logging.info(f"Successfully initialized tokenizer for model: {model_name}") | |
else: | |
enc = tiktoken.get_encoding("cl100k_base") | |
logging.warning(f"Model '{model_name}' doesn't have a corresponding tokenizer, falling back to default: cl100k_base") | |
else: | |
enc = tiktoken.get_encoding("cl100k_base") | |
logging.info(f"No model specified, using default tokenizer: cl100k_base") | |
raw_token = len(enc.encode(raw)) | |
upper_bound = model_limit - raw_token - tolerance - generage_limit | |
if upper_bound < 0: | |
logging.info(f"raw content is too long: {raw_token}") | |
return 0 | |
return upper_bound | |
def equidistant_filter(chunks, separator, filtered_chunks_n=6): | |
# Select the first and last two chunks, sample the remaining chunks evenly from the middle | |
gap = (len(chunks) - 2) / (filtered_chunks_n - 2) | |
indexes = [ | |
int(gap * i) | |
for i in range(int(len(chunks) / gap) + 1) | |
if (gap * i < len(chunks) - 2) | |
] | |
filtered_chunks = [chunks[i] for i in indexes] | |
filtered_chunks.append(separator.join(chunks[-2:])) | |
return filtered_chunks | |
def tab_or_space_replacement(match): | |
# If there is a tab character in the matched string, replace it with a single tab, otherwise replace it with a single space | |
return "\t" if "\t" in match.group() else " " | |
def text_filter(text: str) -> str: | |
pattern_tab_space = "[ \t]{3,}" | |
pattern_wordwrap = "[\n\f\r\v]{3,}" | |
# Replace when encountering three or more spaces or tabs | |
replaced_text = re.sub(pattern_tab_space, tab_or_space_replacement, text) | |
# When there are multiple consecutive \n (newline), \f (form feed), \r (carriage return), \v (vertical tab), replace them with 2 original newlines | |
replaced_text = re.sub(pattern_wordwrap, "\n\n", replaced_text) | |
return replaced_text | |
ALLOW_SPECIAL_TOKEN = {"<|endofprompt|>", "<|endoftext|>"} | |
def find_sublist_indices(main_list, sublist): | |
indices = [] | |
length = len(sublist) | |
for i in range(len(main_list) - length + 1): | |
if main_list[i : i + length] == sublist: | |
indices.append((i, i + length)) | |
return indices | |
class TokenTextSplitter(TextSplitter): | |
"""Implementation of splitting text that looks at tokens.""" | |
def __init__( | |
self, | |
encoding_name: str = "cl100k_base", | |
model_name: Optional[str] = None, | |
allowed_special: Union[Literal["all"], AbstractSet[str]] = ALLOW_SPECIAL_TOKEN, | |
disallowed_special: Union[Literal["all"], Collection[str]] = "all", | |
**kwargs: Any, | |
): | |
"""Create a new TextSplitter.""" | |
super().__init__(**kwargs) | |
try: | |
import tiktoken | |
except ImportError: | |
raise ValueError( | |
"Could not import tiktoken python package. " | |
"This is needed in order to for TokenTextSplitter. " | |
"Please it install it with `pip install tiktoken`." | |
) | |
# create a GPT-3 encoder instance | |
if model_name is not None: | |
if model_name in tiktoken.model.MODEL_TO_ENCODING: | |
enc = tiktoken.encoding_for_model(model_name) | |
logging.info(f"Successfully initialized tokenizer for model: {model_name}") | |
else: | |
enc = tiktoken.get_encoding(encoding_name) | |
logging.warning(f"Model '{model_name}' doesn't have a corresponding tokenizer, falling back to default: {encoding_name}") | |
else: | |
enc = tiktoken.get_encoding(encoding_name) | |
logging.info(f"No model specified, using default tokenizer: {encoding_name}") | |
self._tokenizer = enc | |
self._allowed_special = allowed_special | |
self._disallowed_special = disallowed_special | |
def split_text(self, text: str) -> List[str]: | |
"""Split incoming text and return chunks.""" | |
# Filter content with a large number of whitespace characters in the input text to increase the proportion of effective content within chunks | |
text = text_filter(text) | |
splits = [] | |
input_ids = self._tokenizer.encode( | |
text, | |
allowed_special=self._allowed_special, | |
disallowed_special=self._disallowed_special, | |
) | |
start_idx = 0 | |
while start_idx < len(input_ids): | |
cur_idx = min(start_idx + self._chunk_size, len(input_ids)) | |
chunk_ids = input_ids[start_idx:cur_idx] | |
s = self._tokenizer.decode(chunk_ids).strip() | |
if s: | |
s = self._cut_meaningless_head_tail(s) | |
if s: | |
splits.append(s) | |
start_idx += self._chunk_size - self._chunk_overlap | |
logging.debug("finished split_text(): %s splits", len(splits)) | |
return splits | |
def _cut_meaningless_head_tail(self, text: str) -> str: | |
# Only split when there are multiple newlines, as parsing of PDF/Word often contains false newlines | |
sentences = re.split(r"\. |! |\? |。|!|?|\n+ *\n+", text) | |
if len(sentences) < 2: | |
return text | |
head = sentences[0] | |
body = ". ".join(sentences[1:-1]) | |
tail = sentences[-1] | |
head_len = len( | |
self._tokenizer.encode( | |
body, | |
allowed_special=self._allowed_special, | |
disallowed_special=self._disallowed_special, | |
) | |
) | |
body_len = len( | |
self._tokenizer.encode( | |
body, | |
allowed_special=self._allowed_special, | |
disallowed_special=self._disallowed_special, | |
) | |
) | |
tail_len = len( | |
self._tokenizer.encode( | |
tail, | |
allowed_special=self._allowed_special, | |
disallowed_special=self._disallowed_special, | |
) | |
) | |
parts = [] | |
# Use length to roughly estimate the impact of discarding the tail; if the impact is not significant, discard it | |
# Rough estimate: Chinese 20 tokens, 8 characters; English 10 tokens, 30 characters | |
if head_len >= 20 or len(head) >= 30: | |
parts.append(head) | |
if body_len > 0: | |
parts.append(body) | |
if tail_len >= 20 or len(tail) >= 30: | |
parts.append(tail) | |
res = "\n".join(parts) | |
logger.info( | |
"_cut_meaningless_tail() removes redundant sentence tails from chunks, before cut: %s characters, after cut: %s characters", | |
len(text), | |
len(res), | |
) | |
return res | |
def chunk_filter( | |
chunks, filter, filtered_chunks_n=6, separator="\n", spacer="\n……\n……\n……\n" | |
): | |
if len(chunks) <= filtered_chunks_n: | |
return separator.join(chunks) | |
return spacer.join(filter(chunks, separator, filtered_chunks_n)) | |
def get_safe_content_turncate(content, model_name="gpt-3.5-turbo", max_tokens=3300): | |
if model_name is not None: | |
if model_name in tiktoken.model.MODEL_TO_ENCODING: | |
enc = tiktoken.encoding_for_model(model_name) | |
logging.info(f"Successfully initialized tokenizer for model: {model_name}") | |
else: | |
enc = tiktoken.get_encoding("cl100k_base") | |
logging.warning(f"Model '{model_name}' doesn't have a corresponding tokenizer, falling back to default: cl100k_base") | |
else: | |
enc = tiktoken.get_encoding("cl100k_base") | |
logging.info(f"No model specified, using default tokenizer: cl100k_base") | |
logging.warning( | |
"get_safe_content_turncate(): current model maximum input length is %s, current input length is %s", | |
max_tokens, | |
len(enc.encode(content)), | |
) | |
if len(enc.encode(content)) > max_tokens: | |
content = enc.decode(enc.encode(content)[:max_tokens]) | |
return content | |
class DataType(Enum): | |
DOCUMENT = "DOCUMENT" | |
WEBSITE = "WEBSITE" | |
IMAGE = "IMAGE" | |
TABLE = "TABLE" | |
AUDIO = "AUDIO" | |
TEXT = "TEXT" | |
def extra_values_map(): | |
return { | |
"SHORT_AUDIO": "AUDIO", | |
} | |
def _missing_(cls, value): | |
# Try to find the corresponding primary key value from the extra value mapping | |
extra_map = cls.extra_values_map() | |
if value in extra_map: | |
value = extra_map[value] | |
return cls.__members__.get(value) | |
# If not found, return DOCUMENT by default | |
logging.error("DataType._missing_(): Could not find corresponding DataType enum value: %s", value) | |
return cls.DOCUMENT | |
def get_urls(string): | |
url_arr = [] | |
if not string: | |
return url_arr | |
pattern = re.compile( | |
r"(https?|ftp|file)://[-A-Za-z0-9+&@#/%?=~_|!:,.;\u4e00-\u9fa5]+[-A-Za-z0-9+&@#/%=~_|]" | |
) | |
matcher = pattern.finditer(string) | |
for match in matcher: | |
url_arr.append(match.group()) | |
sorted_url_arr = sorted(set(url_arr), key=len, reverse=True) | |
return sorted_url_arr | |
def get_random_string(s_length: int) -> str: | |
# Generate a random string | |
letters = string.ascii_letters + string.digits | |
return "".join(random.choice(letters) for i in range(s_length)) | |
def get_random_strings(n: int, s_length: int) -> List[str]: | |
unique_strings = set() | |
while len(unique_strings) < n: | |
unique_strings.add(get_random_string(s_length)) | |
return list(unique_strings) | |
def encode_urls(text, random_string_len: int = 16): | |
urls = get_urls(text) | |
random_strings = get_random_strings(len(urls), random_string_len) | |
url2string_dict = dict(zip(urls, random_strings)) | |
string2url_dict = dict(zip(random_strings, urls)) | |
for url, random_string in url2string_dict.items(): | |
text = text.replace(url, random_string) | |
return text, string2url_dict | |
def decode_urls(text, string2url_dict): | |
for random_string, url in string2url_dict.items(): | |
text = text.replace(random_string, url) | |
return text | |
class TokenParagraphSplitter(TextSplitter): | |
"""For business data characteristics, perform some additional processing. This includes: | |
1. Complete fragments as independent chunks help improve information focus in each chunk. Complete fragments are mainly determined by period+newline. | |
2. When complete fragments are too long, split them into sentences and combine sentences into chunks that meet window size limits | |
3. If a sentence is too long, split it directly by token granularity | |
""" | |
line_break_characters = ["\n", "\f", "\r", "\v"] | |
whitespace_characters = [" ", "\t"] | |
sentence_terminators = [ | |
".", | |
"!", | |
"?", | |
"。", | |
"!", | |
"?", | |
"……", | |
"...", | |
] + line_break_characters | |
paired_punctuation = [ | |
("(", ")"), | |
("[", "]"), | |
("{", "}"), | |
("<", ">"), | |
("“", "”"), | |
("‘", "’"), | |
("《", "》"), | |
("【", "】"), | |
] | |
intra_sentence_delimiters = [",", ",", ";", ";"] + whitespace_characters | |
def __init__( | |
self, | |
encoding_name: str = "cl100k_base", | |
allowed_special: Union[Literal["all"], AbstractSet[str]] = ALLOW_SPECIAL_TOKEN, | |
disallowed_special: Union[Literal["all"], Collection[str]] = "all", | |
**kwargs: Any, | |
): | |
"""Create a new TextSplitter.""" | |
super().__init__(**kwargs) | |
try: | |
import tiktoken | |
except ImportError: | |
raise ValueError( | |
"Could not import tiktoken python package. " | |
"This is needed in order to for TokenTextSplitter. " | |
"Please it install it with `pip install tiktoken`." | |
) | |
# create a GPT-3 encoder instance | |
self._tokenizer = tiktoken.get_encoding(encoding_name) | |
self._allowed_special = allowed_special | |
self._disallowed_special = disallowed_special | |
def split_text(self, text: str) -> List[str]: | |
chunks = [] | |
# Clean up abnormal whitespace characters in the text, such as replacing 3 or more consecutive \n with \n\n | |
text = text_filter(text) | |
# Replace URLs in the text to avoid symbols like ./?/ in URLs interfering with sentence splitting | |
text, string2url_dict = encode_urls(text) | |
url_strings = list(string2url_dict.keys()) | |
# Split by paragraphs according to rules | |
paragraphs = self._split_to_paragraphs( | |
text, min_paragraph_length=self._chunk_size // 2 | |
) | |
for i, paragraph in enumerate(paragraphs): | |
splits = self._split_to_chunks(paragraph, url_strings) | |
logging.debug( | |
"paragraph %s/%s %s characters: %s", | |
i + 1, | |
len(paragraphs), | |
len(paragraph), | |
paragraph, | |
) | |
logging.debug( | |
"paragraph %s/%s split into %s chunks: %s", | |
i + 1, | |
len(paragraphs), | |
len(splits), | |
splits, | |
) | |
chunks.extend(splits) | |
chunks = [decode_urls(chunk, string2url_dict) for chunk in chunks] | |
return chunks | |
def _split_to_chunks(self, text: str, url_strings: List[str] = []) -> List[str]: | |
sentences = self._split_to_sentences(text, url_strings) | |
chunks = self._merge_sentences_into_chunks( | |
sentences, min_chunk_size=self._chunk_size // 2 | |
) | |
return chunks | |
def _split_to_paragraphs( | |
self, text: str, min_paragraph_length: int = 0 | |
) -> List[str]: | |
"""Currently split the original document into paragraphs directly based on the \n[any space]\n rule.""" | |
line_break_characters = "".join(self.line_break_characters) | |
whitespace_characters = "".join(self.whitespace_characters) | |
paragraphs = re.split( | |
f"([{line_break_characters}]+[{whitespace_characters}]*[{line_break_characters}])+", | |
text, | |
) | |
if len(paragraphs) % 2 == 1: | |
paragraphs = [""] + paragraphs | |
paragraphs = [ | |
(paragraphs[i], paragraphs[i + 1]) | |
for i in range(0, len(paragraphs), 2) | |
if (paragraphs[i] + paragraphs[i + 1]).strip() | |
] | |
if not paragraphs: | |
return [] | |
new_paragraphs = [] | |
cur_paragraph, cur_paragraph_len = "", 0 | |
# merge short or broken paragraphs | |
for sep, paragraph in paragraphs: | |
if cur_paragraph_len >= min_paragraph_length and any( | |
cur_paragraph.endswith(sym) for sym in self.sentence_terminators | |
): | |
new_paragraphs.append(cur_paragraph.strip()) | |
cur_paragraph, cur_paragraph_len = "", 0 | |
cur_paragraph_len += len(self._tokenizer.encode(sep + paragraph)) | |
cur_paragraph += sep + paragraph | |
if cur_paragraph: | |
new_paragraphs.append(cur_paragraph.strip()) | |
return new_paragraphs | |
def _split_to_sentences(self, text: str, url_strings: List[str] = []) -> List[str]: | |
# Use capture groups to preserve sentence separators | |
pattern = ( | |
f"({'|'.join(re.escape(symbol) for symbol in self.sentence_terminators)})+" | |
) | |
parts = re.split(pattern, text) | |
sentences = [] | |
# Merge by skipping steps to ensure punctuation is added to the end of the corresponding sentence | |
if len(parts) % 2 == 1: | |
parts.append("") | |
sentences = ["".join(parts[i : i + 2]) for i in range(0, len(parts), 2)] | |
sentences = [s for s in sentences if s.strip()] | |
if not sentences: | |
return [] | |
# Fix fragmented sentences, mainly for special cases such as numeric indices, floating-point numbers, etc., which may be separated | |
sentences = self.recombine_broken_sentences(sentences) | |
# Split sentences that are too long; in the short term, split directly by character length; future optimizations could consider splitting by punctuation within sentences | |
sentences_list = [ | |
self._force_split_to_chunks(s, url_strings) for s in sentences | |
] | |
sentences = list(chain.from_iterable(sentences_list)) | |
return sentences | |
def recombine_broken_sentences(self, sentences: List[str]) -> List[str]: | |
"""Fix fragmented sentences, mainly for special cases such as numeric indices, floating-point numbers, etc., which may be separated。""" | |
if len(sentences) < 2: | |
return sentences | |
open_symbols_dict = { | |
open_sym: close_sym for open_sym, close_sym in self.paired_punctuation | |
} | |
close_symbols_dict = { | |
close_sym: open_sym for open_sym, close_sym in self.paired_punctuation | |
} | |
new_sentences = [] | |
cur_sentences = "" | |
unmatched_symbol = [] | |
for sent in sentences: | |
# If the current sentence is not empty, doesn't meet predefined merge conditions, and has no pending matching punctuation ([, (, {, etc.), then consider the sentence complete | |
if cur_sentences.strip() and not ( | |
self.check_merge(cur_sentences, sent) or unmatched_symbol | |
): | |
new_sentences.append(cur_sentences) | |
cur_sentences = "" | |
for c in sent: | |
if c in open_symbols_dict: | |
unmatched_symbol.append(c) | |
elif c in close_symbols_dict: | |
if ( | |
unmatched_symbol | |
and unmatched_symbol[-1] == close_symbols_dict[c] | |
): | |
unmatched_symbol.pop() | |
# By default, the current sentence ends when a newline-like character appears | |
if c in self.line_break_characters: | |
unmatched_symbol = [] | |
if cur_sentences.strip(): | |
new_sentences.append(cur_sentences) | |
cur_sentences = "" | |
cur_sentences += c | |
if cur_sentences: | |
new_sentences.append(cur_sentences) | |
return new_sentences | |
def check_merge(self, pre_sen, cur_sen): | |
if len(pre_sen) > 1 and len(cur_sen) > 0: | |
# If it's a decimal point in the middle of a floating-point number | |
if pre_sen[-1] == "." and pre_sen[-2].isdigit() and cur_sen[0].isdigit(): | |
return True | |
# If it's a numeric index at the beginning of a sentence, such as 1. *****\n2. ***** | |
if ( | |
pre_sen[-1] == "." | |
and pre_sen[-2].isdigit() | |
and cur_sen[0] not in self.line_break_characters | |
): | |
return True | |
# In markdown format, ! followed by [ may be an image link | |
if ( | |
pre_sen[-1] == "!" | |
and pre_sen[-2] in self.line_break_characters | |
and cur_sen[0] == "[" | |
): | |
return True | |
return False | |
def _merge_sentences_into_chunks( | |
self, sentences: List[str], min_chunk_size: int = 200 | |
) -> List[str]: | |
"""Assemble into chunks according to chunk_size and overlap. Note that external guarantees ensure that the length of a single sentence does not exceed chunk_size""" | |
if not sentences: | |
return [] | |
n_tokens = [ | |
len( | |
self._tokenizer.encode( | |
sentence, | |
allowed_special=self._allowed_special, | |
disallowed_special=self._disallowed_special, | |
) | |
) | |
for sentence in sentences | |
] | |
chunks = [] | |
start_idx = 0 | |
end_idx = start_idx + 1 | |
cur_token_num = n_tokens[start_idx] | |
while start_idx < len(n_tokens): | |
# Tail reaches the end point, | |
if end_idx >= len(n_tokens): | |
chunk = "".join(sentences[start_idx:end_idx]) | |
logging.debug( | |
"sentences[%s:%s] merged into chunk, current num_tokens: %s(%s)", | |
start_idx, | |
end_idx, | |
sum(n_tokens[start_idx:end_idx]), | |
cur_token_num, | |
) | |
chunks.append(chunk) | |
break | |
else: | |
# +The next sentence will not exceed chunk_size, continue to include new sentences | |
if cur_token_num + n_tokens[end_idx] <= self._chunk_size: | |
cur_token_num += n_tokens[end_idx] | |
end_idx += 1 | |
# +The next sentence will exceed chunk_size, assemble the current chunk and move to the next chunk | |
else: | |
chunk = "".join(sentences[start_idx:end_idx]) | |
logging.debug( | |
"sentences[%s:%s] merged into chunk, current num_tokens: %s(%s)", | |
start_idx, | |
end_idx, | |
sum(n_tokens[start_idx:end_idx]), | |
cur_token_num, | |
) | |
chunks.append(chunk) | |
# Next chunk: idx moves at least one position forward, start_idx allows overlap | |
end_idx = end_idx + 1 | |
# Find a new starting point for start_idx that doesn't exceed the overlap | |
new_start_idx = end_idx - 1 | |
overlap = 0 | |
new_cur_token_num = n_tokens[new_start_idx] | |
while new_start_idx > start_idx + 1: | |
if ( | |
overlap + n_tokens[new_start_idx - 1] >= self._chunk_overlap | |
or new_cur_token_num >= self._chunk_size | |
): | |
break | |
new_start_idx -= 1 | |
overlap += n_tokens[new_start_idx] | |
new_cur_token_num += n_tokens[new_start_idx] | |
start_idx = new_start_idx | |
cur_token_num = new_cur_token_num | |
if len(chunks) > 1 and len(chunks[-1]) < min_chunk_size: | |
logging.warning( | |
"The last chunk length %s is less than %s, merge with the previous chunk", | |
len(chunks[-1]), | |
min_chunk_size, | |
) | |
last_chunk = chunks.pop() | |
chunks[-1] += last_chunk | |
chunks = [chunk for chunk in chunks if chunk.strip()] | |
return chunks | |
def _force_split_to_chunks( | |
self, text: str, url_strings: List[str] = [] | |
) -> List[str]: | |
# TODO: In the future, consider adding forced splitting logic, such as: if a single sentence is too long, split by punctuation within the sentence, trying to preserve links and other data that require complete information | |
"""If a single sentence is too long, it can only be forcibly split, split by punctuation within the sentence, trying to preserve links and other data that require complete information""" | |
splits = [] | |
input_ids = self._tokenizer.encode( | |
text, | |
allowed_special=self._allowed_special, | |
disallowed_special=self._disallowed_special, | |
) | |
if len(input_ids) < self._chunk_size: | |
return [text] | |
if text[-1] not in self.sentence_terminators + self.intra_sentence_delimiters: | |
text += self.sentence_terminators[0] | |
cur_sentence, cur_sentence_len = "", 0 | |
sub_sentence = "" | |
for c in text: | |
sub_sentence += c | |
if c in self.intra_sentence_delimiters + self.sentence_terminators: | |
sub_sentence_len = len(self._tokenizer.encode(sub_sentence)) | |
if ( | |
cur_sentence_len + sub_sentence_len | |
> self._chunk_size - self._chunk_overlap | |
): | |
if cur_sentence: | |
splits.append(cur_sentence) | |
cur_sentence, cur_sentence_len = sub_sentence, sub_sentence_len | |
else: | |
# This indicates that sub_sentence is too long, at this point directly follow the forced splitting logic based on tokens | |
_splits = self.safe_split(sub_sentence, url_strings) | |
splits.extend(_splits[:-1]) | |
cur_sentence, cur_sentence_len = _splits[-1], len(_splits[-1]) | |
else: | |
cur_sentence += sub_sentence | |
cur_sentence_len += sub_sentence_len | |
sub_sentence = "" | |
if cur_sentence: | |
splits.append(cur_sentence) | |
return splits | |
def safe_split(self, sub_sentence: str, url_strings: List[str] = []) -> List[str]: | |
sub_sentence_tokens = self._tokenizer.encode(sub_sentence) | |
# Find the position intervals of all strings in url_strings | |
url_string_intervals = [] | |
for url_string in url_strings: | |
encoded_url_string = self._tokenizer.encode(url_string) | |
# Use find_sublist_indices to find all position intervals | |
url_string_intervals.extend( | |
find_sublist_indices(sub_sentence_tokens, encoded_url_string) | |
) | |
_splits = [] | |
i = 0 | |
while i < len(sub_sentence_tokens): | |
if i + self._chunk_size >= len(sub_sentence_tokens): | |
slice_end = len(sub_sentence_tokens) | |
else: | |
slice_end = i + self._chunk_size - self._chunk_overlap | |
# Determine if the split interval overlaps with any important string intervals | |
for s_begin, s_end in url_string_intervals: | |
if i < s_end <= slice_end or i < s_begin < slice_end: | |
slice_end = max(slice_end, s_end) | |
# Split and record the current chunk | |
_splits.append(self._tokenizer.decode(sub_sentence_tokens[i:slice_end])) | |
# Move to the starting point of the next chunk | |
i = slice_end | |
return _splits | |
def get_summarize_title_keywords(responses): | |
# Clean LLM generated content to obtain summarized text titles, abstracts, and keywords | |
pattern = re.compile(r"\{.*(\}|\]|\,)", re.DOTALL) | |
gen_texts = [each.choices[0].message.content for each in responses] | |
logging.info("gen_texts: %s", gen_texts) | |
results = [] | |
for res in gen_texts: | |
try: | |
# Match against the pattern | |
matches = list(pattern.finditer(res)) | |
if not matches: | |
results.append(("", "", [])) | |
else: | |
answer = matches[0].group(0) | |
content = answer.strip().strip(",") | |
content += "]" * (content.count("[") - content.count("]")) | |
content += "}" * (content.count("{") - content.count("}")) | |
d = json.loads(res) | |
results.append( | |
(d.get("title", ""), d.get("summary", ""), d.get("keywords", [])) | |
) | |
except json.JSONDecodeError: | |
logging.warning("JSON parsing failed, returning empty list") | |
results.append(("", "", [])) | |
return results | |