mistral7b / train_lora_turkcell.py
ciyidogan's picture
Rename train_lora_mistral.py to train_lora_turkcell.py
e93d840 verified
import sys, os, zipfile, shutil, time, traceback, threading, uvicorn
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from datetime import datetime
from datasets import load_dataset
from huggingface_hub import HfApi
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from peft import get_peft_model, LoraConfig, TaskType
import torch
# === Sabitler ===
START_NUMBER = 0
END_NUMBER = 9
MODEL_NAME = "TURKCELL/Turkcell-LLM-7b-v1"
TOKENIZED_DATASET_ID = "UcsTurkey/turkish-train-tokenized"
ZIP_UPLOAD_REPO = "UcsTurkey/trained-zips"
HF_TOKEN = os.environ.get("HF_TOKEN")
BATCH_SIZE = 1
EPOCHS = 2
MAX_LENGTH = 2048
OUTPUT_DIR = "/data/output"
ZIP_FOLDER = "/data/zip_temp"
zip_name = f"trained_model_{START_NUMBER:03d}_{END_NUMBER:03d}.zip"
ZIP_PATH = os.path.join(ZIP_FOLDER, zip_name)
# === Health check
app = FastAPI()
@app.get("/")
def health():
return JSONResponse(content={"status": "ok"})
def run_health_server():
uvicorn.run(app, host="0.0.0.0", port=7860)
threading.Thread(target=run_health_server, daemon=True).start()
# === Log
def log(message):
timestamp = datetime.now().strftime("%H:%M:%S")
print(f"[{timestamp}] {message}")
sys.stdout.flush()
# === Eğitim Başlıyor
log("🛠️ Ortam hazırlanıyor...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
log("🧠 Model indiriliyor...")
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16)
base_model.config.pad_token_id = tokenizer.pad_token_id
log("🎯 LoRA adapter uygulanıyor...")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=64,
lora_alpha=16,
lora_dropout=0.1,
bias="none",
fan_in_fan_out=False
)
model = get_peft_model(base_model, peft_config)
model.print_trainable_parameters()
log("📦 Parquet dosyaları listeleniyor...")
api = HfApi()
files = api.list_repo_files(repo_id=TOKENIZED_DATASET_ID, repo_type="dataset", token=HF_TOKEN)
selected_files = sorted([f for f in files if f.startswith("chunk_") and f.endswith(".parquet")])[START_NUMBER:END_NUMBER+1]
if not selected_files:
log("⚠️ Parquet bulunamadı. Eğitim iptal.")
exit(0)
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=BATCH_SIZE,
num_train_epochs=EPOCHS,
save_strategy="epoch",
save_total_limit=2,
learning_rate=2e-4,
disable_tqdm=True,
logging_strategy="steps",
logging_steps=10,
report_to=[],
bf16=True,
fp16=False
)
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
for file in selected_files:
try:
log(f"\n📄 Yükleniyor: {file}")
dataset = load_dataset(
path=TOKENIZED_DATASET_ID,
data_files={"train": file},
split="train",
token=HF_TOKEN
)
log(f"🔍 {len(dataset)} örnek")
if len(dataset) == 0:
continue
# prompt tanımı: tokenize edilmiş dataset içinde input_ids zaten var
# sadece örnek bir tanesini loglayalım
first_row = dataset[0]
decoded_prompt = tokenizer.decode(first_row["input_ids"], skip_special_tokens=True)
log(f"📌 Örnek prompt: {decoded_prompt[:200]}...")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=collator
)
log("🚀 Eğitim başlıyor...")
trainer.train()
log("✅ Eğitim tamam.")
except Exception as e:
log(f"❌ Hata: {file}{e}")
traceback.print_exc()
# === Zip
log("📦 Model zipleniyor...")
try:
tmp_dir = os.path.join(ZIP_FOLDER, "temp_save")
os.makedirs(tmp_dir, exist_ok=True)
model.save_pretrained(tmp_dir)
tokenizer.save_pretrained(tmp_dir)
with zipfile.ZipFile(ZIP_PATH, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, _, files in os.walk(tmp_dir):
for file in files:
filepath = os.path.join(root, file)
arcname = os.path.relpath(filepath, tmp_dir)
zipf.write(filepath, arcname=os.path.join("output", arcname))
log(f"✅ Zip oluşturuldu: {ZIP_PATH}")
except Exception as e:
log(f"❌ Zipleme hatası: {e}")
traceback.print_exc()
# === Upload
try:
log("☁️ Hugging Face'e yükleniyor...")
api.upload_file(
path_or_fileobj=ZIP_PATH,
path_in_repo=zip_name,
repo_id=ZIP_UPLOAD_REPO,
repo_type="model",
token=HF_TOKEN
)
log("✅ Upload tamam.")
except Exception as e:
log(f"❌ Upload hatası: {e}")
traceback.print_exc()
log("⏸️ Eğitim tamamlandı. Servis bekleme modunda...")
while True:
time.sleep(60)