Caden Shokat
commited on
Commit
·
eb13318
1
Parent(s):
13f746f
push model to hub
Browse files- src/training/train.py +5 -3
src/training/train.py
CHANGED
@@ -5,12 +5,13 @@ from sentence_transformers.losses import MultipleNegativesRankingLoss
|
|
5 |
from sentence_transformers.trainer import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
|
6 |
from sentence_transformers.training_args import BatchSamplers
|
7 |
from sentence_transformers.losses import MatryoshkaLoss
|
8 |
-
|
9 |
from src.utils.config import CFG
|
10 |
from src.utils.paths import TRAIN_JSON, TEST_JSON
|
11 |
from src.eval.ir_eval import build_eval
|
12 |
|
13 |
def main():
|
|
|
14 |
device = "cuda" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
|
15 |
|
16 |
# base model with SDPA
|
@@ -66,8 +67,9 @@ def main():
|
|
66 |
trainer.train()
|
67 |
trainer.save_model()
|
68 |
|
69 |
-
|
70 |
-
|
|
|
71 |
|
72 |
if __name__ == "__main__":
|
73 |
main()
|
|
|
5 |
from sentence_transformers.trainer import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
|
6 |
from sentence_transformers.training_args import BatchSamplers
|
7 |
from sentence_transformers.losses import MatryoshkaLoss
|
8 |
+
from huggingface_hub import login
|
9 |
from src.utils.config import CFG
|
10 |
from src.utils.paths import TRAIN_JSON, TEST_JSON
|
11 |
from src.eval.ir_eval import build_eval
|
12 |
|
13 |
def main():
|
14 |
+
HF_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN")
|
15 |
device = "cuda" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
|
16 |
|
17 |
# base model with SDPA
|
|
|
67 |
trainer.train()
|
68 |
trainer.save_model()
|
69 |
|
70 |
+
if HF_TOKEN:
|
71 |
+
login(token=HF_TOKEN)
|
72 |
+
trainer.model.push_to_hub(CFG.output_dir)
|
73 |
|
74 |
if __name__ == "__main__":
|
75 |
main()
|