WonderWaffle4 commited on
Commit
59c6d5c
·
1 Parent(s): 3328eb8

GPT-2 bot with transformers and bot server

Browse files
.gitignore CHANGED
@@ -206,5 +206,5 @@ marimo/_static/
206
  marimo/_lsp/
207
  __marimo__/
208
 
209
- models/seq2seq/data/train/*
210
- models/seq2seq/checkpoint/*
 
206
  marimo/_lsp/
207
  __marimo__/
208
 
209
+ models/*/data/*
210
+ models/*/checkpoint/*
main.py → bot_local.py RENAMED
@@ -1,90 +1,85 @@
1
- import time
2
- from typing import Final
3
- import requests
4
- import re
5
- from telegram import Update
6
- from telegram.ext import Application, MessageHandler, filters, ContextTypes
7
- from typing import Optional
8
- import random
9
- import os
10
- from dotenv import load_dotenv
11
- from models.seq2seq.model import Seq2SeqChatbot
12
- import torch
13
-
14
- load_dotenv()
15
-
16
- TOKEN: Final = os.environ.get("TOKEN")
17
- BOT_USERNAME: Final = os.environ.get("BOT_USERNAME")
18
- CHAT_ID: Final = int(os.environ.get("CHAT_ID"))
19
-
20
- CHECKPOINT_PATH: Final = "models/seq2seq/checkpoint/150_checkpoint.tar"
21
-
22
- romantiki_gif_id = "CgACAgIAAxkBAAE4zMlojLmMwqrxG5e2rnYS2f9_PZZgVwACL2oAAjbWyUqiyR5II6u6YDYE"
23
- bezumtsi_gif_id = "CgACAgIAAxkBAAE4zMtojLmiH_CGW5cT7G0QVXHR7D4g6wAC53UAApkBmEmM-VxqunRc6zYE"
24
-
25
- last_gif_sent = 1.0
26
- gif_sent_cooldown = 180.0
27
-
28
- torch.manual_seed(0)
29
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
-
31
- chatbot = Seq2SeqChatbot(500, 8856, 2, 2, 0.1, device)
32
- chatbot.load_checkpoint(CHECKPOINT_PATH)
33
- chatbot.eval_mode()
34
-
35
-
36
- def handle_response(text: str) -> Optional[str]:
37
- response_chance = 0.02
38
- if random.random() < response_chance:
39
- return chatbot(text)
40
- return None
41
-
42
-
43
- def edit_response(text: Optional[str]) -> Optional[str]:
44
- if text is None:
45
- return None
46
- text = re.sub(r'\s+([,.!?;])\s+', r'\1 ', text)
47
- return text
48
-
49
-
50
- async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
51
- if update.message.chat_id == CHAT_ID:
52
- # response: Optional[str] = ""
53
- global last_gif_sent
54
- if "роман" in update.message.text.lower() and \
55
- time.time() - last_gif_sent >= gif_sent_cooldown:
56
- await context.bot.send_animation( chat_id=update.message.chat_id, animation=romantiki_gif_id)
57
- last_gif_sent = time.time()
58
- elif "безу" in update.message.text.lower() and \
59
- time.time() - last_gif_sent >= gif_sent_cooldown:
60
- await context.bot.send_animation(chat_id=update.message.chat_id, animation=bezumtsi_gif_id)
61
- last_gif_sent = time.time()
62
- else:
63
- text = update.message.text.replace(BOT_USERNAME, '').strip().lower()
64
- response = edit_response(handle_response(text))
65
- if response:
66
- await context.bot.sendMessage(update.message.chat_id, response, reply_to_message_id=update.message.id)
67
-
68
-
69
- async def error(update: Update, context: ContextTypes.DEFAULT_TYPE):
70
- print(f"{update.message.from_user.username} in {update.message.chat.type} "
71
- f"chat caused error \"{context.error}\"\n"
72
- f"{update}\"")
73
-
74
-
75
- def main() -> None:
76
- """Run the bot."""
77
- requests.post(f"https://api.telegram.org/bot{TOKEN}/getUpdates?offset=-1")
78
-
79
- application = Application.builder().token(TOKEN).build()
80
-
81
- application.add_handler(MessageHandler(filters.TEXT, handle_message))
82
- application.add_error_handler(error)
83
-
84
- application.run_polling(allowed_updates=Update.ALL_TYPES)
85
-
86
-
87
- if __name__ == '__main__':
88
- print("Running main...")
89
- # print(chatbot("test"))
90
- main()
 
1
+ import time
2
+ from typing import Final
3
+ import re
4
+ from telegram import Update
5
+ from telegram.ext import Application, MessageHandler, filters, ContextTypes
6
+ from typing import Optional
7
+ import random
8
+ import os
9
+ from dotenv import load_dotenv
10
+ from models.seq2seq.model import Seq2SeqChatbot
11
+ import torch
12
+
13
+ load_dotenv()
14
+
15
+ TOKEN: Final = os.environ.get("TOKEN")
16
+ BOT_USERNAME: Final = os.environ.get("BOT_USERNAME")
17
+ CHAT_ID: Final = int(os.environ.get("CHAT_ID"))
18
+
19
+ CHECKPOINT_PATH: Final = "models/seq2seq/checkpoint/150_checkpoint.tar"
20
+
21
+ ROMANTIKI_GIF_ID: Final = "CgACAgIAAxkBAAE4zMlojLmMwqrxG5e2rnYS2f9_PZZgVwACL2oAAjbWyUqiyR5II6u6YDYE"
22
+ BEZUMTSI_GIF_ID: Final = "CgACAgIAAxkBAAE4zMlojLmMwqrxG5e2rnYS2f9_PZZgVwACL2oAAjbWyUqiyR5II6u6YDYE"
23
+
24
+ last_gif_sent = 1.0
25
+ gif_sent_cooldown = 180.0
26
+ response_chance = 1.0
27
+
28
+ torch.manual_seed(0)
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+
31
+ chatbot = Seq2SeqChatbot(500, 8856, 2, 2, 0.1, device)
32
+ chatbot.load_checkpoint(CHECKPOINT_PATH)
33
+ chatbot.eval_mode()
34
+
35
+
36
+ def handle_response(text: str) -> Optional[str]:
37
+ if random.random() < response_chance:
38
+ return chatbot(text)
39
+ return None
40
+
41
+
42
+ def edit_response(text: Optional[str]) -> Optional[str]:
43
+ if text is None:
44
+ return None
45
+ text = re.sub(r'\s+([,.!?;])\s+', r'\1 ', text)
46
+ return text
47
+
48
+
49
+ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
50
+ if update.message.chat_id == CHAT_ID:
51
+ global last_gif_sent
52
+ if "роман" in update.message.text.lower() and \
53
+ time.time() - last_gif_sent >= gif_sent_cooldown:
54
+ await context.bot.send_animation( chat_id=update.message.chat_id, animation=ROMANTIKI_GIF_ID)
55
+ last_gif_sent = time.time()
56
+ elif "безу" in update.message.text.lower() and \
57
+ time.time() - last_gif_sent >= gif_sent_cooldown:
58
+ await context.bot.send_animation(chat_id=update.message.chat_id, animation=BEZUMTSI_GIF_ID)
59
+ last_gif_sent = time.time()
60
+ else:
61
+ text = update.message.text.replace(BOT_USERNAME, '').strip().lower()
62
+ response = edit_response(handle_response(text))
63
+ if response:
64
+ await context.bot.sendMessage(update.message.chat_id, response, reply_to_message_id=update.message.id)
65
+
66
+
67
+ async def error(update: Update, context: ContextTypes.DEFAULT_TYPE):
68
+ print(f"{update.message.from_user.username} in {update.message.chat.type} "
69
+ f"chat caused error \"{context.error}\"\n"
70
+ f"{update}\"")
71
+
72
+
73
+ def main() -> None:
74
+ """Run the bot."""
75
+ application = Application.builder().token(TOKEN).build()
76
+
77
+ application.add_handler(MessageHandler(filters.TEXT, handle_message))
78
+ application.add_error_handler(error)
79
+
80
+ application.run_polling(allowed_updates=Update.ALL_TYPES, drop_pending_updates=True)
81
+
82
+
83
+ if __name__ == '__main__':
84
+ print("Running main...")
85
+ main()
 
 
 
 
 
bot_server.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Final
3
+ import re
4
+ from telegram import Update
5
+ from telegram.ext import Application, MessageHandler, filters, ContextTypes
6
+ from typing import Optional
7
+ import random
8
+ import os
9
+ import requests
10
+ from dotenv import load_dotenv
11
+ import requests
12
+
13
+ load_dotenv()
14
+
15
+ TOKEN: Final = os.environ.get("TOKEN")
16
+ BOT_USERNAME: Final = os.environ.get("BOT_USERNAME")
17
+ CHAT_ID: Final = int(os.environ.get("CHAT_ID"))
18
+
19
+ CHECKPOINT_PATH: Final = "models/seq2seq/checkpoint/150_checkpoint.tar"
20
+
21
+ ROMANTIKI_GIF_ID: Final = "CgACAgIAAxkBAAE4zMlojLmMwqrxG5e2rnYS2f9_PZZgVwACL2oAAjbWyUqiyR5II6u6YDYE"
22
+ BEZUMTSI_GIF_ID: Final = "CgACAgIAAxkBAAE4zMlojLmMwqrxG5e2rnYS2f9_PZZgVwACL2oAAjbWyUqiyR5II6u6YDYE"
23
+
24
+ last_gif_sent = 1.0
25
+ gif_sent_cooldown = 180.0
26
+ response_chance = 1.0
27
+
28
+ def handle_response(author: str, content: str) -> Optional[str]:
29
+ if random.random() < response_chance:
30
+ return requests.post("http://localhost:8000/generate", json={"author": author, "content": content + " "}).json()["response"]
31
+ return None
32
+
33
+
34
+ def edit_response(text: Optional[str]) -> Optional[str]:
35
+ if text is None:
36
+ return None
37
+ # text = re.sub(r'\s+([,.!?;])\s+', r'\1 ', text)
38
+ return text
39
+
40
+
41
+ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
42
+ if update.message.chat_id == CHAT_ID:
43
+ global last_gif_sent
44
+ if "роман" in update.message.text.lower() and \
45
+ time.time() - last_gif_sent >= gif_sent_cooldown:
46
+ await context.bot.send_animation( chat_id=update.message.chat_id, animation=ROMANTIKI_GIF_ID)
47
+ last_gif_sent = time.time()
48
+ elif "безу" in update.message.text.lower() and \
49
+ time.time() - last_gif_sent >= gif_sent_cooldown:
50
+ await context.bot.send_animation(chat_id=update.message.chat_id, animation=BEZUMTSI_GIF_ID)
51
+ last_gif_sent = time.time()
52
+ else:
53
+ author = ""
54
+ first_name = update.message.from_user.first_name
55
+ last_name = update.message.from_user.last_name
56
+ if first_name:
57
+ author += first_name
58
+ if last_name:
59
+ author += f" {last_name}"
60
+ content = update.message.text.replace(BOT_USERNAME, '').strip().lower()
61
+
62
+ # response = edit_response(handle_response(author, content))
63
+ response = handle_response(author, content)
64
+ print(response)
65
+ if response:
66
+ await context.bot.sendMessage(update.message.chat_id, response, reply_to_message_id=update.message.id)
67
+
68
+
69
+ async def error(update: Update, context: ContextTypes.DEFAULT_TYPE):
70
+ print(f"{update.message.from_user.username} in {update.message.chat.type} "
71
+ f"chat caused error \"{context.error}\"\n"
72
+ f"{update}\"")
73
+
74
+
75
+ def main() -> None:
76
+ """Run the bot."""
77
+ application = Application.builder().token(TOKEN).build()
78
+
79
+ application.add_handler(MessageHandler(filters.TEXT, handle_message))
80
+ application.add_error_handler(error)
81
+
82
+ application.run_polling(allowed_updates=Update.ALL_TYPES, drop_pending_updates=True)
83
+
84
+
85
+ if __name__ == '__main__':
86
+ print("Running main...")
87
+ # print(chatbot("test"))
88
+ main()
models/seq2seq/attention.py CHANGED
@@ -3,6 +3,7 @@ import torch.nn.functional as F
3
  import torch.nn as nn
4
  from .custom_types import Method
5
 
 
6
  class LuongAttention(nn.Module):
7
  def __init__(self, method: Method, hidden_size: int):
8
  super().__init__()
 
3
  import torch.nn as nn
4
  from .custom_types import Method
5
 
6
+
7
  class LuongAttention(nn.Module):
8
  def __init__(self, method: Method, hidden_size: int):
9
  super().__init__()
models/seq2seq/chat_dataset.py CHANGED
@@ -5,10 +5,11 @@ from collections import OrderedDict
5
  from .vocab import Vocab
6
  from .custom_types import Message, MessageId, Conversation
7
  from torch.nn.utils.rnn import pad_sequence
8
- from .custom_types import Token
9
  import re
10
  import json
11
 
 
12
  class ChatDataset(data.Dataset):
13
  def __init__(self, path: str, max_message_count: int = None, batch_size=5):
14
  super().__init__()
@@ -32,7 +33,7 @@ class ChatDataset(data.Dataset):
32
  batches_X.append(pad_sequence([self.vocab.sentence_indices(conversations[i+j][0] + ["<eos>"]) for j in range(self.batch_size) if i+j < len(conversations)], batch_first=True, padding_value=0))
33
  batches_y.append(pad_sequence([self.vocab.sentence_indices(conversations[i+j][1] + ["<eos>"]) for j in range(self.batch_size) if i+j < len(conversations)], batch_first=True, padding_value=0))
34
  lengths.append(torch.tensor([len(conversations[i+j][0]) for j in range(self.batch_size) if i+j < len(conversations)]))
35
- mask.append(batches_y[-1] != Token.PAD_TOKEN.value)
36
  return batches_X, batches_y, lengths, mask
37
 
38
  @classmethod
 
5
  from .vocab import Vocab
6
  from .custom_types import Message, MessageId, Conversation
7
  from torch.nn.utils.rnn import pad_sequence
8
+ from .constants import PAD_TOKEN
9
  import re
10
  import json
11
 
12
+
13
  class ChatDataset(data.Dataset):
14
  def __init__(self, path: str, max_message_count: int = None, batch_size=5):
15
  super().__init__()
 
33
  batches_X.append(pad_sequence([self.vocab.sentence_indices(conversations[i+j][0] + ["<eos>"]) for j in range(self.batch_size) if i+j < len(conversations)], batch_first=True, padding_value=0))
34
  batches_y.append(pad_sequence([self.vocab.sentence_indices(conversations[i+j][1] + ["<eos>"]) for j in range(self.batch_size) if i+j < len(conversations)], batch_first=True, padding_value=0))
35
  lengths.append(torch.tensor([len(conversations[i+j][0]) for j in range(self.batch_size) if i+j < len(conversations)]))
36
+ mask.append(batches_y[-1] != PAD_TOKEN)
37
  return batches_X, batches_y, lengths, mask
38
 
39
  @classmethod
models/seq2seq/constants.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ PAD_TOKEN = 0
2
+ BOS_TOKEN = 1
3
+ EOS_TOKEN = 2
4
+ UNK_TOKEN = 3
models/seq2seq/custom_types.py CHANGED
@@ -1,6 +1,7 @@
1
  from typing import List, TypedDict, NotRequired, Tuple
2
  from enum import Enum, auto
3
 
 
4
  MessageId = int
5
  MessageText = List[str]
6
  Conversation = Tuple[MessageText]
@@ -14,10 +15,4 @@ class Message(TypedDict):
14
  class Method(Enum):
15
  DOT = auto()
16
  GENERAL = auto()
17
- CONCAT = auto()
18
-
19
- class Token(Enum):
20
- PAD_TOKEN = 0
21
- BOS_TOKEN = 1
22
- EOS_TOKEN = 2
23
- UNK_TOKEN = 3
 
1
  from typing import List, TypedDict, NotRequired, Tuple
2
  from enum import Enum, auto
3
 
4
+
5
  MessageId = int
6
  MessageText = List[str]
7
  Conversation = Tuple[MessageText]
 
15
  class Method(Enum):
16
  DOT = auto()
17
  GENERAL = auto()
18
+ CONCAT = auto()
 
 
 
 
 
 
models/seq2seq/model.py CHANGED
@@ -5,7 +5,8 @@ import torch.nn as nn
5
  import torch.optim as optim
6
  from .chat_dataset import ChatDataset
7
  from .attention import LuongAttention
8
- from .custom_types import Method, Token
 
9
  from .vocab import Vocab
10
  from .searchers import GreedySearch
11
  import os
@@ -108,7 +109,7 @@ class Seq2SeqChatbot(nn.Module):
108
  encoder_outputs, hidden = self.encoder(x_train, x_lengths) # Output shape: (batch_size, max_len_in_batch, hidden_size)
109
  hidden = hidden[:self.decoder_num_layers]
110
  loss = 0
111
- decoder_input = torch.LongTensor([[Token.BOS_TOKEN.value] for _ in range(y_train.shape[0])])
112
  decoder_input = decoder_input.to(device)
113
  use_teacher_forcing = random.random() < teacher_forcing_ratio
114
  if use_teacher_forcing:
 
5
  import torch.optim as optim
6
  from .chat_dataset import ChatDataset
7
  from .attention import LuongAttention
8
+ from .custom_types import Method
9
+ from .constants import BOS_TOKEN
10
  from .vocab import Vocab
11
  from .searchers import GreedySearch
12
  import os
 
109
  encoder_outputs, hidden = self.encoder(x_train, x_lengths) # Output shape: (batch_size, max_len_in_batch, hidden_size)
110
  hidden = hidden[:self.decoder_num_layers]
111
  loss = 0
112
+ decoder_input = torch.LongTensor([[BOS_TOKEN] for _ in range(y_train.shape[0])])
113
  decoder_input = decoder_input.to(device)
114
  use_teacher_forcing = random.random() < teacher_forcing_ratio
115
  if use_teacher_forcing:
models/seq2seq/requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ colorama==0.4.6
2
+ filelock==3.18.0
3
+ fsspec==2025.7.0
4
+ Jinja2==3.1.6
5
+ MarkupSafe==3.0.2
6
+ mpmath==1.3.0
7
+ networkx==3.5
8
+ numpy==2.3.2
9
+ setuptools==80.9.0
10
+ sympy==1.14.0
11
+ torch==2.7.1
12
+ tqdm==4.67.1
13
+ typing_extensions==4.14.1
models/seq2seq/searchers.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
- from .custom_types import Token
 
4
 
5
  class GreedySearch(nn.Module):
6
  def __init__(self, encoder, decoder, embedding, device):
@@ -13,7 +14,7 @@ class GreedySearch(nn.Module):
13
  def forward(self, x, input_length, max_length):
14
  encoder_outputs, hidden = self.encoder(x, input_length)
15
  decoder_hidden = hidden[:self.decoder.num_layers]
16
- decoder_input = torch.ones(1, 1, device=self.device, dtype=torch.long) * Token.BOS_TOKEN.value
17
  all_tokens = torch.zeros([0], device=self.device, dtype=torch.long)
18
  all_scores = torch.zeros([0], device=self.device)
19
 
 
1
  import torch
2
  import torch.nn as nn
3
+ from .constants import BOS_TOKEN
4
+
5
 
6
  class GreedySearch(nn.Module):
7
  def __init__(self, encoder, decoder, embedding, device):
 
14
  def forward(self, x, input_length, max_length):
15
  encoder_outputs, hidden = self.encoder(x, input_length)
16
  decoder_hidden = hidden[:self.decoder.num_layers]
17
+ decoder_input = torch.ones(1, 1, device=self.device, dtype=torch.long) * BOS_TOKEN
18
  all_tokens = torch.zeros([0], device=self.device, dtype=torch.long)
19
  all_scores = torch.zeros([0], device=self.device)
20
 
models/seq2seq/vocab.py CHANGED
@@ -1,14 +1,14 @@
1
  import torch
2
  import torch.nn as nn
3
  from typing import List, Dict
4
- from .custom_types import Token
5
 
6
 
7
  class Vocab(nn.Module):
8
  def __init__(self, messages: List[Dict]):
9
  super().__init__()
10
- self.word2index: Dict[str, int] = {"<pad>": Token.PAD_TOKEN.value, "<bos>": Token.BOS_TOKEN.value, "<eos>": Token.EOS_TOKEN.value, "<unk>": Token.UNK_TOKEN.value}
11
- self.index2word: Dict[int, str] = {Token.PAD_TOKEN.value: "<pad>", Token.BOS_TOKEN.value: "<bos>", Token.EOS_TOKEN.value: "<eos>", Token.UNK_TOKEN.value: "<unk>"}
12
  self.word_count: Dict[str, int] = dict()
13
  self.size = 4
14
 
@@ -33,7 +33,7 @@ class Vocab(nn.Module):
33
  def sentence_indices(self, sentence: List[str]) -> torch.LongTensor:
34
  indices = torch.LongTensor(len(sentence))
35
  for i, word in enumerate(sentence):
36
- indices[i] = self.word2index[word] if word in self.word2index else Token.UNK_TOKEN.value
37
  return indices
38
 
39
  def forward(self, indices: torch.LongTensor):
 
1
  import torch
2
  import torch.nn as nn
3
  from typing import List, Dict
4
+ from .constants import PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN
5
 
6
 
7
  class Vocab(nn.Module):
8
  def __init__(self, messages: List[Dict]):
9
  super().__init__()
10
+ self.word2index: Dict[str, int] = {"<pad>": PAD_TOKEN, "<bos>": BOS_TOKEN, "<eos>": EOS_TOKEN, "<unk>": UNK_TOKEN}
11
+ self.index2word: Dict[int, str] = {PAD_TOKEN: "<pad>", BOS_TOKEN: "<bos>", EOS_TOKEN: "<eos>", UNK_TOKEN: "<unk>"}
12
  self.word_count: Dict[str, int] = dict()
13
  self.size = 4
14
 
 
33
  def sentence_indices(self, sentence: List[str]) -> torch.LongTensor:
34
  indices = torch.LongTensor(len(sentence))
35
  for i, word in enumerate(sentence):
36
+ indices[i] = self.word2index[word] if word in self.word2index else UNK_TOKEN
37
  return indices
38
 
39
  def forward(self, indices: torch.LongTensor):
models/transformer/__init__.py ADDED
File without changes
models/transformer/constants.py ADDED
@@ -0,0 +1 @@
 
 
1
+ CHECKPOINT_PATH = "models/transformer/checkpoint/"
models/transformer/custom_types.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, NotRequired, Tuple, TypedDict, Union
2
+
3
+
4
+ MessageId = int
5
+ MessageText = Union[List[str], Tuple[str]]
6
+ Conversation = Tuple[MessageText]
7
+
8
+
9
+ class Message(TypedDict):
10
+ id: MessageId
11
+ text: MessageText
12
+ reply_to_id: NotRequired[int]
models/transformer/fine_tuner.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from .utils import modified_tokenizer
3
+ from .telegram_data_extractor import TelegramDataExtractor
4
+ from transformers import GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling
5
+ from datasets import load_dataset
6
+ from .constants import CHECKPOINT_PATH
7
+
8
+
9
+ class FineTuner:
10
+ def __init__(self,
11
+ model_name="ai-forever/rugpt3small_based_on_gpt2",
12
+ cache_dir="model_cache",
13
+ data_path=CHECKPOINT_PATH):
14
+ self.data_path = Path(data_path)
15
+
16
+ # Инициализация токенизатора и модели
17
+ self.tokenizer = modified_tokenizer(model_name, cache_dir, self.data_path)
18
+ self.model = GPT2LMHeadModel.from_pretrained(model_name, cache_dir=str(self.data_path / cache_dir))
19
+
20
+ def prepare_data(self):
21
+ """
22
+ Подготовка данных для обучения
23
+ """
24
+ messages = TelegramDataExtractor.load_messages_from_json("/kaggle/input/chat-history/chat_history_small.json")
25
+ dataset_path = TelegramDataExtractor.conversations_from_messages(self.data_path, self.tokenizer, messages)
26
+ return dataset_path
27
+
28
+ def fine_tune(self,
29
+ dataset_path,
30
+ output_name='fine_tuned_model',
31
+ num_train_epochs=10,
32
+ per_device_train_batch_size=8,
33
+ learning_rate=5e-5,
34
+ save_steps=10_000):
35
+ """
36
+ Дообучение модели на заданном датасете.
37
+ """
38
+ dataset = load_dataset("text", data_files={"train": "train_dataset.txt"})
39
+
40
+ def preprocess(example):
41
+ # Tokenize while preserving structure
42
+ return self.tokenizer(example["text"], truncation=True, max_length=300)
43
+
44
+ train_dataset = dataset.map(preprocess, batched=True)["train"]
45
+
46
+ data_collator = DataCollatorForLanguageModeling(
47
+ tokenizer=self.tokenizer, mlm=False
48
+ )
49
+
50
+ training_args = TrainingArguments(
51
+ output_dir=str(self.data_path / output_name),
52
+ overwrite_output_dir=True,
53
+ num_train_epochs=num_train_epochs,
54
+ per_device_train_batch_size=per_device_train_batch_size,
55
+ # fp16=True,
56
+ # gradient_accumulation_steps=2,
57
+ save_steps=save_steps,
58
+ learning_rate=learning_rate,
59
+ torch_compile=True,
60
+ save_total_limit=2,
61
+ logging_dir=str(self.data_path / 'logs'),
62
+ report_to="none"
63
+ )
64
+
65
+ trainer = Trainer(
66
+ model=self.model,
67
+ args=training_args,
68
+ data_collator=data_collator,
69
+ train_dataset=train_dataset,
70
+ )
71
+
72
+ trainer.train()
73
+ # Сохранение обученной модели и токенизатора
74
+ self.model.save_pretrained(str(self.data_path / output_name))
75
+ self.tokenizer.save_pretrained(str(self.data_path / output_name))
models/transformer/requirements.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.6.1
2
+ aiohttp==3.12.15
3
+ aiosignal==1.4.0
4
+ attrs==25.3.0
5
+ certifi==2025.8.3
6
+ charset-normalizer==3.4.2
7
+ colorama==0.4.6
8
+ datasets==4.0.0
9
+ dill==0.3.8
10
+ filelock==3.18.0
11
+ frozenlist==1.7.0
12
+ fsspec==2025.3.0
13
+ huggingface-hub==0.34.3
14
+ idna==3.10
15
+ multidict==6.6.3
16
+ multiprocess==0.70.16
17
+ numpy==2.3.2
18
+ packaging==25.0
19
+ pandas==2.3.1
20
+ propcache==0.3.2
21
+ pyarrow==21.0.0
22
+ python-dateutil==2.9.0.post0
23
+ pytz==2025.2
24
+ PyYAML==6.0.2
25
+ regex==2025.7.34
26
+ requests==2.32.4
27
+ safetensors==0.5.3
28
+ six==1.17.0
29
+ tokenizers==0.21.4
30
+ tqdm==4.67.1
31
+ transformers==4.54.1
32
+ typing_extensions==4.14.1
33
+ tzdata==2025.2
34
+ urllib3==2.5.0
35
+ xxhash==3.5.0
36
+ yarl==1.20.1
models/transformer/telegram_data_extractor.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Dict, OrderedDict, Tuple, Union
3
+ from .custom_types import MessageId, Message, Conversation, MessageText
4
+ from typing import List
5
+ import re
6
+ import json
7
+
8
+
9
+ class TelegramDataExtractor:
10
+ @classmethod
11
+ def load_messages_from_json(cls, path: str, max_message_count: int = None) -> OrderedDict[MessageId, Message]:
12
+ messages: OrderedDict[MessageId, Message] = OrderedDict()
13
+ with open(path, "r", encoding="utf-8") as file:
14
+ chat_json = json.load(file)
15
+ for i, message in enumerate(chat_json["messages"]):
16
+ if max_message_count and i == max_message_count:
17
+ break
18
+ if message["type"] != "message":
19
+ continue
20
+ new_message = {
21
+ "from": cls.normalize_username(message["from"]),
22
+ "id": message["id"],
23
+ "text": cls.normalize(message["text"])
24
+ }
25
+ if not new_message["text"]: # Check for empty message
26
+ continue
27
+ if "reply_to_message_id" in message.keys():
28
+ new_message["reply_to_id"] = message["reply_to_message_id"]
29
+
30
+ messages[new_message["id"]] = new_message
31
+ return messages
32
+
33
+ @staticmethod
34
+ def conversations_from_messages(save_to: Path, tokenizer, messages: OrderedDict[MessageId, Message]) -> List[Conversation]:
35
+ _MAX_MESSAGE_LEN = 150
36
+ _MAX_QA_LEN_DIFF = 20
37
+ def remove_duplicates_keep_order(lst: List[Conversation]) -> List[Conversation]:
38
+ lst = list(dict.fromkeys(lst)) # Remove duplicates and keep order
39
+ return [(list(x[0]), list(x[1]), x[2]) for x in lst] # Tuples are only needed for hashability
40
+ def remove_answers_with_only_special_symbols(lst: List[Conversation]) -> List[Conversation]:
41
+ return [i for i in lst if re.findall(r"[а-я]", " ".join(i[1]))]
42
+ def remove_long_qa(lst: List[Conversation]) -> List[Conversation]:
43
+ return [i for i in lst if len(i[0]) <= _MAX_MESSAGE_LEN and len(i[1]) <= _MAX_MESSAGE_LEN]
44
+ def remove_unbalanced_qa(lst: List[Conversation]) -> List[Conversation]:
45
+ return [i for i in lst if abs(len(i[0]) - len(i[1])) <= _MAX_QA_LEN_DIFF]
46
+ def normalize_conversations(lst: List[Conversation]) -> List[Conversation]:
47
+ lst = remove_duplicates_keep_order(lst)
48
+ lst = remove_answers_with_only_special_symbols(lst)
49
+ lst = remove_long_qa(lst)
50
+ lst = remove_unbalanced_qa(lst)
51
+ return lst
52
+
53
+ conversations: List[Conversation] = []
54
+ questions: Dict[MessageText, int] = dict()
55
+
56
+ messages_values = list(messages.values()) # TODO: try changing this cast to something more applicable
57
+ for i in range(len(messages) - 1): # There's no answer for last message so add -1
58
+ try: # Message is answer for message with `id` of `reply_to_id`
59
+ prev_message = messages[messages_values[i+1]["reply_to_id"]]
60
+ except KeyError:
61
+ prev_message = messages_values[i]
62
+ qa = (prev_message["text"], messages_values[i+1]["text"], prev_message["from"])
63
+ if qa[0] in questions.keys(): # If there are multiple answers for same message, choose the longest one
64
+ if len(conversations[questions[qa[0]]][1]) < len(qa[1]) and abs(len(conversations[questions[qa[0]]][1]) - len(qa[1])) <= _MAX_QA_LEN_DIFF:
65
+ conversations[questions[qa[0]]] = (qa[0], qa[1], qa[2])
66
+ continue
67
+ else:
68
+ questions[qa[0]] = len(conversations)
69
+ conversations.append(qa)
70
+
71
+ conversations = normalize_conversations(conversations)
72
+ output_path = save_to / "train_dataset.txt"
73
+ with open(output_path, "w", encoding="utf-8") as file:
74
+ for conversation in conversations:
75
+ line = "<user> " + conversation[2] + " <says> " + " ".join(conversation[0]) + f" {tokenizer.eos_token} <response> " + " ".join(conversation[1]) + f" {tokenizer.eos_token}" + "\n"
76
+ file.write(line)
77
+ return output_path
78
+
79
+ @staticmethod
80
+ def normalize(text: Union[str, List]) -> Tuple[str]:
81
+ if isinstance(text, List):
82
+ text = " ".join([word for word in text if isinstance(word, str)])
83
+ text = text.lower().strip()
84
+ text = re.sub(r"[^а-яё.!?:\d]+", r" ", text) # Leave only russian and special characters
85
+ text = re.sub(r'\.(\s*\.)+', '... ', text) # Replace any sequence of 2+ dots with '...'
86
+ text = re.sub(r'([?!])(\s*\1)+', r'\1 ', text) # Collapse repeating ? or !
87
+ text = re.sub(r"([!?]|\.+)", r"\1 ", text) # Separate special symbols by whitespaces
88
+ text = re.sub(r"ё", r"е", text)
89
+ text = re.sub(r"(.*[ауспэиычвекьхъз]{6,}.*|\b[апх][аеписх]{2,3}\b|\b[ах]{2,}\b)", r" <laugh> ", text) # Laugh token for strings such as `ахах` etc.
90
+ text = re.sub(r"(<laugh>)(\s*\1)+", r" <laugh> ", text) # Collapse repeating <laugh> tokens
91
+ text = re.sub(r"\s+", r" ", text).strip() # Leave only one space between each word
92
+ return tuple(text.split())
93
+
94
+ @staticmethod
95
+ def normalize_username(text: str) -> Tuple[str]:
96
+ text = text.lower()
97
+ text = re.sub(r"[^а-яa-z\s]+", "", text).strip()
98
+ return text
models/transformer/text_generator.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2LMHeadModel
2
+ from pathlib import Path
3
+ from .utils import modified_tokenizer
4
+ from .constants import CHECKPOINT_PATH
5
+
6
+
7
+ class TextGenerator:
8
+ def __init__(self, model_name='fine_tuned_model', data_path=CHECKPOINT_PATH):
9
+ """
10
+ Инициализация модели и токенизатора.
11
+ Загружаем модель и токенизатор из указанного пути.
12
+ """
13
+ model_path = Path(data_path) / model_name
14
+ self.tokenizer = modified_tokenizer(model_path, None, data_path)
15
+ self.model = GPT2LMHeadModel.from_pretrained(str(model_path), device_map="auto")
16
+ self.model.eval()
17
+
18
+ def generate_text(self,
19
+ author: str,
20
+ input_str: str,
21
+ max_length=100,
22
+ num_return_sequences=1,
23
+ temperature=1.0,
24
+ top_k=0,
25
+ top_p=1.0,
26
+ do_sample=False):
27
+ """
28
+ Генерация текста на основе заданного начального текста (prompt) и параметров.
29
+
30
+ Параметры:
31
+ - input: Входная последовательность.
32
+ - max_length: Максимальная длина сгенерированного текста.
33
+ - num_return_sequences: Количество возвращаемых последовательностей.
34
+ - temperature: Контролирует разнообразие вывода.
35
+ - top_k: Если больше 0, ограничивает количество слов для выборки только k наиболее вероятными словами.
36
+ - top_p: Если меньше 1.0, применяется nucleus sampling.
37
+ - do_sample: Если True, включает случайную выборку для увеличения разнообразия.
38
+ """
39
+ # Формирование prompt
40
+ prompt_text = f"<user> {author} <says> {input_str} {self.tokenizer.eos_token} <response>"
41
+ print(prompt_text)
42
+
43
+ # Кодирование текста в формате, пригодном для модели
44
+ encoded_input = self.tokenizer.encode(prompt_text, return_tensors='pt')
45
+
46
+ # Генерация текстов
47
+ outputs = self.model.generate(
48
+ encoded_input,
49
+ max_length=max_length + len(encoded_input[0]),
50
+ num_return_sequences=num_return_sequences,
51
+ temperature=temperature,
52
+ top_k=top_k,
53
+ top_p=top_p,
54
+ do_sample=do_sample,
55
+ no_repeat_ngram_size=2
56
+ )
57
+
58
+ # Декодирование результатов
59
+ all_texts = [self.tokenizer.decode(output, skip_special_tokens=False) for output in outputs]
60
+
61
+ # Удаление входных данных из текстов
62
+ prompt_length = len(self.tokenizer.decode(encoded_input[0], skip_special_tokens=False))
63
+ trimmed_texts = [text[prompt_length:] for text in all_texts]
64
+
65
+ # Возврат результатов в виде словаря
66
+ return {
67
+ "full_texts": all_texts,
68
+ "generated_texts": trimmed_texts
69
+ }
70
+
71
+ if __name__ == "__main__":
72
+ print("OK")
models/transformer/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Tokenizer
2
+ from pathlib import Path
3
+ from .constants import CHECKPOINT_PATH
4
+
5
+
6
+ def modified_tokenizer(model_name="ai-forever/rugpt3small_based_on_gpt2", cache_dir="model_cache", data_path=Path(CHECKPOINT_PATH)):
7
+ if cache_dir:
8
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name, cache_dir=str(data_path / cache_dir))
9
+ else:
10
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
11
+ special_tokens_dict = {
12
+ "additional_special_tokens": [
13
+ "<user>",
14
+ "<says>",
15
+ "<response>"
16
+ ]
17
+ }
18
+ tokenizer.add_special_tokens(special_tokens_dict)
19
+ tokenizer.add_tokens(["<laugh>"])
20
+ return tokenizer
requirements.txt CHANGED
@@ -1,22 +1,22 @@
1
- anyio==4.9.0
2
- certifi==2025.7.14
3
- colorama==0.4.6
4
- filelock==3.18.0
5
- fsspec==2025.7.0
6
  h11==0.16.0
7
  httpcore==1.0.9
8
  httpx==0.28.1
9
  idna==3.10
10
- Jinja2==3.1.6
11
- MarkupSafe==3.0.2
12
- mpmath==1.3.0
13
- networkx==3.5
14
- numpy==2.3.1
15
  python-dotenv==1.1.1
16
  python-telegram-bot==22.3
17
- setuptools==80.9.0
18
  sniffio==1.3.1
19
- sympy==1.14.0
20
- torch==2.7.1
21
- tqdm==4.67.1
22
  typing_extensions==4.14.1
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ anyio==4.10.0
2
+ certifi==2025.8.3
 
 
 
3
  h11==0.16.0
4
  httpcore==1.0.9
5
  httpx==0.28.1
6
  idna==3.10
 
 
 
 
 
7
  python-dotenv==1.1.1
8
  python-telegram-bot==22.3
 
9
  sniffio==1.3.1
 
 
 
10
  typing_extensions==4.14.1
11
+ typing-inspection==0.4.1
12
+ annotated-types==0.7.0
13
+ fastapi==0.116.1
14
+ pydantic==2.11.7
15
+ pydantic_core==2.33.2
16
+ starlette==0.47.2
17
+ charset-normalizer==3.4.2
18
+ requests==2.32.4
19
+ urllib3==2.5.0
20
+ click==8.2.1
21
+ colorama==0.4.6
22
+ uvicorn==0.35.0
server.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from models.transformer.text_generator import TextGenerator
4
+
5
+
6
+ app = FastAPI()
7
+
8
+ generator = TextGenerator(
9
+ model_name='fine_tuned_model_gpt_2',
10
+ )
11
+
12
+
13
+ class Message(BaseModel):
14
+ author: str
15
+ content: str
16
+
17
+
18
+ @app.post("/generate")
19
+ def generate_response(message: Message):
20
+ response = generator.generate_text(
21
+ author=message.author,
22
+ input_str=message.content,
23
+ max_length=100,
24
+ num_return_sequences=1,
25
+ do_sample=True,
26
+ temperature=0.8, # Слегка уменьшаем уверенность
27
+ top_k=100, # Уменьшаем количество рассматриваемых верхних k слов
28
+ top_p=0.95 # Уменьшаем "ядерность" распределения
29
+ )["generated_texts"][0]
30
+ response = response[:response.find("</s>")]
31
+ return { "response": response }