ciyidogan commited on
Commit
9aa5822
·
verified ·
1 Parent(s): 40252a8

Update train_lora_mistral.py

Browse files
Files changed (1) hide show
  1. train_lora_mistral.py +18 -7
train_lora_mistral.py CHANGED
@@ -4,7 +4,7 @@ from fastapi.responses import JSONResponse
4
  from datetime import datetime
5
  from datasets import load_dataset
6
  from huggingface_hub import HfApi
7
- from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
8
  from peft import get_peft_model, LoraConfig, TaskType
9
  import torch
10
 
@@ -36,7 +36,6 @@ def run_health_server():
36
  threading.Thread(target=run_health_server, daemon=True).start()
37
 
38
  # === Log
39
-
40
  def log(message):
41
  timestamp = datetime.now().strftime("%H:%M:%S")
42
  print(f"[{timestamp}] {message}")
@@ -55,8 +54,11 @@ base_model.config.pad_token_id = tokenizer.pad_token_id
55
  log("🎯 LoRA adapter uygulanıyor...")
56
  peft_config = LoraConfig(
57
  task_type=TaskType.CAUSAL_LM,
58
- r=64, lora_alpha=16, lora_dropout=0.1,
59
- bias="none", fan_in_fan_out=False
 
 
 
60
  )
61
  model = get_peft_model(base_model, peft_config)
62
  model.print_trainable_parameters()
@@ -65,6 +67,7 @@ log("📦 Parquet dosyaları listeleniyor...")
65
  api = HfApi()
66
  files = api.list_repo_files(repo_id=TOKENIZED_DATASET_ID, repo_type="dataset", token=HF_TOKEN)
67
  selected_files = sorted([f for f in files if f.startswith("chunk_") and f.endswith(".parquet")])[START_NUMBER:END_NUMBER+1]
 
68
  if not selected_files:
69
  log("⚠️ Parquet bulunamadı. Eğitim iptal.")
70
  exit(0)
@@ -84,6 +87,8 @@ training_args = TrainingArguments(
84
  fp16=False
85
  )
86
 
 
 
87
  for file in selected_files:
88
  try:
89
  log(f"\n📄 Yükleniyor: {file}")
@@ -97,12 +102,18 @@ for file in selected_files:
97
  if len(dataset) == 0:
98
  continue
99
 
100
- # Eğitim öncesi örnek prompt kontrolü
 
101
  first_row = dataset[0]
102
  decoded_prompt = tokenizer.decode(first_row["input_ids"], skip_special_tokens=True)
103
- log(f"📌 Örnek prompt: {decoded_prompt}")
104
 
105
- trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
 
 
 
 
 
106
  log("🚀 Eğitim başlıyor...")
107
  trainer.train()
108
  log("✅ Eğitim tamam.")
 
4
  from datetime import datetime
5
  from datasets import load_dataset
6
  from huggingface_hub import HfApi
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
8
  from peft import get_peft_model, LoraConfig, TaskType
9
  import torch
10
 
 
36
  threading.Thread(target=run_health_server, daemon=True).start()
37
 
38
  # === Log
 
39
  def log(message):
40
  timestamp = datetime.now().strftime("%H:%M:%S")
41
  print(f"[{timestamp}] {message}")
 
54
  log("🎯 LoRA adapter uygulanıyor...")
55
  peft_config = LoraConfig(
56
  task_type=TaskType.CAUSAL_LM,
57
+ r=64,
58
+ lora_alpha=16,
59
+ lora_dropout=0.1,
60
+ bias="none",
61
+ fan_in_fan_out=False
62
  )
63
  model = get_peft_model(base_model, peft_config)
64
  model.print_trainable_parameters()
 
67
  api = HfApi()
68
  files = api.list_repo_files(repo_id=TOKENIZED_DATASET_ID, repo_type="dataset", token=HF_TOKEN)
69
  selected_files = sorted([f for f in files if f.startswith("chunk_") and f.endswith(".parquet")])[START_NUMBER:END_NUMBER+1]
70
+
71
  if not selected_files:
72
  log("⚠️ Parquet bulunamadı. Eğitim iptal.")
73
  exit(0)
 
87
  fp16=False
88
  )
89
 
90
+ collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
91
+
92
  for file in selected_files:
93
  try:
94
  log(f"\n📄 Yükleniyor: {file}")
 
102
  if len(dataset) == 0:
103
  continue
104
 
105
+ # prompt tanımı: tokenize edilmiş dataset içinde input_ids zaten var
106
+ # sadece örnek bir tanesini loglayalım
107
  first_row = dataset[0]
108
  decoded_prompt = tokenizer.decode(first_row["input_ids"], skip_special_tokens=True)
109
+ log(f"📌 Örnek prompt: {decoded_prompt[:200]}...")
110
 
111
+ trainer = Trainer(
112
+ model=model,
113
+ args=training_args,
114
+ train_dataset=dataset,
115
+ data_collator=collator
116
+ )
117
  log("🚀 Eğitim başlıyor...")
118
  trainer.train()
119
  log("✅ Eğitim tamam.")