Caden Shokat commited on
Commit
5dfc78f
·
1 Parent(s): fa6c34a

add traing + utils

Browse files
src/training/train.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch
2
+ from datasets import load_dataset
3
+ from sentence_transformers import SentenceTransformer, SentenceTransformerModelCardData
4
+ from sentence_transformers.losses import MultipleNegativesRankingLoss
5
+ from sentence_transformers.trainer import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
6
+ from sentence_transformers.training_args import BatchSamplers
7
+ from sentence_transformers.losses import MatryoshkaLoss
8
+
9
+ from src.utils.config import CFG
10
+ from src.utils.paths import TRAIN_JSON, TEST_JSON
11
+ from src.eval.ir_eval import build_eval
12
+
13
+ def _precision_and_optim():
14
+ """Pick safe precision/optimizer for the current device."""
15
+ use_cuda = torch.cuda.is_available()
16
+ use_mps = getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
17
+
18
+ cfg = dict(fp16=False, bf16=False, tf32=False, optim="adamw_torch")
19
+
20
+ if use_cuda:
21
+ # TF32 + fused adamw only on NVIDIA GPUs
22
+ cfg["tf32"] = True
23
+ try:
24
+ cfg["bf16"] = torch.cuda.is_bf16_supported()
25
+ except Exception:
26
+ cfg["bf16"] = False
27
+ maj, _ = torch.cuda.get_device_capability()
28
+ cfg["optim"] = "adamw_torch_fused" if maj >= 8 else "adamw_torch"
29
+
30
+ # MPS/CPU: stick to fp32; fused/TF32/bf16 unsupported in HF trainer
31
+ return cfg, use_cuda
32
+
33
+
34
+ def main():
35
+ device = "cuda" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
36
+
37
+ # base model with SDPA
38
+ model = SentenceTransformer(
39
+ CFG.model_id,
40
+ device=device,
41
+ model_kwargs={"attn_implementation": "sdpa"},
42
+ model_card_data=SentenceTransformerModelCardData(
43
+ language="en",
44
+ license="apache-2.0",
45
+ model_name="Embed AWS Docs",
46
+ ),
47
+ )
48
+
49
+ train_dataset = load_dataset("json", data_files=str(TRAIN_JSON), split="train")
50
+ test_dataset = load_dataset("json", data_files=str(TEST_JSON), split="train")
51
+
52
+ evaluator = build_eval(CFG.matryoshka_dims)
53
+
54
+ base_loss = MultipleNegativesRankingLoss(model)
55
+ train_loss = MatryoshkaLoss(model, base_loss, matryoshka_dims=list(CFG.matryoshka_dims))
56
+
57
+ prec_optim, on_cuda = _precision_and_optim()
58
+ # Smaller batches on CPU/MPS
59
+ train_bs = 32 if on_cuda else 8
60
+ eval_bs = 16 if on_cuda else 8
61
+ grad_acc = 16 if on_cuda else 4 # keeps global batch reasonable
62
+
63
+ args = SentenceTransformerTrainingArguments(
64
+ output_dir=CFG.output_dir,
65
+ num_train_epochs=4,
66
+ per_device_train_batch_size=32,
67
+ gradient_accumulation_steps=16,
68
+ per_device_eval_batch_size=16,
69
+ warmup_ratio=0.1,
70
+ learning_rate=2e-5,
71
+ lr_scheduler_type="cosine",
72
+ optim=prec_optim["optim"],
73
+ tf32=prec_optim["optim"],
74
+ bf16=["bf16"],
75
+ batch_sampler=BatchSamplers.NO_DUPLICATES,
76
+ eval_strategy="epoch",
77
+ save_strategy="epoch",
78
+ logging_steps=10,
79
+ save_total_limit=3,
80
+ load_best_model_at_end=True,
81
+ metric_for_best_model="eval_dim_128_cosine_ndcg@10",
82
+ report_to="none",
83
+ )
84
+
85
+ trainer = SentenceTransformerTrainer(
86
+ model=model,
87
+ args=args,
88
+ train_dataset=train_dataset.select_columns(["positive", "anchor"]),
89
+ loss=train_loss,
90
+ evaluator=evaluator,
91
+ )
92
+
93
+ trainer.train()
94
+ trainer.save_model()
95
+
96
+ if os.getenv("HUGGINGFACE_HUB_TOKEN"):
97
+ trainer.model.push_to_hub(CFG.output_dir)
98
+
99
+ if __name__ == "__main__":
100
+ main()
src/utils/config.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass(frozen=True)
4
+ class Config:
5
+ model_id: str = "nomic-ai/modernbert-embed-base"
6
+ matryoshka_dims: tuple[int, ...] = (768, 512, 256, 128, 64)
7
+ output_dir: str = "modernbert-embed-aws"
8
+
9
+ CFG = Config()
src/utils/encoding.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ def read_text_safely(path: str) -> str:
4
+ for enc in ("utf-8", "utf-8-sig", "cp1252", "latin-1"):
5
+ try:
6
+ with open(path, "r", encoding=enc, errors="strict") as f:
7
+ return f.read()
8
+ except UnicodeDecodeError:
9
+ continue
10
+
11
+ try:
12
+ from charset_normalizer import from_path
13
+ result = from_path(path).best()
14
+ if result is not None:
15
+ return str(result)
16
+ except Exception:
17
+ pass
18
+
19
+ with open(path, "rb") as f:
20
+ data = f.read()
21
+ return data.decode("utf-8", errors="replace")
src/utils/paths.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ ROOT = Path(__file__).resolve().parents[2]
4
+ DATA_DIR = ROOT / "dataset"
5
+ PROC_DIR = DATA_DIR / "processed"
6
+ PROC_DIR.mkdir(parents=True, exist_ok=True)
7
+
8
+ TRAIN_JSON = PROC_DIR / "train_dataset.json"
9
+ TEST_JSON = PROC_DIR / "test_dataset.json"
src/utils/seed.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, random, numpy as np
2
+
3
+ def set_seed(seed: int = 42):
4
+ random.seed(seed)
5
+ np.random.seed(seed)
6
+ try:
7
+ import torch
8
+ torch.manual_seed(seed)
9
+ if torch.cuda.is_available():
10
+ torch.cuda.manual_seed_all(seed)
11
+ except Exception:
12
+ pass