Luis Oala commited on
Commit
0b2696b
·
unverified ·
1 Parent(s): 7a549c2

Update base.py

Browse files
Files changed (1) hide show
  1. utils/base.py +1 -45
utils/base.py CHANGED
@@ -20,17 +20,10 @@ import argparse
20
 
21
 
22
  class SmartFormatter(argparse.HelpFormatter):
23
- <<<<<<< HEAD
24
-
25
  def _split_lines(self, text, width):
26
  if text.startswith('R|'):
27
  return text[2:].splitlines()
28
- =======
29
-
30
- def _split_lines(self, text, width):
31
- if text.startswith('R|'):
32
- return text[2:].splitlines()
33
- >>>>>>> ea1d33b387781225b4149b4b1b3b04f34dc42268
34
  # this is the RawTextHelpFormatter._split_lines
35
  return argparse.HelpFormatter._split_lines(self, text, width)
36
 
@@ -136,13 +129,8 @@ def b2_list_files(folder=''):
136
 
137
  def get_b2_bucket():
138
  bucket_name = 'perturbed-minds'
139
- <<<<<<< HEAD
140
- application_key_id = '003d6b042de536a0000000004'
141
- application_key = 'K003E5Cr+BAYlvSHfg2ynLtvS5aNq78'
142
- =======
143
  application_key_id = '003d6b042de536a0000000008'
144
  application_key = 'K003HMNxnoa91Dy9c0V8JVCKNUnwR9U'
145
- >>>>>>> ea1d33b387781225b4149b4b1b3b04f34dc42268
146
  info = InMemoryAccountInfo()
147
  b2_api = B2Api(info)
148
  b2_api.authorize_account('production', application_key_id, application_key)
@@ -205,18 +193,9 @@ def b2_download_folder(b2_dir, local_dir, force_download=False, mirror_folder=Tr
205
  def get_name(obj):
206
  return obj.__name__ if hasattr(obj, '__name__') else type(obj).__name__
207
 
208
-
209
- <<<<<<< HEAD
210
- def get_mlflow_model_by_name(experiment_name, run_name,
211
- tracking_uri = "http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com",
212
- download_model = True):
213
-
214
- =======
215
  def get_mlflow_model_by_name(experiment_name, run_name,
216
  tracking_uri="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com",
217
  download_model=True):
218
-
219
- >>>>>>> ea1d33b387781225b4149b4b1b3b04f34dc42268
220
  # 0. mlflow basics
221
  mlflow.set_tracking_uri(tracking_uri)
222
  os.environ["AWS_ACCESS_KEY_ID"] = "#TODO: add your AWS access key if you want to write your results to our collaborative lab server"
@@ -229,17 +208,11 @@ def get_mlflow_model_by_name(experiment_name, run_name,
229
  if os.path.isfile('cache/runs_names.pkl'):
230
  runs = pd.read_pickle('cache/runs_names.pkl')
231
  if runs['tags.mlflow.runName'][runs['tags.mlflow.runName'] == run_name].empty:
232
- <<<<<<< HEAD
233
- runs = fetch_runs_list_mlflow(experiment) #returns a pandas data frame where each row is a run (if several exist under that name)
234
- else:
235
- runs = fetch_runs_list_mlflow(experiment) #returns a pandas data frame where each row is a run (if several exist under that name)
236
- =======
237
  # returns a pandas data frame where each row is a run (if several exist under that name)
238
  runs = fetch_runs_list_mlflow(experiment)
239
  else:
240
  # returns a pandas data frame where each row is a run (if several exist under that name)
241
  runs = fetch_runs_list_mlflow(experiment)
242
- >>>>>>> ea1d33b387781225b4149b4b1b3b04f34dc42268
243
 
244
  # 3. get the selected run between all runs inside the selected experiment
245
  run = runs.loc[runs['tags.mlflow.runName'] == run_name]
@@ -257,18 +230,10 @@ def get_mlflow_model_by_name(experiment_name, run_name,
257
  # model = mlflow.pytorch.load_model(os.path.join(
258
  # artifact_uri, "model"), map_location=torch.device(DEVICE))
259
  model = fetch_from_mlflow(os.path.join(
260
- <<<<<<< HEAD
261
- artifact_uri, "model"), use_cache=True, download_model=download_model)
262
-
263
- return state_dict, model
264
-
265
- =======
266
  artifact_uri, "model"), use_cache=True, download_model=download_model)
267
 
268
  return state_dict, model
269
 
270
-
271
- >>>>>>> ea1d33b387781225b4149b4b1b3b04f34dc42268
272
  def data_loader_mean_and_std(data_loader, transform=None):
273
  means = []
274
  stds = []
@@ -279,20 +244,11 @@ def data_loader_mean_and_std(data_loader, transform=None):
279
  stds.append(x.std(dim=(0, 2, 3)).unsqueeze(0))
280
  return torch.cat(means).mean(dim=0), torch.cat(stds).mean(dim=0)
281
 
282
- <<<<<<< HEAD
283
- def fetch_runs_list_mlflow(experiment):
284
- =======
285
-
286
  def fetch_runs_list_mlflow(experiment):
287
- >>>>>>> ea1d33b387781225b4149b4b1b3b04f34dc42268
288
  runs = mlflow.search_runs(experiment.experiment_id)
289
  runs.to_pickle('cache/runs_names.pkl') # where to save it, usually as a .pkl
290
  return runs
291
 
292
- <<<<<<< HEAD
293
- =======
294
-
295
- >>>>>>> ea1d33b387781225b4149b4b1b3b04f34dc42268
296
  def fetch_from_mlflow(uri, use_cache=True, download_model=True):
297
  cache_loc = os.path.join('cache', uri.split('//')[1]) + '.pt'
298
  if use_cache and os.path.exists(cache_loc):
 
20
 
21
 
22
  class SmartFormatter(argparse.HelpFormatter):
23
+
 
24
  def _split_lines(self, text, width):
25
  if text.startswith('R|'):
26
  return text[2:].splitlines()
 
 
 
 
 
 
27
  # this is the RawTextHelpFormatter._split_lines
28
  return argparse.HelpFormatter._split_lines(self, text, width)
29
 
 
129
 
130
  def get_b2_bucket():
131
  bucket_name = 'perturbed-minds'
 
 
 
 
132
  application_key_id = '003d6b042de536a0000000008'
133
  application_key = 'K003HMNxnoa91Dy9c0V8JVCKNUnwR9U'
 
134
  info = InMemoryAccountInfo()
135
  b2_api = B2Api(info)
136
  b2_api.authorize_account('production', application_key_id, application_key)
 
193
  def get_name(obj):
194
  return obj.__name__ if hasattr(obj, '__name__') else type(obj).__name__
195
 
 
 
 
 
 
 
 
196
  def get_mlflow_model_by_name(experiment_name, run_name,
197
  tracking_uri="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com",
198
  download_model=True):
 
 
199
  # 0. mlflow basics
200
  mlflow.set_tracking_uri(tracking_uri)
201
  os.environ["AWS_ACCESS_KEY_ID"] = "#TODO: add your AWS access key if you want to write your results to our collaborative lab server"
 
208
  if os.path.isfile('cache/runs_names.pkl'):
209
  runs = pd.read_pickle('cache/runs_names.pkl')
210
  if runs['tags.mlflow.runName'][runs['tags.mlflow.runName'] == run_name].empty:
 
 
 
 
 
211
  # returns a pandas data frame where each row is a run (if several exist under that name)
212
  runs = fetch_runs_list_mlflow(experiment)
213
  else:
214
  # returns a pandas data frame where each row is a run (if several exist under that name)
215
  runs = fetch_runs_list_mlflow(experiment)
 
216
 
217
  # 3. get the selected run between all runs inside the selected experiment
218
  run = runs.loc[runs['tags.mlflow.runName'] == run_name]
 
230
  # model = mlflow.pytorch.load_model(os.path.join(
231
  # artifact_uri, "model"), map_location=torch.device(DEVICE))
232
  model = fetch_from_mlflow(os.path.join(
 
 
 
 
 
 
233
  artifact_uri, "model"), use_cache=True, download_model=download_model)
234
 
235
  return state_dict, model
236
 
 
 
237
  def data_loader_mean_and_std(data_loader, transform=None):
238
  means = []
239
  stds = []
 
244
  stds.append(x.std(dim=(0, 2, 3)).unsqueeze(0))
245
  return torch.cat(means).mean(dim=0), torch.cat(stds).mean(dim=0)
246
 
 
 
 
 
247
  def fetch_runs_list_mlflow(experiment):
 
248
  runs = mlflow.search_runs(experiment.experiment_id)
249
  runs.to_pickle('cache/runs_names.pkl') # where to save it, usually as a .pkl
250
  return runs
251
 
 
 
 
 
252
  def fetch_from_mlflow(uri, use_cache=True, download_model=True):
253
  cache_loc = os.path.join('cache', uri.split('//')[1]) + '.pt'
254
  if use_cache and os.path.exists(cache_loc):