Spaces:
Running
Running
File size: 3,011 Bytes
59c6d5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from pathlib import Path
from .utils import modified_tokenizer
from .telegram_data_extractor import TelegramDataExtractor
from transformers import GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
from .constants import CHECKPOINT_PATH
class FineTuner:
def __init__(self,
model_name="ai-forever/rugpt3small_based_on_gpt2",
cache_dir="model_cache",
data_path=CHECKPOINT_PATH):
self.data_path = Path(data_path)
# Инициализация токенизатора и модели
self.tokenizer = modified_tokenizer(model_name, cache_dir, self.data_path)
self.model = GPT2LMHeadModel.from_pretrained(model_name, cache_dir=str(self.data_path / cache_dir))
def prepare_data(self):
"""
Подготовка данных для обучения
"""
messages = TelegramDataExtractor.load_messages_from_json("/kaggle/input/chat-history/chat_history_small.json")
dataset_path = TelegramDataExtractor.conversations_from_messages(self.data_path, self.tokenizer, messages)
return dataset_path
def fine_tune(self,
dataset_path,
output_name='fine_tuned_model',
num_train_epochs=10,
per_device_train_batch_size=8,
learning_rate=5e-5,
save_steps=10_000):
"""
Дообучение модели на заданном датасете.
"""
dataset = load_dataset("text", data_files={"train": "train_dataset.txt"})
def preprocess(example):
# Tokenize while preserving structure
return self.tokenizer(example["text"], truncation=True, max_length=300)
train_dataset = dataset.map(preprocess, batched=True)["train"]
data_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer, mlm=False
)
training_args = TrainingArguments(
output_dir=str(self.data_path / output_name),
overwrite_output_dir=True,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=per_device_train_batch_size,
# fp16=True,
# gradient_accumulation_steps=2,
save_steps=save_steps,
learning_rate=learning_rate,
torch_compile=True,
save_total_limit=2,
logging_dir=str(self.data_path / 'logs'),
report_to="none"
)
trainer = Trainer(
model=self.model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
)
trainer.train()
# Сохранение обученной модели и токенизатора
self.model.save_pretrained(str(self.data_path / output_name))
self.tokenizer.save_pretrained(str(self.data_path / output_name))
|