Spaces:
Running
Running
from pathlib import Path | |
from .utils import modified_tokenizer | |
from .telegram_data_extractor import TelegramDataExtractor | |
from transformers import GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling | |
from datasets import load_dataset | |
from .constants import CHECKPOINT_PATH | |
class FineTuner: | |
def __init__(self, | |
model_name="ai-forever/rugpt3small_based_on_gpt2", | |
cache_dir="model_cache", | |
data_path=CHECKPOINT_PATH): | |
self.data_path = Path(data_path) | |
# Инициализация токенизатора и модели | |
self.tokenizer = modified_tokenizer(model_name, cache_dir, self.data_path) | |
self.model = GPT2LMHeadModel.from_pretrained(model_name, cache_dir=str(self.data_path / cache_dir)) | |
def prepare_data(self): | |
""" | |
Подготовка данных для обучения | |
""" | |
messages = TelegramDataExtractor.load_messages_from_json("/kaggle/input/chat-history/chat_history_small.json") | |
dataset_path = TelegramDataExtractor.conversations_from_messages(self.data_path, self.tokenizer, messages) | |
return dataset_path | |
def fine_tune(self, | |
dataset_path, | |
output_name='fine_tuned_model', | |
num_train_epochs=10, | |
per_device_train_batch_size=8, | |
learning_rate=5e-5, | |
save_steps=10_000): | |
""" | |
Дообучение модели на заданном датасете. | |
""" | |
dataset = load_dataset("text", data_files={"train": "train_dataset.txt"}) | |
def preprocess(example): | |
# Tokenize while preserving structure | |
return self.tokenizer(example["text"], truncation=True, max_length=300) | |
train_dataset = dataset.map(preprocess, batched=True)["train"] | |
data_collator = DataCollatorForLanguageModeling( | |
tokenizer=self.tokenizer, mlm=False | |
) | |
training_args = TrainingArguments( | |
output_dir=str(self.data_path / output_name), | |
overwrite_output_dir=True, | |
num_train_epochs=num_train_epochs, | |
per_device_train_batch_size=per_device_train_batch_size, | |
# fp16=True, | |
# gradient_accumulation_steps=2, | |
save_steps=save_steps, | |
learning_rate=learning_rate, | |
torch_compile=True, | |
save_total_limit=2, | |
logging_dir=str(self.data_path / 'logs'), | |
report_to="none" | |
) | |
trainer = Trainer( | |
model=self.model, | |
args=training_args, | |
data_collator=data_collator, | |
train_dataset=train_dataset, | |
) | |
trainer.train() | |
# Сохранение обученной модели и токенизатора | |
self.model.save_pretrained(str(self.data_path / output_name)) | |
self.tokenizer.save_pretrained(str(self.data_path / output_name)) | |