from pathlib import Path from typing import Dict, OrderedDict, Tuple, Union from .custom_types import MessageId, Message, Conversation, MessageText from typing import List import re import json class TelegramDataExtractor: @classmethod def load_messages_from_json(cls, path: str, max_message_count: int = None) -> OrderedDict[MessageId, Message]: messages: OrderedDict[MessageId, Message] = OrderedDict() with open(path, "r", encoding="utf-8") as file: chat_json = json.load(file) for i, message in enumerate(chat_json["messages"]): if max_message_count and i == max_message_count: break if message["type"] != "message": continue new_message = { "from": cls.normalize_username(message["from"]), "id": message["id"], "text": cls.normalize(message["text"]) } if not new_message["text"]: # Check for empty message continue if "reply_to_message_id" in message.keys(): new_message["reply_to_id"] = message["reply_to_message_id"] messages[new_message["id"]] = new_message return messages @staticmethod def conversations_from_messages(save_to: Path, tokenizer, messages: OrderedDict[MessageId, Message]) -> List[Conversation]: _MAX_MESSAGE_LEN = 150 _MAX_QA_LEN_DIFF = 20 def remove_duplicates_keep_order(lst: List[Conversation]) -> List[Conversation]: lst = list(dict.fromkeys(lst)) # Remove duplicates and keep order return [(list(x[0]), list(x[1]), x[2]) for x in lst] # Tuples are only needed for hashability def remove_answers_with_only_special_symbols(lst: List[Conversation]) -> List[Conversation]: return [i for i in lst if re.findall(r"[а-я]", " ".join(i[1]))] def remove_long_qa(lst: List[Conversation]) -> List[Conversation]: return [i for i in lst if len(i[0]) <= _MAX_MESSAGE_LEN and len(i[1]) <= _MAX_MESSAGE_LEN] def remove_unbalanced_qa(lst: List[Conversation]) -> List[Conversation]: return [i for i in lst if abs(len(i[0]) - len(i[1])) <= _MAX_QA_LEN_DIFF] def normalize_conversations(lst: List[Conversation]) -> List[Conversation]: lst = remove_duplicates_keep_order(lst) lst = remove_answers_with_only_special_symbols(lst) lst = remove_long_qa(lst) lst = remove_unbalanced_qa(lst) return lst conversations: List[Conversation] = [] questions: Dict[MessageText, int] = dict() messages_values = list(messages.values()) # TODO: try changing this cast to something more applicable for i in range(len(messages) - 1): # There's no answer for last message so add -1 try: # Message is answer for message with `id` of `reply_to_id` prev_message = messages[messages_values[i+1]["reply_to_id"]] except KeyError: prev_message = messages_values[i] qa = (prev_message["text"], messages_values[i+1]["text"], prev_message["from"]) if qa[0] in questions.keys(): # If there are multiple answers for same message, choose the longest one if len(conversations[questions[qa[0]]][1]) < len(qa[1]) and abs(len(conversations[questions[qa[0]]][1]) - len(qa[1])) <= _MAX_QA_LEN_DIFF: conversations[questions[qa[0]]] = (qa[0], qa[1], qa[2]) continue else: questions[qa[0]] = len(conversations) conversations.append(qa) conversations = normalize_conversations(conversations) output_path = save_to / "train_dataset.txt" with open(output_path, "w", encoding="utf-8") as file: for conversation in conversations: line = " " + conversation[2] + " " + " ".join(conversation[0]) + f" {tokenizer.eos_token} " + " ".join(conversation[1]) + f" {tokenizer.eos_token}" + "\n" file.write(line) return output_path @staticmethod def normalize(text: Union[str, List]) -> Tuple[str]: if isinstance(text, List): text = " ".join([word for word in text if isinstance(word, str)]) text = text.lower().strip() text = re.sub(r"[^а-яё.!?:\d]+", r" ", text) # Leave only russian and special characters text = re.sub(r'\.(\s*\.)+', '... ', text) # Replace any sequence of 2+ dots with '...' text = re.sub(r'([?!])(\s*\1)+', r'\1 ', text) # Collapse repeating ? or ! text = re.sub(r"([!?]|\.+)", r"\1 ", text) # Separate special symbols by whitespaces text = re.sub(r"ё", r"е", text) text = re.sub(r"(.*[ауспэиычвекьхъз]{6,}.*|\b[апх][аеписх]{2,3}\b|\b[ах]{2,}\b)", r" ", text) # Laugh token for strings such as `ахах` etc. text = re.sub(r"()(\s*\1)+", r" ", text) # Collapse repeating tokens text = re.sub(r"\s+", r" ", text).strip() # Leave only one space between each word return tuple(text.split()) @staticmethod def normalize_username(text: str) -> Tuple[str]: text = text.lower() text = re.sub(r"[^а-яa-z\s]+", "", text).strip() return text