Caden Shokat
commited on
Commit
·
5dfc78f
1
Parent(s):
fa6c34a
add traing + utils
Browse files- src/training/train.py +100 -0
- src/utils/config.py +9 -0
- src/utils/encoding.py +21 -0
- src/utils/paths.py +9 -0
- src/utils/seed.py +12 -0
src/training/train.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, torch
|
2 |
+
from datasets import load_dataset
|
3 |
+
from sentence_transformers import SentenceTransformer, SentenceTransformerModelCardData
|
4 |
+
from sentence_transformers.losses import MultipleNegativesRankingLoss
|
5 |
+
from sentence_transformers.trainer import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
|
6 |
+
from sentence_transformers.training_args import BatchSamplers
|
7 |
+
from sentence_transformers.losses import MatryoshkaLoss
|
8 |
+
|
9 |
+
from src.utils.config import CFG
|
10 |
+
from src.utils.paths import TRAIN_JSON, TEST_JSON
|
11 |
+
from src.eval.ir_eval import build_eval
|
12 |
+
|
13 |
+
def _precision_and_optim():
|
14 |
+
"""Pick safe precision/optimizer for the current device."""
|
15 |
+
use_cuda = torch.cuda.is_available()
|
16 |
+
use_mps = getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
17 |
+
|
18 |
+
cfg = dict(fp16=False, bf16=False, tf32=False, optim="adamw_torch")
|
19 |
+
|
20 |
+
if use_cuda:
|
21 |
+
# TF32 + fused adamw only on NVIDIA GPUs
|
22 |
+
cfg["tf32"] = True
|
23 |
+
try:
|
24 |
+
cfg["bf16"] = torch.cuda.is_bf16_supported()
|
25 |
+
except Exception:
|
26 |
+
cfg["bf16"] = False
|
27 |
+
maj, _ = torch.cuda.get_device_capability()
|
28 |
+
cfg["optim"] = "adamw_torch_fused" if maj >= 8 else "adamw_torch"
|
29 |
+
|
30 |
+
# MPS/CPU: stick to fp32; fused/TF32/bf16 unsupported in HF trainer
|
31 |
+
return cfg, use_cuda
|
32 |
+
|
33 |
+
|
34 |
+
def main():
|
35 |
+
device = "cuda" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
|
36 |
+
|
37 |
+
# base model with SDPA
|
38 |
+
model = SentenceTransformer(
|
39 |
+
CFG.model_id,
|
40 |
+
device=device,
|
41 |
+
model_kwargs={"attn_implementation": "sdpa"},
|
42 |
+
model_card_data=SentenceTransformerModelCardData(
|
43 |
+
language="en",
|
44 |
+
license="apache-2.0",
|
45 |
+
model_name="Embed AWS Docs",
|
46 |
+
),
|
47 |
+
)
|
48 |
+
|
49 |
+
train_dataset = load_dataset("json", data_files=str(TRAIN_JSON), split="train")
|
50 |
+
test_dataset = load_dataset("json", data_files=str(TEST_JSON), split="train")
|
51 |
+
|
52 |
+
evaluator = build_eval(CFG.matryoshka_dims)
|
53 |
+
|
54 |
+
base_loss = MultipleNegativesRankingLoss(model)
|
55 |
+
train_loss = MatryoshkaLoss(model, base_loss, matryoshka_dims=list(CFG.matryoshka_dims))
|
56 |
+
|
57 |
+
prec_optim, on_cuda = _precision_and_optim()
|
58 |
+
# Smaller batches on CPU/MPS
|
59 |
+
train_bs = 32 if on_cuda else 8
|
60 |
+
eval_bs = 16 if on_cuda else 8
|
61 |
+
grad_acc = 16 if on_cuda else 4 # keeps global batch reasonable
|
62 |
+
|
63 |
+
args = SentenceTransformerTrainingArguments(
|
64 |
+
output_dir=CFG.output_dir,
|
65 |
+
num_train_epochs=4,
|
66 |
+
per_device_train_batch_size=32,
|
67 |
+
gradient_accumulation_steps=16,
|
68 |
+
per_device_eval_batch_size=16,
|
69 |
+
warmup_ratio=0.1,
|
70 |
+
learning_rate=2e-5,
|
71 |
+
lr_scheduler_type="cosine",
|
72 |
+
optim=prec_optim["optim"],
|
73 |
+
tf32=prec_optim["optim"],
|
74 |
+
bf16=["bf16"],
|
75 |
+
batch_sampler=BatchSamplers.NO_DUPLICATES,
|
76 |
+
eval_strategy="epoch",
|
77 |
+
save_strategy="epoch",
|
78 |
+
logging_steps=10,
|
79 |
+
save_total_limit=3,
|
80 |
+
load_best_model_at_end=True,
|
81 |
+
metric_for_best_model="eval_dim_128_cosine_ndcg@10",
|
82 |
+
report_to="none",
|
83 |
+
)
|
84 |
+
|
85 |
+
trainer = SentenceTransformerTrainer(
|
86 |
+
model=model,
|
87 |
+
args=args,
|
88 |
+
train_dataset=train_dataset.select_columns(["positive", "anchor"]),
|
89 |
+
loss=train_loss,
|
90 |
+
evaluator=evaluator,
|
91 |
+
)
|
92 |
+
|
93 |
+
trainer.train()
|
94 |
+
trainer.save_model()
|
95 |
+
|
96 |
+
if os.getenv("HUGGINGFACE_HUB_TOKEN"):
|
97 |
+
trainer.model.push_to_hub(CFG.output_dir)
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
main()
|
src/utils/config.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
@dataclass(frozen=True)
|
4 |
+
class Config:
|
5 |
+
model_id: str = "nomic-ai/modernbert-embed-base"
|
6 |
+
matryoshka_dims: tuple[int, ...] = (768, 512, 256, 128, 64)
|
7 |
+
output_dir: str = "modernbert-embed-aws"
|
8 |
+
|
9 |
+
CFG = Config()
|
src/utils/encoding.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
def read_text_safely(path: str) -> str:
|
4 |
+
for enc in ("utf-8", "utf-8-sig", "cp1252", "latin-1"):
|
5 |
+
try:
|
6 |
+
with open(path, "r", encoding=enc, errors="strict") as f:
|
7 |
+
return f.read()
|
8 |
+
except UnicodeDecodeError:
|
9 |
+
continue
|
10 |
+
|
11 |
+
try:
|
12 |
+
from charset_normalizer import from_path
|
13 |
+
result = from_path(path).best()
|
14 |
+
if result is not None:
|
15 |
+
return str(result)
|
16 |
+
except Exception:
|
17 |
+
pass
|
18 |
+
|
19 |
+
with open(path, "rb") as f:
|
20 |
+
data = f.read()
|
21 |
+
return data.decode("utf-8", errors="replace")
|
src/utils/paths.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
ROOT = Path(__file__).resolve().parents[2]
|
4 |
+
DATA_DIR = ROOT / "dataset"
|
5 |
+
PROC_DIR = DATA_DIR / "processed"
|
6 |
+
PROC_DIR.mkdir(parents=True, exist_ok=True)
|
7 |
+
|
8 |
+
TRAIN_JSON = PROC_DIR / "train_dataset.json"
|
9 |
+
TEST_JSON = PROC_DIR / "test_dataset.json"
|
src/utils/seed.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, random, numpy as np
|
2 |
+
|
3 |
+
def set_seed(seed: int = 42):
|
4 |
+
random.seed(seed)
|
5 |
+
np.random.seed(seed)
|
6 |
+
try:
|
7 |
+
import torch
|
8 |
+
torch.manual_seed(seed)
|
9 |
+
if torch.cuda.is_available():
|
10 |
+
torch.cuda.manual_seed_all(seed)
|
11 |
+
except Exception:
|
12 |
+
pass
|