Caden Shokat commited on
Commit
eb13318
·
1 Parent(s): 13f746f

push model to hub

Browse files
Files changed (1) hide show
  1. src/training/train.py +5 -3
src/training/train.py CHANGED
@@ -5,12 +5,13 @@ from sentence_transformers.losses import MultipleNegativesRankingLoss
5
  from sentence_transformers.trainer import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
6
  from sentence_transformers.training_args import BatchSamplers
7
  from sentence_transformers.losses import MatryoshkaLoss
8
-
9
  from src.utils.config import CFG
10
  from src.utils.paths import TRAIN_JSON, TEST_JSON
11
  from src.eval.ir_eval import build_eval
12
 
13
  def main():
 
14
  device = "cuda" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
15
 
16
  # base model with SDPA
@@ -66,8 +67,9 @@ def main():
66
  trainer.train()
67
  trainer.save_model()
68
 
69
-
70
- trainer.model.push_to_hub(CFG.output_dir)
 
71
 
72
  if __name__ == "__main__":
73
  main()
 
5
  from sentence_transformers.trainer import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
6
  from sentence_transformers.training_args import BatchSamplers
7
  from sentence_transformers.losses import MatryoshkaLoss
8
+ from huggingface_hub import login
9
  from src.utils.config import CFG
10
  from src.utils.paths import TRAIN_JSON, TEST_JSON
11
  from src.eval.ir_eval import build_eval
12
 
13
  def main():
14
+ HF_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN")
15
  device = "cuda" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
16
 
17
  # base model with SDPA
 
67
  trainer.train()
68
  trainer.save_model()
69
 
70
+ if HF_TOKEN:
71
+ login(token=HF_TOKEN)
72
+ trainer.model.push_to_hub(CFG.output_dir)
73
 
74
  if __name__ == "__main__":
75
  main()