mistral7b / train_lora_mistral.py
ciyidogan's picture
Update train_lora_mistral.py
126bdfd verified
raw
history blame
4.38 kB
import sys, os, zipfile, shutil, time, traceback, threading, uvicorn
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from datetime import datetime
from datasets import load_dataset
from huggingface_hub import HfApi
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType
import torch
# === Sabitler ===
START_NUMBER = 0
END_NUMBER = 9
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
TOKENIZED_DATASET_ID = "UcsTurkey/turkish-general-culture-tokenized"
ZIP_UPLOAD_REPO = "UcsTurkey/trained-zips"
HF_TOKEN = os.environ.get("HF_TOKEN")
BATCH_SIZE = 1
EPOCHS = 2
MAX_LENGTH = 2048
OUTPUT_DIR = "/data/output"
ZIP_FOLDER = "/data/zip_temp"
zip_name = f"trained_model_{START_NUMBER:03d}_{END_NUMBER:03d}.zip"
ZIP_PATH = os.path.join(ZIP_FOLDER, zip_name)
# === Health check
app = FastAPI()
@app.get("/")
def health():
return JSONResponse(content={"status": "ok"})
def run_health_server():
uvicorn.run(app, host="0.0.0.0", port=7860)
threading.Thread(target=run_health_server, daemon=True).start()
# === Log
def log(message):
timestamp = datetime.now().strftime("%H:%M:%S")
print(f"[{timestamp}] {message}")
sys.stdout.flush()
# === Eğitim Başlıyor
log("🛠️ Ortam hazırlanıyor...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
log("🧠 Model indiriliyor...")
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16)
base_model.config.pad_token_id = tokenizer.pad_token_id
log("🎯 LoRA adapter uygulanıyor...")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=64, lora_alpha=16, lora_dropout=0.1,
bias="none", fan_in_fan_out=False
)
model = get_peft_model(base_model, peft_config)
model.print_trainable_parameters()
log("📦 Parquet dosyaları listeleniyor...")
api = HfApi()
files = api.list_repo_files(repo_id=TOKENIZED_DATASET_ID, repo_type="dataset", token=HF_TOKEN)
selected_files = sorted([f for f in files if f.startswith("chunk_") and f.endswith(".parquet")])[START_NUMBER:END_NUMBER+1]
if not selected_files:
log("⚠️ Parquet bulunamadı. Eğitim iptal.")
exit(0)
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=BATCH_SIZE,
num_train_epochs=EPOCHS,
save_strategy="epoch",
save_total_limit=2,
learning_rate=2e-4,
disable_tqdm=True,
logging_strategy="steps",
logging_steps=10,
report_to=[],
bf16=True,
fp16=False
)
for file in selected_files:
try:
log(f"\n📄 Yükleniyor: {file}")
dataset = load_dataset(
path=TOKENIZED_DATASET_ID,
data_files={"train": file},
split="train",
token=HF_TOKEN
)
log(f"🔍 {len(dataset)} örnek")
if len(dataset) == 0:
continue
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
log("🚀 Eğitim başlıyor...")
trainer.train()
log("✅ Eğitim tamam.")
except Exception as e:
log(f"❌ Hata: {file}{e}")
traceback.print_exc()
# === Zip
log("📦 Model zipleniyor...")
try:
tmp_dir = os.path.join(ZIP_FOLDER, "temp_save")
os.makedirs(tmp_dir, exist_ok=True)
model.save_pretrained(tmp_dir)
tokenizer.save_pretrained(tmp_dir)
with zipfile.ZipFile(ZIP_PATH, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, _, files in os.walk(tmp_dir):
for file in files:
filepath = os.path.join(root, file)
arcname = os.path.relpath(filepath, tmp_dir)
zipf.write(filepath, arcname=os.path.join("output", arcname))
log(f"✅ Zip oluşturuldu: {ZIP_PATH}")
except Exception as e:
log(f"❌ Zipleme hatası: {e}")
traceback.print_exc()
# === Upload
try:
log("☁️ Hugging Face'e yükleniyor...")
api.upload_file(
path_or_fileobj=ZIP_PATH,
path_in_repo=zip_name,
repo_id=ZIP_UPLOAD_REPO,
repo_type="model",
token=HF_TOKEN
)
log("✅ Upload tamam.")
except Exception as e:
log(f"❌ Upload hatası: {e}")
traceback.print_exc()
log("⏸️ Eğitim tamamlandı. Servis bekleme modunda...")
while True:
time.sleep(60)