Javierss commited on
Commit
fa48fcf
·
1 Parent(s): cde751a

Update model download method

Browse files
Files changed (1) hide show
  1. game.py +7 -16
game.py CHANGED
@@ -55,6 +55,8 @@ from tracking import (
55
  )
56
  from sentence_transformers import SentenceTransformer, util
57
  import warnings
 
 
58
 
59
  warnings.filterwarnings(action="ignore", category=UserWarning, module="gensim")
60
 
@@ -65,45 +67,34 @@ class Model_class:
65
 
66
  def __init__(self, lang=0, model_type="SentenceTransformer"):
67
 
68
- if model_type == "SentenceTransformer":
69
- repo_url = "[email protected]:Jsevisal/strans_models"
70
- dest_path = "config/strans_models/"
71
- else:
72
- repo_url = "[email protected]:Jsevisal/w2v_models"
73
- dest_path = "config/w2v_models/"
74
-
75
  # Check if the model exists, clone it if it doesn't
76
  if not os.path.exists(
77
  os.path.join(self.base_path, "config/strans_models/")
78
  ) or not os.path.exists(os.path.join(self.base_path, "config/w2v_models/")):
79
- os.system(f"git clone {repo_url} {dest_path}")
80
 
81
  if lang == 1:
82
  if model_type == "word2vec":
83
  self.model = KeyedVectors.load(
84
- os.path.join(self.base_path, "config/w2v_models/eng_w2v_model"),
85
  mmap="r",
86
  )
87
  elif model_type == "SentenceTransformer":
88
  self.model = KeyedVectors.load(
89
- os.path.join(
90
- self.base_path, "config/strans_models/eng_strans_model"
91
- ),
92
  mmap="r",
93
  )
94
 
95
  else:
96
  if model_type == "word2vec":
97
  self.model = KeyedVectors.load(
98
- os.path.join(self.base_path, "config/w2v_models/esp_w2v_model"),
99
  mmap="r",
100
  )
101
 
102
  elif model_type == "SentenceTransformer":
103
  self.model = KeyedVectors.load(
104
- os.path.join(
105
- self.base_path, "config/strans_models/esp_strans_model"
106
- ),
107
  mmap="r",
108
  )
109
 
 
55
  )
56
  from sentence_transformers import SentenceTransformer, util
57
  import warnings
58
+ from huggingface_hub import snapshot_download
59
+
60
 
61
  warnings.filterwarnings(action="ignore", category=UserWarning, module="gensim")
62
 
 
67
 
68
  def __init__(self, lang=0, model_type="SentenceTransformer"):
69
 
 
 
 
 
 
 
 
70
  # Check if the model exists, clone it if it doesn't
71
  if not os.path.exists(
72
  os.path.join(self.base_path, "config/strans_models/")
73
  ) or not os.path.exists(os.path.join(self.base_path, "config/w2v_models/")):
74
+ model_path = snapshot_download(repo_id="Jsevisal/strans_models")
75
 
76
  if lang == 1:
77
  if model_type == "word2vec":
78
  self.model = KeyedVectors.load(
79
+ model_path,
80
  mmap="r",
81
  )
82
  elif model_type == "SentenceTransformer":
83
  self.model = KeyedVectors.load(
84
+ model_path,
 
 
85
  mmap="r",
86
  )
87
 
88
  else:
89
  if model_type == "word2vec":
90
  self.model = KeyedVectors.load(
91
+ model_path,
92
  mmap="r",
93
  )
94
 
95
  elif model_type == "SentenceTransformer":
96
  self.model = KeyedVectors.load(
97
+ model_path,
 
 
98
  mmap="r",
99
  )
100