Update train_lora_mistral.py
Browse files- train_lora_mistral.py +18 -7
train_lora_mistral.py
CHANGED
@@ -4,7 +4,7 @@ from fastapi.responses import JSONResponse
|
|
4 |
from datetime import datetime
|
5 |
from datasets import load_dataset
|
6 |
from huggingface_hub import HfApi
|
7 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
8 |
from peft import get_peft_model, LoraConfig, TaskType
|
9 |
import torch
|
10 |
|
@@ -36,7 +36,6 @@ def run_health_server():
|
|
36 |
threading.Thread(target=run_health_server, daemon=True).start()
|
37 |
|
38 |
# === Log
|
39 |
-
|
40 |
def log(message):
|
41 |
timestamp = datetime.now().strftime("%H:%M:%S")
|
42 |
print(f"[{timestamp}] {message}")
|
@@ -55,8 +54,11 @@ base_model.config.pad_token_id = tokenizer.pad_token_id
|
|
55 |
log("🎯 LoRA adapter uygulanıyor...")
|
56 |
peft_config = LoraConfig(
|
57 |
task_type=TaskType.CAUSAL_LM,
|
58 |
-
r=64,
|
59 |
-
|
|
|
|
|
|
|
60 |
)
|
61 |
model = get_peft_model(base_model, peft_config)
|
62 |
model.print_trainable_parameters()
|
@@ -65,6 +67,7 @@ log("📦 Parquet dosyaları listeleniyor...")
|
|
65 |
api = HfApi()
|
66 |
files = api.list_repo_files(repo_id=TOKENIZED_DATASET_ID, repo_type="dataset", token=HF_TOKEN)
|
67 |
selected_files = sorted([f for f in files if f.startswith("chunk_") and f.endswith(".parquet")])[START_NUMBER:END_NUMBER+1]
|
|
|
68 |
if not selected_files:
|
69 |
log("⚠️ Parquet bulunamadı. Eğitim iptal.")
|
70 |
exit(0)
|
@@ -84,6 +87,8 @@ training_args = TrainingArguments(
|
|
84 |
fp16=False
|
85 |
)
|
86 |
|
|
|
|
|
87 |
for file in selected_files:
|
88 |
try:
|
89 |
log(f"\n📄 Yükleniyor: {file}")
|
@@ -97,12 +102,18 @@ for file in selected_files:
|
|
97 |
if len(dataset) == 0:
|
98 |
continue
|
99 |
|
100 |
-
#
|
|
|
101 |
first_row = dataset[0]
|
102 |
decoded_prompt = tokenizer.decode(first_row["input_ids"], skip_special_tokens=True)
|
103 |
-
log(f"📌 Örnek prompt: {decoded_prompt}")
|
104 |
|
105 |
-
trainer = Trainer(
|
|
|
|
|
|
|
|
|
|
|
106 |
log("🚀 Eğitim başlıyor...")
|
107 |
trainer.train()
|
108 |
log("✅ Eğitim tamam.")
|
|
|
4 |
from datetime import datetime
|
5 |
from datasets import load_dataset
|
6 |
from huggingface_hub import HfApi
|
7 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
|
8 |
from peft import get_peft_model, LoraConfig, TaskType
|
9 |
import torch
|
10 |
|
|
|
36 |
threading.Thread(target=run_health_server, daemon=True).start()
|
37 |
|
38 |
# === Log
|
|
|
39 |
def log(message):
|
40 |
timestamp = datetime.now().strftime("%H:%M:%S")
|
41 |
print(f"[{timestamp}] {message}")
|
|
|
54 |
log("🎯 LoRA adapter uygulanıyor...")
|
55 |
peft_config = LoraConfig(
|
56 |
task_type=TaskType.CAUSAL_LM,
|
57 |
+
r=64,
|
58 |
+
lora_alpha=16,
|
59 |
+
lora_dropout=0.1,
|
60 |
+
bias="none",
|
61 |
+
fan_in_fan_out=False
|
62 |
)
|
63 |
model = get_peft_model(base_model, peft_config)
|
64 |
model.print_trainable_parameters()
|
|
|
67 |
api = HfApi()
|
68 |
files = api.list_repo_files(repo_id=TOKENIZED_DATASET_ID, repo_type="dataset", token=HF_TOKEN)
|
69 |
selected_files = sorted([f for f in files if f.startswith("chunk_") and f.endswith(".parquet")])[START_NUMBER:END_NUMBER+1]
|
70 |
+
|
71 |
if not selected_files:
|
72 |
log("⚠️ Parquet bulunamadı. Eğitim iptal.")
|
73 |
exit(0)
|
|
|
87 |
fp16=False
|
88 |
)
|
89 |
|
90 |
+
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
91 |
+
|
92 |
for file in selected_files:
|
93 |
try:
|
94 |
log(f"\n📄 Yükleniyor: {file}")
|
|
|
102 |
if len(dataset) == 0:
|
103 |
continue
|
104 |
|
105 |
+
# prompt tanımı: tokenize edilmiş dataset içinde input_ids zaten var
|
106 |
+
# sadece örnek bir tanesini loglayalım
|
107 |
first_row = dataset[0]
|
108 |
decoded_prompt = tokenizer.decode(first_row["input_ids"], skip_special_tokens=True)
|
109 |
+
log(f"📌 Örnek prompt: {decoded_prompt[:200]}...")
|
110 |
|
111 |
+
trainer = Trainer(
|
112 |
+
model=model,
|
113 |
+
args=training_args,
|
114 |
+
train_dataset=dataset,
|
115 |
+
data_collator=collator
|
116 |
+
)
|
117 |
log("🚀 Eğitim başlıyor...")
|
118 |
trainer.train()
|
119 |
log("✅ Eğitim tamam.")
|