Spaces:
Running
Running
WonderWaffle4
commited on
Commit
·
59c6d5c
1
Parent(s):
3328eb8
GPT-2 bot with transformers and bot server
Browse files- .gitignore +2 -2
- main.py → bot_local.py +85 -90
- bot_server.py +88 -0
- models/seq2seq/attention.py +1 -0
- models/seq2seq/chat_dataset.py +3 -2
- models/seq2seq/constants.py +4 -0
- models/seq2seq/custom_types.py +2 -7
- models/seq2seq/model.py +3 -2
- models/seq2seq/requirements.txt +13 -0
- models/seq2seq/searchers.py +3 -2
- models/seq2seq/vocab.py +4 -4
- models/transformer/__init__.py +0 -0
- models/transformer/constants.py +1 -0
- models/transformer/custom_types.py +12 -0
- models/transformer/fine_tuner.py +75 -0
- models/transformer/requirements.txt +36 -0
- models/transformer/telegram_data_extractor.py +98 -0
- models/transformer/text_generator.py +72 -0
- models/transformer/utils.py +20 -0
- requirements.txt +14 -14
- server.py +31 -0
.gitignore
CHANGED
@@ -206,5 +206,5 @@ marimo/_static/
|
|
206 |
marimo/_lsp/
|
207 |
__marimo__/
|
208 |
|
209 |
-
models
|
210 |
-
models
|
|
|
206 |
marimo/_lsp/
|
207 |
__marimo__/
|
208 |
|
209 |
+
models/*/data/*
|
210 |
+
models/*/checkpoint/*
|
main.py → bot_local.py
RENAMED
@@ -1,90 +1,85 @@
|
|
1 |
-
import time
|
2 |
-
from typing import Final
|
3 |
-
import
|
4 |
-
import
|
5 |
-
from telegram import
|
6 |
-
from
|
7 |
-
|
8 |
-
import
|
9 |
-
import
|
10 |
-
from
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
torch.manual_seed(0)
|
29 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
-
|
31 |
-
chatbot = Seq2SeqChatbot(500, 8856, 2, 2, 0.1, device)
|
32 |
-
chatbot.load_checkpoint(CHECKPOINT_PATH)
|
33 |
-
chatbot.eval_mode()
|
34 |
-
|
35 |
-
|
36 |
-
def handle_response(text: str) -> Optional[str]:
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
if __name__ == '__main__':
|
88 |
-
print("Running main...")
|
89 |
-
# print(chatbot("test"))
|
90 |
-
main()
|
|
|
1 |
+
import time
|
2 |
+
from typing import Final
|
3 |
+
import re
|
4 |
+
from telegram import Update
|
5 |
+
from telegram.ext import Application, MessageHandler, filters, ContextTypes
|
6 |
+
from typing import Optional
|
7 |
+
import random
|
8 |
+
import os
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
from models.seq2seq.model import Seq2SeqChatbot
|
11 |
+
import torch
|
12 |
+
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
+
TOKEN: Final = os.environ.get("TOKEN")
|
16 |
+
BOT_USERNAME: Final = os.environ.get("BOT_USERNAME")
|
17 |
+
CHAT_ID: Final = int(os.environ.get("CHAT_ID"))
|
18 |
+
|
19 |
+
CHECKPOINT_PATH: Final = "models/seq2seq/checkpoint/150_checkpoint.tar"
|
20 |
+
|
21 |
+
ROMANTIKI_GIF_ID: Final = "CgACAgIAAxkBAAE4zMlojLmMwqrxG5e2rnYS2f9_PZZgVwACL2oAAjbWyUqiyR5II6u6YDYE"
|
22 |
+
BEZUMTSI_GIF_ID: Final = "CgACAgIAAxkBAAE4zMlojLmMwqrxG5e2rnYS2f9_PZZgVwACL2oAAjbWyUqiyR5II6u6YDYE"
|
23 |
+
|
24 |
+
last_gif_sent = 1.0
|
25 |
+
gif_sent_cooldown = 180.0
|
26 |
+
response_chance = 1.0
|
27 |
+
|
28 |
+
torch.manual_seed(0)
|
29 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
+
|
31 |
+
chatbot = Seq2SeqChatbot(500, 8856, 2, 2, 0.1, device)
|
32 |
+
chatbot.load_checkpoint(CHECKPOINT_PATH)
|
33 |
+
chatbot.eval_mode()
|
34 |
+
|
35 |
+
|
36 |
+
def handle_response(text: str) -> Optional[str]:
|
37 |
+
if random.random() < response_chance:
|
38 |
+
return chatbot(text)
|
39 |
+
return None
|
40 |
+
|
41 |
+
|
42 |
+
def edit_response(text: Optional[str]) -> Optional[str]:
|
43 |
+
if text is None:
|
44 |
+
return None
|
45 |
+
text = re.sub(r'\s+([,.!?;])\s+', r'\1 ', text)
|
46 |
+
return text
|
47 |
+
|
48 |
+
|
49 |
+
async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
50 |
+
if update.message.chat_id == CHAT_ID:
|
51 |
+
global last_gif_sent
|
52 |
+
if "роман" in update.message.text.lower() and \
|
53 |
+
time.time() - last_gif_sent >= gif_sent_cooldown:
|
54 |
+
await context.bot.send_animation( chat_id=update.message.chat_id, animation=ROMANTIKI_GIF_ID)
|
55 |
+
last_gif_sent = time.time()
|
56 |
+
elif "безу" in update.message.text.lower() and \
|
57 |
+
time.time() - last_gif_sent >= gif_sent_cooldown:
|
58 |
+
await context.bot.send_animation(chat_id=update.message.chat_id, animation=BEZUMTSI_GIF_ID)
|
59 |
+
last_gif_sent = time.time()
|
60 |
+
else:
|
61 |
+
text = update.message.text.replace(BOT_USERNAME, '').strip().lower()
|
62 |
+
response = edit_response(handle_response(text))
|
63 |
+
if response:
|
64 |
+
await context.bot.sendMessage(update.message.chat_id, response, reply_to_message_id=update.message.id)
|
65 |
+
|
66 |
+
|
67 |
+
async def error(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
68 |
+
print(f"{update.message.from_user.username} in {update.message.chat.type} "
|
69 |
+
f"chat caused error \"{context.error}\"\n"
|
70 |
+
f"{update}\"")
|
71 |
+
|
72 |
+
|
73 |
+
def main() -> None:
|
74 |
+
"""Run the bot."""
|
75 |
+
application = Application.builder().token(TOKEN).build()
|
76 |
+
|
77 |
+
application.add_handler(MessageHandler(filters.TEXT, handle_message))
|
78 |
+
application.add_error_handler(error)
|
79 |
+
|
80 |
+
application.run_polling(allowed_updates=Update.ALL_TYPES, drop_pending_updates=True)
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == '__main__':
|
84 |
+
print("Running main...")
|
85 |
+
main()
|
|
|
|
|
|
|
|
|
|
bot_server.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from typing import Final
|
3 |
+
import re
|
4 |
+
from telegram import Update
|
5 |
+
from telegram.ext import Application, MessageHandler, filters, ContextTypes
|
6 |
+
from typing import Optional
|
7 |
+
import random
|
8 |
+
import os
|
9 |
+
import requests
|
10 |
+
from dotenv import load_dotenv
|
11 |
+
import requests
|
12 |
+
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
+
TOKEN: Final = os.environ.get("TOKEN")
|
16 |
+
BOT_USERNAME: Final = os.environ.get("BOT_USERNAME")
|
17 |
+
CHAT_ID: Final = int(os.environ.get("CHAT_ID"))
|
18 |
+
|
19 |
+
CHECKPOINT_PATH: Final = "models/seq2seq/checkpoint/150_checkpoint.tar"
|
20 |
+
|
21 |
+
ROMANTIKI_GIF_ID: Final = "CgACAgIAAxkBAAE4zMlojLmMwqrxG5e2rnYS2f9_PZZgVwACL2oAAjbWyUqiyR5II6u6YDYE"
|
22 |
+
BEZUMTSI_GIF_ID: Final = "CgACAgIAAxkBAAE4zMlojLmMwqrxG5e2rnYS2f9_PZZgVwACL2oAAjbWyUqiyR5II6u6YDYE"
|
23 |
+
|
24 |
+
last_gif_sent = 1.0
|
25 |
+
gif_sent_cooldown = 180.0
|
26 |
+
response_chance = 1.0
|
27 |
+
|
28 |
+
def handle_response(author: str, content: str) -> Optional[str]:
|
29 |
+
if random.random() < response_chance:
|
30 |
+
return requests.post("http://localhost:8000/generate", json={"author": author, "content": content + " "}).json()["response"]
|
31 |
+
return None
|
32 |
+
|
33 |
+
|
34 |
+
def edit_response(text: Optional[str]) -> Optional[str]:
|
35 |
+
if text is None:
|
36 |
+
return None
|
37 |
+
# text = re.sub(r'\s+([,.!?;])\s+', r'\1 ', text)
|
38 |
+
return text
|
39 |
+
|
40 |
+
|
41 |
+
async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
42 |
+
if update.message.chat_id == CHAT_ID:
|
43 |
+
global last_gif_sent
|
44 |
+
if "роман" in update.message.text.lower() and \
|
45 |
+
time.time() - last_gif_sent >= gif_sent_cooldown:
|
46 |
+
await context.bot.send_animation( chat_id=update.message.chat_id, animation=ROMANTIKI_GIF_ID)
|
47 |
+
last_gif_sent = time.time()
|
48 |
+
elif "безу" in update.message.text.lower() and \
|
49 |
+
time.time() - last_gif_sent >= gif_sent_cooldown:
|
50 |
+
await context.bot.send_animation(chat_id=update.message.chat_id, animation=BEZUMTSI_GIF_ID)
|
51 |
+
last_gif_sent = time.time()
|
52 |
+
else:
|
53 |
+
author = ""
|
54 |
+
first_name = update.message.from_user.first_name
|
55 |
+
last_name = update.message.from_user.last_name
|
56 |
+
if first_name:
|
57 |
+
author += first_name
|
58 |
+
if last_name:
|
59 |
+
author += f" {last_name}"
|
60 |
+
content = update.message.text.replace(BOT_USERNAME, '').strip().lower()
|
61 |
+
|
62 |
+
# response = edit_response(handle_response(author, content))
|
63 |
+
response = handle_response(author, content)
|
64 |
+
print(response)
|
65 |
+
if response:
|
66 |
+
await context.bot.sendMessage(update.message.chat_id, response, reply_to_message_id=update.message.id)
|
67 |
+
|
68 |
+
|
69 |
+
async def error(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
70 |
+
print(f"{update.message.from_user.username} in {update.message.chat.type} "
|
71 |
+
f"chat caused error \"{context.error}\"\n"
|
72 |
+
f"{update}\"")
|
73 |
+
|
74 |
+
|
75 |
+
def main() -> None:
|
76 |
+
"""Run the bot."""
|
77 |
+
application = Application.builder().token(TOKEN).build()
|
78 |
+
|
79 |
+
application.add_handler(MessageHandler(filters.TEXT, handle_message))
|
80 |
+
application.add_error_handler(error)
|
81 |
+
|
82 |
+
application.run_polling(allowed_updates=Update.ALL_TYPES, drop_pending_updates=True)
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == '__main__':
|
86 |
+
print("Running main...")
|
87 |
+
# print(chatbot("test"))
|
88 |
+
main()
|
models/seq2seq/attention.py
CHANGED
@@ -3,6 +3,7 @@ import torch.nn.functional as F
|
|
3 |
import torch.nn as nn
|
4 |
from .custom_types import Method
|
5 |
|
|
|
6 |
class LuongAttention(nn.Module):
|
7 |
def __init__(self, method: Method, hidden_size: int):
|
8 |
super().__init__()
|
|
|
3 |
import torch.nn as nn
|
4 |
from .custom_types import Method
|
5 |
|
6 |
+
|
7 |
class LuongAttention(nn.Module):
|
8 |
def __init__(self, method: Method, hidden_size: int):
|
9 |
super().__init__()
|
models/seq2seq/chat_dataset.py
CHANGED
@@ -5,10 +5,11 @@ from collections import OrderedDict
|
|
5 |
from .vocab import Vocab
|
6 |
from .custom_types import Message, MessageId, Conversation
|
7 |
from torch.nn.utils.rnn import pad_sequence
|
8 |
-
from .
|
9 |
import re
|
10 |
import json
|
11 |
|
|
|
12 |
class ChatDataset(data.Dataset):
|
13 |
def __init__(self, path: str, max_message_count: int = None, batch_size=5):
|
14 |
super().__init__()
|
@@ -32,7 +33,7 @@ class ChatDataset(data.Dataset):
|
|
32 |
batches_X.append(pad_sequence([self.vocab.sentence_indices(conversations[i+j][0] + ["<eos>"]) for j in range(self.batch_size) if i+j < len(conversations)], batch_first=True, padding_value=0))
|
33 |
batches_y.append(pad_sequence([self.vocab.sentence_indices(conversations[i+j][1] + ["<eos>"]) for j in range(self.batch_size) if i+j < len(conversations)], batch_first=True, padding_value=0))
|
34 |
lengths.append(torch.tensor([len(conversations[i+j][0]) for j in range(self.batch_size) if i+j < len(conversations)]))
|
35 |
-
mask.append(batches_y[-1] !=
|
36 |
return batches_X, batches_y, lengths, mask
|
37 |
|
38 |
@classmethod
|
|
|
5 |
from .vocab import Vocab
|
6 |
from .custom_types import Message, MessageId, Conversation
|
7 |
from torch.nn.utils.rnn import pad_sequence
|
8 |
+
from .constants import PAD_TOKEN
|
9 |
import re
|
10 |
import json
|
11 |
|
12 |
+
|
13 |
class ChatDataset(data.Dataset):
|
14 |
def __init__(self, path: str, max_message_count: int = None, batch_size=5):
|
15 |
super().__init__()
|
|
|
33 |
batches_X.append(pad_sequence([self.vocab.sentence_indices(conversations[i+j][0] + ["<eos>"]) for j in range(self.batch_size) if i+j < len(conversations)], batch_first=True, padding_value=0))
|
34 |
batches_y.append(pad_sequence([self.vocab.sentence_indices(conversations[i+j][1] + ["<eos>"]) for j in range(self.batch_size) if i+j < len(conversations)], batch_first=True, padding_value=0))
|
35 |
lengths.append(torch.tensor([len(conversations[i+j][0]) for j in range(self.batch_size) if i+j < len(conversations)]))
|
36 |
+
mask.append(batches_y[-1] != PAD_TOKEN)
|
37 |
return batches_X, batches_y, lengths, mask
|
38 |
|
39 |
@classmethod
|
models/seq2seq/constants.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PAD_TOKEN = 0
|
2 |
+
BOS_TOKEN = 1
|
3 |
+
EOS_TOKEN = 2
|
4 |
+
UNK_TOKEN = 3
|
models/seq2seq/custom_types.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from typing import List, TypedDict, NotRequired, Tuple
|
2 |
from enum import Enum, auto
|
3 |
|
|
|
4 |
MessageId = int
|
5 |
MessageText = List[str]
|
6 |
Conversation = Tuple[MessageText]
|
@@ -14,10 +15,4 @@ class Message(TypedDict):
|
|
14 |
class Method(Enum):
|
15 |
DOT = auto()
|
16 |
GENERAL = auto()
|
17 |
-
CONCAT = auto()
|
18 |
-
|
19 |
-
class Token(Enum):
|
20 |
-
PAD_TOKEN = 0
|
21 |
-
BOS_TOKEN = 1
|
22 |
-
EOS_TOKEN = 2
|
23 |
-
UNK_TOKEN = 3
|
|
|
1 |
from typing import List, TypedDict, NotRequired, Tuple
|
2 |
from enum import Enum, auto
|
3 |
|
4 |
+
|
5 |
MessageId = int
|
6 |
MessageText = List[str]
|
7 |
Conversation = Tuple[MessageText]
|
|
|
15 |
class Method(Enum):
|
16 |
DOT = auto()
|
17 |
GENERAL = auto()
|
18 |
+
CONCAT = auto()
|
|
|
|
|
|
|
|
|
|
|
|
models/seq2seq/model.py
CHANGED
@@ -5,7 +5,8 @@ import torch.nn as nn
|
|
5 |
import torch.optim as optim
|
6 |
from .chat_dataset import ChatDataset
|
7 |
from .attention import LuongAttention
|
8 |
-
from .custom_types import Method
|
|
|
9 |
from .vocab import Vocab
|
10 |
from .searchers import GreedySearch
|
11 |
import os
|
@@ -108,7 +109,7 @@ class Seq2SeqChatbot(nn.Module):
|
|
108 |
encoder_outputs, hidden = self.encoder(x_train, x_lengths) # Output shape: (batch_size, max_len_in_batch, hidden_size)
|
109 |
hidden = hidden[:self.decoder_num_layers]
|
110 |
loss = 0
|
111 |
-
decoder_input = torch.LongTensor([[
|
112 |
decoder_input = decoder_input.to(device)
|
113 |
use_teacher_forcing = random.random() < teacher_forcing_ratio
|
114 |
if use_teacher_forcing:
|
|
|
5 |
import torch.optim as optim
|
6 |
from .chat_dataset import ChatDataset
|
7 |
from .attention import LuongAttention
|
8 |
+
from .custom_types import Method
|
9 |
+
from .constants import BOS_TOKEN
|
10 |
from .vocab import Vocab
|
11 |
from .searchers import GreedySearch
|
12 |
import os
|
|
|
109 |
encoder_outputs, hidden = self.encoder(x_train, x_lengths) # Output shape: (batch_size, max_len_in_batch, hidden_size)
|
110 |
hidden = hidden[:self.decoder_num_layers]
|
111 |
loss = 0
|
112 |
+
decoder_input = torch.LongTensor([[BOS_TOKEN] for _ in range(y_train.shape[0])])
|
113 |
decoder_input = decoder_input.to(device)
|
114 |
use_teacher_forcing = random.random() < teacher_forcing_ratio
|
115 |
if use_teacher_forcing:
|
models/seq2seq/requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
colorama==0.4.6
|
2 |
+
filelock==3.18.0
|
3 |
+
fsspec==2025.7.0
|
4 |
+
Jinja2==3.1.6
|
5 |
+
MarkupSafe==3.0.2
|
6 |
+
mpmath==1.3.0
|
7 |
+
networkx==3.5
|
8 |
+
numpy==2.3.2
|
9 |
+
setuptools==80.9.0
|
10 |
+
sympy==1.14.0
|
11 |
+
torch==2.7.1
|
12 |
+
tqdm==4.67.1
|
13 |
+
typing_extensions==4.14.1
|
models/seq2seq/searchers.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
-
from .
|
|
|
4 |
|
5 |
class GreedySearch(nn.Module):
|
6 |
def __init__(self, encoder, decoder, embedding, device):
|
@@ -13,7 +14,7 @@ class GreedySearch(nn.Module):
|
|
13 |
def forward(self, x, input_length, max_length):
|
14 |
encoder_outputs, hidden = self.encoder(x, input_length)
|
15 |
decoder_hidden = hidden[:self.decoder.num_layers]
|
16 |
-
decoder_input = torch.ones(1, 1, device=self.device, dtype=torch.long) *
|
17 |
all_tokens = torch.zeros([0], device=self.device, dtype=torch.long)
|
18 |
all_scores = torch.zeros([0], device=self.device)
|
19 |
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
+
from .constants import BOS_TOKEN
|
4 |
+
|
5 |
|
6 |
class GreedySearch(nn.Module):
|
7 |
def __init__(self, encoder, decoder, embedding, device):
|
|
|
14 |
def forward(self, x, input_length, max_length):
|
15 |
encoder_outputs, hidden = self.encoder(x, input_length)
|
16 |
decoder_hidden = hidden[:self.decoder.num_layers]
|
17 |
+
decoder_input = torch.ones(1, 1, device=self.device, dtype=torch.long) * BOS_TOKEN
|
18 |
all_tokens = torch.zeros([0], device=self.device, dtype=torch.long)
|
19 |
all_scores = torch.zeros([0], device=self.device)
|
20 |
|
models/seq2seq/vocab.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
from typing import List, Dict
|
4 |
-
from .
|
5 |
|
6 |
|
7 |
class Vocab(nn.Module):
|
8 |
def __init__(self, messages: List[Dict]):
|
9 |
super().__init__()
|
10 |
-
self.word2index: Dict[str, int] = {"<pad>":
|
11 |
-
self.index2word: Dict[int, str] = {
|
12 |
self.word_count: Dict[str, int] = dict()
|
13 |
self.size = 4
|
14 |
|
@@ -33,7 +33,7 @@ class Vocab(nn.Module):
|
|
33 |
def sentence_indices(self, sentence: List[str]) -> torch.LongTensor:
|
34 |
indices = torch.LongTensor(len(sentence))
|
35 |
for i, word in enumerate(sentence):
|
36 |
-
indices[i] = self.word2index[word] if word in self.word2index else
|
37 |
return indices
|
38 |
|
39 |
def forward(self, indices: torch.LongTensor):
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
from typing import List, Dict
|
4 |
+
from .constants import PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN
|
5 |
|
6 |
|
7 |
class Vocab(nn.Module):
|
8 |
def __init__(self, messages: List[Dict]):
|
9 |
super().__init__()
|
10 |
+
self.word2index: Dict[str, int] = {"<pad>": PAD_TOKEN, "<bos>": BOS_TOKEN, "<eos>": EOS_TOKEN, "<unk>": UNK_TOKEN}
|
11 |
+
self.index2word: Dict[int, str] = {PAD_TOKEN: "<pad>", BOS_TOKEN: "<bos>", EOS_TOKEN: "<eos>", UNK_TOKEN: "<unk>"}
|
12 |
self.word_count: Dict[str, int] = dict()
|
13 |
self.size = 4
|
14 |
|
|
|
33 |
def sentence_indices(self, sentence: List[str]) -> torch.LongTensor:
|
34 |
indices = torch.LongTensor(len(sentence))
|
35 |
for i, word in enumerate(sentence):
|
36 |
+
indices[i] = self.word2index[word] if word in self.word2index else UNK_TOKEN
|
37 |
return indices
|
38 |
|
39 |
def forward(self, indices: torch.LongTensor):
|
models/transformer/__init__.py
ADDED
File without changes
|
models/transformer/constants.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
CHECKPOINT_PATH = "models/transformer/checkpoint/"
|
models/transformer/custom_types.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, NotRequired, Tuple, TypedDict, Union
|
2 |
+
|
3 |
+
|
4 |
+
MessageId = int
|
5 |
+
MessageText = Union[List[str], Tuple[str]]
|
6 |
+
Conversation = Tuple[MessageText]
|
7 |
+
|
8 |
+
|
9 |
+
class Message(TypedDict):
|
10 |
+
id: MessageId
|
11 |
+
text: MessageText
|
12 |
+
reply_to_id: NotRequired[int]
|
models/transformer/fine_tuner.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from .utils import modified_tokenizer
|
3 |
+
from .telegram_data_extractor import TelegramDataExtractor
|
4 |
+
from transformers import GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling
|
5 |
+
from datasets import load_dataset
|
6 |
+
from .constants import CHECKPOINT_PATH
|
7 |
+
|
8 |
+
|
9 |
+
class FineTuner:
|
10 |
+
def __init__(self,
|
11 |
+
model_name="ai-forever/rugpt3small_based_on_gpt2",
|
12 |
+
cache_dir="model_cache",
|
13 |
+
data_path=CHECKPOINT_PATH):
|
14 |
+
self.data_path = Path(data_path)
|
15 |
+
|
16 |
+
# Инициализация токенизатора и модели
|
17 |
+
self.tokenizer = modified_tokenizer(model_name, cache_dir, self.data_path)
|
18 |
+
self.model = GPT2LMHeadModel.from_pretrained(model_name, cache_dir=str(self.data_path / cache_dir))
|
19 |
+
|
20 |
+
def prepare_data(self):
|
21 |
+
"""
|
22 |
+
Подготовка данных для обучения
|
23 |
+
"""
|
24 |
+
messages = TelegramDataExtractor.load_messages_from_json("/kaggle/input/chat-history/chat_history_small.json")
|
25 |
+
dataset_path = TelegramDataExtractor.conversations_from_messages(self.data_path, self.tokenizer, messages)
|
26 |
+
return dataset_path
|
27 |
+
|
28 |
+
def fine_tune(self,
|
29 |
+
dataset_path,
|
30 |
+
output_name='fine_tuned_model',
|
31 |
+
num_train_epochs=10,
|
32 |
+
per_device_train_batch_size=8,
|
33 |
+
learning_rate=5e-5,
|
34 |
+
save_steps=10_000):
|
35 |
+
"""
|
36 |
+
Дообучение модели на заданном датасете.
|
37 |
+
"""
|
38 |
+
dataset = load_dataset("text", data_files={"train": "train_dataset.txt"})
|
39 |
+
|
40 |
+
def preprocess(example):
|
41 |
+
# Tokenize while preserving structure
|
42 |
+
return self.tokenizer(example["text"], truncation=True, max_length=300)
|
43 |
+
|
44 |
+
train_dataset = dataset.map(preprocess, batched=True)["train"]
|
45 |
+
|
46 |
+
data_collator = DataCollatorForLanguageModeling(
|
47 |
+
tokenizer=self.tokenizer, mlm=False
|
48 |
+
)
|
49 |
+
|
50 |
+
training_args = TrainingArguments(
|
51 |
+
output_dir=str(self.data_path / output_name),
|
52 |
+
overwrite_output_dir=True,
|
53 |
+
num_train_epochs=num_train_epochs,
|
54 |
+
per_device_train_batch_size=per_device_train_batch_size,
|
55 |
+
# fp16=True,
|
56 |
+
# gradient_accumulation_steps=2,
|
57 |
+
save_steps=save_steps,
|
58 |
+
learning_rate=learning_rate,
|
59 |
+
torch_compile=True,
|
60 |
+
save_total_limit=2,
|
61 |
+
logging_dir=str(self.data_path / 'logs'),
|
62 |
+
report_to="none"
|
63 |
+
)
|
64 |
+
|
65 |
+
trainer = Trainer(
|
66 |
+
model=self.model,
|
67 |
+
args=training_args,
|
68 |
+
data_collator=data_collator,
|
69 |
+
train_dataset=train_dataset,
|
70 |
+
)
|
71 |
+
|
72 |
+
trainer.train()
|
73 |
+
# Сохранение обученной модели и токенизатора
|
74 |
+
self.model.save_pretrained(str(self.data_path / output_name))
|
75 |
+
self.tokenizer.save_pretrained(str(self.data_path / output_name))
|
models/transformer/requirements.txt
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohappyeyeballs==2.6.1
|
2 |
+
aiohttp==3.12.15
|
3 |
+
aiosignal==1.4.0
|
4 |
+
attrs==25.3.0
|
5 |
+
certifi==2025.8.3
|
6 |
+
charset-normalizer==3.4.2
|
7 |
+
colorama==0.4.6
|
8 |
+
datasets==4.0.0
|
9 |
+
dill==0.3.8
|
10 |
+
filelock==3.18.0
|
11 |
+
frozenlist==1.7.0
|
12 |
+
fsspec==2025.3.0
|
13 |
+
huggingface-hub==0.34.3
|
14 |
+
idna==3.10
|
15 |
+
multidict==6.6.3
|
16 |
+
multiprocess==0.70.16
|
17 |
+
numpy==2.3.2
|
18 |
+
packaging==25.0
|
19 |
+
pandas==2.3.1
|
20 |
+
propcache==0.3.2
|
21 |
+
pyarrow==21.0.0
|
22 |
+
python-dateutil==2.9.0.post0
|
23 |
+
pytz==2025.2
|
24 |
+
PyYAML==6.0.2
|
25 |
+
regex==2025.7.34
|
26 |
+
requests==2.32.4
|
27 |
+
safetensors==0.5.3
|
28 |
+
six==1.17.0
|
29 |
+
tokenizers==0.21.4
|
30 |
+
tqdm==4.67.1
|
31 |
+
transformers==4.54.1
|
32 |
+
typing_extensions==4.14.1
|
33 |
+
tzdata==2025.2
|
34 |
+
urllib3==2.5.0
|
35 |
+
xxhash==3.5.0
|
36 |
+
yarl==1.20.1
|
models/transformer/telegram_data_extractor.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Dict, OrderedDict, Tuple, Union
|
3 |
+
from .custom_types import MessageId, Message, Conversation, MessageText
|
4 |
+
from typing import List
|
5 |
+
import re
|
6 |
+
import json
|
7 |
+
|
8 |
+
|
9 |
+
class TelegramDataExtractor:
|
10 |
+
@classmethod
|
11 |
+
def load_messages_from_json(cls, path: str, max_message_count: int = None) -> OrderedDict[MessageId, Message]:
|
12 |
+
messages: OrderedDict[MessageId, Message] = OrderedDict()
|
13 |
+
with open(path, "r", encoding="utf-8") as file:
|
14 |
+
chat_json = json.load(file)
|
15 |
+
for i, message in enumerate(chat_json["messages"]):
|
16 |
+
if max_message_count and i == max_message_count:
|
17 |
+
break
|
18 |
+
if message["type"] != "message":
|
19 |
+
continue
|
20 |
+
new_message = {
|
21 |
+
"from": cls.normalize_username(message["from"]),
|
22 |
+
"id": message["id"],
|
23 |
+
"text": cls.normalize(message["text"])
|
24 |
+
}
|
25 |
+
if not new_message["text"]: # Check for empty message
|
26 |
+
continue
|
27 |
+
if "reply_to_message_id" in message.keys():
|
28 |
+
new_message["reply_to_id"] = message["reply_to_message_id"]
|
29 |
+
|
30 |
+
messages[new_message["id"]] = new_message
|
31 |
+
return messages
|
32 |
+
|
33 |
+
@staticmethod
|
34 |
+
def conversations_from_messages(save_to: Path, tokenizer, messages: OrderedDict[MessageId, Message]) -> List[Conversation]:
|
35 |
+
_MAX_MESSAGE_LEN = 150
|
36 |
+
_MAX_QA_LEN_DIFF = 20
|
37 |
+
def remove_duplicates_keep_order(lst: List[Conversation]) -> List[Conversation]:
|
38 |
+
lst = list(dict.fromkeys(lst)) # Remove duplicates and keep order
|
39 |
+
return [(list(x[0]), list(x[1]), x[2]) for x in lst] # Tuples are only needed for hashability
|
40 |
+
def remove_answers_with_only_special_symbols(lst: List[Conversation]) -> List[Conversation]:
|
41 |
+
return [i for i in lst if re.findall(r"[а-я]", " ".join(i[1]))]
|
42 |
+
def remove_long_qa(lst: List[Conversation]) -> List[Conversation]:
|
43 |
+
return [i for i in lst if len(i[0]) <= _MAX_MESSAGE_LEN and len(i[1]) <= _MAX_MESSAGE_LEN]
|
44 |
+
def remove_unbalanced_qa(lst: List[Conversation]) -> List[Conversation]:
|
45 |
+
return [i for i in lst if abs(len(i[0]) - len(i[1])) <= _MAX_QA_LEN_DIFF]
|
46 |
+
def normalize_conversations(lst: List[Conversation]) -> List[Conversation]:
|
47 |
+
lst = remove_duplicates_keep_order(lst)
|
48 |
+
lst = remove_answers_with_only_special_symbols(lst)
|
49 |
+
lst = remove_long_qa(lst)
|
50 |
+
lst = remove_unbalanced_qa(lst)
|
51 |
+
return lst
|
52 |
+
|
53 |
+
conversations: List[Conversation] = []
|
54 |
+
questions: Dict[MessageText, int] = dict()
|
55 |
+
|
56 |
+
messages_values = list(messages.values()) # TODO: try changing this cast to something more applicable
|
57 |
+
for i in range(len(messages) - 1): # There's no answer for last message so add -1
|
58 |
+
try: # Message is answer for message with `id` of `reply_to_id`
|
59 |
+
prev_message = messages[messages_values[i+1]["reply_to_id"]]
|
60 |
+
except KeyError:
|
61 |
+
prev_message = messages_values[i]
|
62 |
+
qa = (prev_message["text"], messages_values[i+1]["text"], prev_message["from"])
|
63 |
+
if qa[0] in questions.keys(): # If there are multiple answers for same message, choose the longest one
|
64 |
+
if len(conversations[questions[qa[0]]][1]) < len(qa[1]) and abs(len(conversations[questions[qa[0]]][1]) - len(qa[1])) <= _MAX_QA_LEN_DIFF:
|
65 |
+
conversations[questions[qa[0]]] = (qa[0], qa[1], qa[2])
|
66 |
+
continue
|
67 |
+
else:
|
68 |
+
questions[qa[0]] = len(conversations)
|
69 |
+
conversations.append(qa)
|
70 |
+
|
71 |
+
conversations = normalize_conversations(conversations)
|
72 |
+
output_path = save_to / "train_dataset.txt"
|
73 |
+
with open(output_path, "w", encoding="utf-8") as file:
|
74 |
+
for conversation in conversations:
|
75 |
+
line = "<user> " + conversation[2] + " <says> " + " ".join(conversation[0]) + f" {tokenizer.eos_token} <response> " + " ".join(conversation[1]) + f" {tokenizer.eos_token}" + "\n"
|
76 |
+
file.write(line)
|
77 |
+
return output_path
|
78 |
+
|
79 |
+
@staticmethod
|
80 |
+
def normalize(text: Union[str, List]) -> Tuple[str]:
|
81 |
+
if isinstance(text, List):
|
82 |
+
text = " ".join([word for word in text if isinstance(word, str)])
|
83 |
+
text = text.lower().strip()
|
84 |
+
text = re.sub(r"[^а-яё.!?:\d]+", r" ", text) # Leave only russian and special characters
|
85 |
+
text = re.sub(r'\.(\s*\.)+', '... ', text) # Replace any sequence of 2+ dots with '...'
|
86 |
+
text = re.sub(r'([?!])(\s*\1)+', r'\1 ', text) # Collapse repeating ? or !
|
87 |
+
text = re.sub(r"([!?]|\.+)", r"\1 ", text) # Separate special symbols by whitespaces
|
88 |
+
text = re.sub(r"ё", r"е", text)
|
89 |
+
text = re.sub(r"(.*[ауспэиычвекьхъз]{6,}.*|\b[апх][аеписх]{2,3}\b|\b[ах]{2,}\b)", r" <laugh> ", text) # Laugh token for strings such as `ахах` etc.
|
90 |
+
text = re.sub(r"(<laugh>)(\s*\1)+", r" <laugh> ", text) # Collapse repeating <laugh> tokens
|
91 |
+
text = re.sub(r"\s+", r" ", text).strip() # Leave only one space between each word
|
92 |
+
return tuple(text.split())
|
93 |
+
|
94 |
+
@staticmethod
|
95 |
+
def normalize_username(text: str) -> Tuple[str]:
|
96 |
+
text = text.lower()
|
97 |
+
text = re.sub(r"[^а-яa-z\s]+", "", text).strip()
|
98 |
+
return text
|
models/transformer/text_generator.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import GPT2LMHeadModel
|
2 |
+
from pathlib import Path
|
3 |
+
from .utils import modified_tokenizer
|
4 |
+
from .constants import CHECKPOINT_PATH
|
5 |
+
|
6 |
+
|
7 |
+
class TextGenerator:
|
8 |
+
def __init__(self, model_name='fine_tuned_model', data_path=CHECKPOINT_PATH):
|
9 |
+
"""
|
10 |
+
Инициализация модели и токенизатора.
|
11 |
+
Загружаем модель и токенизатор из указанного пути.
|
12 |
+
"""
|
13 |
+
model_path = Path(data_path) / model_name
|
14 |
+
self.tokenizer = modified_tokenizer(model_path, None, data_path)
|
15 |
+
self.model = GPT2LMHeadModel.from_pretrained(str(model_path), device_map="auto")
|
16 |
+
self.model.eval()
|
17 |
+
|
18 |
+
def generate_text(self,
|
19 |
+
author: str,
|
20 |
+
input_str: str,
|
21 |
+
max_length=100,
|
22 |
+
num_return_sequences=1,
|
23 |
+
temperature=1.0,
|
24 |
+
top_k=0,
|
25 |
+
top_p=1.0,
|
26 |
+
do_sample=False):
|
27 |
+
"""
|
28 |
+
Генерация текста на основе заданного начального текста (prompt) и параметров.
|
29 |
+
|
30 |
+
Параметры:
|
31 |
+
- input: Входная последовательность.
|
32 |
+
- max_length: Максимальная длина сгенерированного текста.
|
33 |
+
- num_return_sequences: Количество возвращаемых последовательностей.
|
34 |
+
- temperature: Контролирует разнообразие вывода.
|
35 |
+
- top_k: Если больше 0, ограничивает количество слов для выборки только k наиболее вероятными словами.
|
36 |
+
- top_p: Если меньше 1.0, применяется nucleus sampling.
|
37 |
+
- do_sample: Если True, включает случайную выборку для увеличения разнообразия.
|
38 |
+
"""
|
39 |
+
# Формирование prompt
|
40 |
+
prompt_text = f"<user> {author} <says> {input_str} {self.tokenizer.eos_token} <response>"
|
41 |
+
print(prompt_text)
|
42 |
+
|
43 |
+
# Кодирование текста в формате, пригодном для модели
|
44 |
+
encoded_input = self.tokenizer.encode(prompt_text, return_tensors='pt')
|
45 |
+
|
46 |
+
# Генерация текстов
|
47 |
+
outputs = self.model.generate(
|
48 |
+
encoded_input,
|
49 |
+
max_length=max_length + len(encoded_input[0]),
|
50 |
+
num_return_sequences=num_return_sequences,
|
51 |
+
temperature=temperature,
|
52 |
+
top_k=top_k,
|
53 |
+
top_p=top_p,
|
54 |
+
do_sample=do_sample,
|
55 |
+
no_repeat_ngram_size=2
|
56 |
+
)
|
57 |
+
|
58 |
+
# Декодирование результатов
|
59 |
+
all_texts = [self.tokenizer.decode(output, skip_special_tokens=False) for output in outputs]
|
60 |
+
|
61 |
+
# Удаление входных данных из текстов
|
62 |
+
prompt_length = len(self.tokenizer.decode(encoded_input[0], skip_special_tokens=False))
|
63 |
+
trimmed_texts = [text[prompt_length:] for text in all_texts]
|
64 |
+
|
65 |
+
# Возврат результатов в виде словаря
|
66 |
+
return {
|
67 |
+
"full_texts": all_texts,
|
68 |
+
"generated_texts": trimmed_texts
|
69 |
+
}
|
70 |
+
|
71 |
+
if __name__ == "__main__":
|
72 |
+
print("OK")
|
models/transformer/utils.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import GPT2Tokenizer
|
2 |
+
from pathlib import Path
|
3 |
+
from .constants import CHECKPOINT_PATH
|
4 |
+
|
5 |
+
|
6 |
+
def modified_tokenizer(model_name="ai-forever/rugpt3small_based_on_gpt2", cache_dir="model_cache", data_path=Path(CHECKPOINT_PATH)):
|
7 |
+
if cache_dir:
|
8 |
+
tokenizer = GPT2Tokenizer.from_pretrained(model_name, cache_dir=str(data_path / cache_dir))
|
9 |
+
else:
|
10 |
+
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
11 |
+
special_tokens_dict = {
|
12 |
+
"additional_special_tokens": [
|
13 |
+
"<user>",
|
14 |
+
"<says>",
|
15 |
+
"<response>"
|
16 |
+
]
|
17 |
+
}
|
18 |
+
tokenizer.add_special_tokens(special_tokens_dict)
|
19 |
+
tokenizer.add_tokens(["<laugh>"])
|
20 |
+
return tokenizer
|
requirements.txt
CHANGED
@@ -1,22 +1,22 @@
|
|
1 |
-
anyio==4.
|
2 |
-
certifi==2025.
|
3 |
-
colorama==0.4.6
|
4 |
-
filelock==3.18.0
|
5 |
-
fsspec==2025.7.0
|
6 |
h11==0.16.0
|
7 |
httpcore==1.0.9
|
8 |
httpx==0.28.1
|
9 |
idna==3.10
|
10 |
-
Jinja2==3.1.6
|
11 |
-
MarkupSafe==3.0.2
|
12 |
-
mpmath==1.3.0
|
13 |
-
networkx==3.5
|
14 |
-
numpy==2.3.1
|
15 |
python-dotenv==1.1.1
|
16 |
python-telegram-bot==22.3
|
17 |
-
setuptools==80.9.0
|
18 |
sniffio==1.3.1
|
19 |
-
sympy==1.14.0
|
20 |
-
torch==2.7.1
|
21 |
-
tqdm==4.67.1
|
22 |
typing_extensions==4.14.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
anyio==4.10.0
|
2 |
+
certifi==2025.8.3
|
|
|
|
|
|
|
3 |
h11==0.16.0
|
4 |
httpcore==1.0.9
|
5 |
httpx==0.28.1
|
6 |
idna==3.10
|
|
|
|
|
|
|
|
|
|
|
7 |
python-dotenv==1.1.1
|
8 |
python-telegram-bot==22.3
|
|
|
9 |
sniffio==1.3.1
|
|
|
|
|
|
|
10 |
typing_extensions==4.14.1
|
11 |
+
typing-inspection==0.4.1
|
12 |
+
annotated-types==0.7.0
|
13 |
+
fastapi==0.116.1
|
14 |
+
pydantic==2.11.7
|
15 |
+
pydantic_core==2.33.2
|
16 |
+
starlette==0.47.2
|
17 |
+
charset-normalizer==3.4.2
|
18 |
+
requests==2.32.4
|
19 |
+
urllib3==2.5.0
|
20 |
+
click==8.2.1
|
21 |
+
colorama==0.4.6
|
22 |
+
uvicorn==0.35.0
|
server.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from models.transformer.text_generator import TextGenerator
|
4 |
+
|
5 |
+
|
6 |
+
app = FastAPI()
|
7 |
+
|
8 |
+
generator = TextGenerator(
|
9 |
+
model_name='fine_tuned_model_gpt_2',
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
class Message(BaseModel):
|
14 |
+
author: str
|
15 |
+
content: str
|
16 |
+
|
17 |
+
|
18 |
+
@app.post("/generate")
|
19 |
+
def generate_response(message: Message):
|
20 |
+
response = generator.generate_text(
|
21 |
+
author=message.author,
|
22 |
+
input_str=message.content,
|
23 |
+
max_length=100,
|
24 |
+
num_return_sequences=1,
|
25 |
+
do_sample=True,
|
26 |
+
temperature=0.8, # Слегка уменьшаем уверенность
|
27 |
+
top_k=100, # Уменьшаем количество рассматриваемых верхних k слов
|
28 |
+
top_p=0.95 # Уменьшаем "ядерность" распределения
|
29 |
+
)["generated_texts"][0]
|
30 |
+
response = response[:response.find("</s>")]
|
31 |
+
return { "response": response }
|