Spaces:
Running
Running
from pathlib import Path | |
from typing import Dict, OrderedDict, Tuple, Union | |
from .custom_types import MessageId, Message, Conversation, MessageText | |
from typing import List | |
import re | |
import json | |
class TelegramDataExtractor: | |
def load_messages_from_json(cls, path: str, max_message_count: int = None) -> OrderedDict[MessageId, Message]: | |
messages: OrderedDict[MessageId, Message] = OrderedDict() | |
with open(path, "r", encoding="utf-8") as file: | |
chat_json = json.load(file) | |
for i, message in enumerate(chat_json["messages"]): | |
if max_message_count and i == max_message_count: | |
break | |
if message["type"] != "message": | |
continue | |
new_message = { | |
"from": cls.normalize_username(message["from"]), | |
"id": message["id"], | |
"text": cls.normalize(message["text"]) | |
} | |
if not new_message["text"]: # Check for empty message | |
continue | |
if "reply_to_message_id" in message.keys(): | |
new_message["reply_to_id"] = message["reply_to_message_id"] | |
messages[new_message["id"]] = new_message | |
return messages | |
def conversations_from_messages(save_to: Path, tokenizer, messages: OrderedDict[MessageId, Message]) -> List[Conversation]: | |
_MAX_MESSAGE_LEN = 150 | |
_MAX_QA_LEN_DIFF = 20 | |
def remove_duplicates_keep_order(lst: List[Conversation]) -> List[Conversation]: | |
lst = list(dict.fromkeys(lst)) # Remove duplicates and keep order | |
return [(list(x[0]), list(x[1]), x[2]) for x in lst] # Tuples are only needed for hashability | |
def remove_answers_with_only_special_symbols(lst: List[Conversation]) -> List[Conversation]: | |
return [i for i in lst if re.findall(r"[а-я]", " ".join(i[1]))] | |
def remove_long_qa(lst: List[Conversation]) -> List[Conversation]: | |
return [i for i in lst if len(i[0]) <= _MAX_MESSAGE_LEN and len(i[1]) <= _MAX_MESSAGE_LEN] | |
def remove_unbalanced_qa(lst: List[Conversation]) -> List[Conversation]: | |
return [i for i in lst if abs(len(i[0]) - len(i[1])) <= _MAX_QA_LEN_DIFF] | |
def normalize_conversations(lst: List[Conversation]) -> List[Conversation]: | |
lst = remove_duplicates_keep_order(lst) | |
lst = remove_answers_with_only_special_symbols(lst) | |
lst = remove_long_qa(lst) | |
lst = remove_unbalanced_qa(lst) | |
return lst | |
conversations: List[Conversation] = [] | |
questions: Dict[MessageText, int] = dict() | |
messages_values = list(messages.values()) # TODO: try changing this cast to something more applicable | |
for i in range(len(messages) - 1): # There's no answer for last message so add -1 | |
try: # Message is answer for message with `id` of `reply_to_id` | |
prev_message = messages[messages_values[i+1]["reply_to_id"]] | |
except KeyError: | |
prev_message = messages_values[i] | |
qa = (prev_message["text"], messages_values[i+1]["text"], prev_message["from"]) | |
if qa[0] in questions.keys(): # If there are multiple answers for same message, choose the longest one | |
if len(conversations[questions[qa[0]]][1]) < len(qa[1]) and abs(len(conversations[questions[qa[0]]][1]) - len(qa[1])) <= _MAX_QA_LEN_DIFF: | |
conversations[questions[qa[0]]] = (qa[0], qa[1], qa[2]) | |
continue | |
else: | |
questions[qa[0]] = len(conversations) | |
conversations.append(qa) | |
conversations = normalize_conversations(conversations) | |
output_path = save_to / "train_dataset.txt" | |
with open(output_path, "w", encoding="utf-8") as file: | |
for conversation in conversations: | |
line = "<user> " + conversation[2] + " <says> " + " ".join(conversation[0]) + f" {tokenizer.eos_token} <response> " + " ".join(conversation[1]) + f" {tokenizer.eos_token}" + "\n" | |
file.write(line) | |
return output_path | |
def normalize(text: Union[str, List]) -> Tuple[str]: | |
if isinstance(text, List): | |
text = " ".join([word for word in text if isinstance(word, str)]) | |
text = text.lower().strip() | |
text = re.sub(r"[^а-яё.!?:\d]+", r" ", text) # Leave only russian and special characters | |
text = re.sub(r'\.(\s*\.)+', '... ', text) # Replace any sequence of 2+ dots with '...' | |
text = re.sub(r'([?!])(\s*\1)+', r'\1 ', text) # Collapse repeating ? or ! | |
text = re.sub(r"([!?]|\.+)", r"\1 ", text) # Separate special symbols by whitespaces | |
text = re.sub(r"ё", r"е", text) | |
text = re.sub(r"(.*[ауспэиычвекьхъз]{6,}.*|\b[апх][аеписх]{2,3}\b|\b[ах]{2,}\b)", r" <laugh> ", text) # Laugh token for strings such as `ахах` etc. | |
text = re.sub(r"(<laugh>)(\s*\1)+", r" <laugh> ", text) # Collapse repeating <laugh> tokens | |
text = re.sub(r"\s+", r" ", text).strip() # Leave only one space between each word | |
return tuple(text.split()) | |
def normalize_username(text: str) -> Tuple[str]: | |
text = text.lower() | |
text = re.sub(r"[^а-яa-z\s]+", "", text).strip() | |
return text |