Luis Oala
commited on
Update base.py
Browse files- utils/base.py +1 -45
utils/base.py
CHANGED
@@ -20,17 +20,10 @@ import argparse
|
|
20 |
|
21 |
|
22 |
class SmartFormatter(argparse.HelpFormatter):
|
23 |
-
|
24 |
-
|
25 |
def _split_lines(self, text, width):
|
26 |
if text.startswith('R|'):
|
27 |
return text[2:].splitlines()
|
28 |
-
=======
|
29 |
-
|
30 |
-
def _split_lines(self, text, width):
|
31 |
-
if text.startswith('R|'):
|
32 |
-
return text[2:].splitlines()
|
33 |
-
>>>>>>> ea1d33b387781225b4149b4b1b3b04f34dc42268
|
34 |
# this is the RawTextHelpFormatter._split_lines
|
35 |
return argparse.HelpFormatter._split_lines(self, text, width)
|
36 |
|
@@ -136,13 +129,8 @@ def b2_list_files(folder=''):
|
|
136 |
|
137 |
def get_b2_bucket():
|
138 |
bucket_name = 'perturbed-minds'
|
139 |
-
<<<<<<< HEAD
|
140 |
-
application_key_id = '003d6b042de536a0000000004'
|
141 |
-
application_key = 'K003E5Cr+BAYlvSHfg2ynLtvS5aNq78'
|
142 |
-
=======
|
143 |
application_key_id = '003d6b042de536a0000000008'
|
144 |
application_key = 'K003HMNxnoa91Dy9c0V8JVCKNUnwR9U'
|
145 |
-
>>>>>>> ea1d33b387781225b4149b4b1b3b04f34dc42268
|
146 |
info = InMemoryAccountInfo()
|
147 |
b2_api = B2Api(info)
|
148 |
b2_api.authorize_account('production', application_key_id, application_key)
|
@@ -205,18 +193,9 @@ def b2_download_folder(b2_dir, local_dir, force_download=False, mirror_folder=Tr
|
|
205 |
def get_name(obj):
|
206 |
return obj.__name__ if hasattr(obj, '__name__') else type(obj).__name__
|
207 |
|
208 |
-
|
209 |
-
<<<<<<< HEAD
|
210 |
-
def get_mlflow_model_by_name(experiment_name, run_name,
|
211 |
-
tracking_uri = "http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com",
|
212 |
-
download_model = True):
|
213 |
-
|
214 |
-
=======
|
215 |
def get_mlflow_model_by_name(experiment_name, run_name,
|
216 |
tracking_uri="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com",
|
217 |
download_model=True):
|
218 |
-
|
219 |
-
>>>>>>> ea1d33b387781225b4149b4b1b3b04f34dc42268
|
220 |
# 0. mlflow basics
|
221 |
mlflow.set_tracking_uri(tracking_uri)
|
222 |
os.environ["AWS_ACCESS_KEY_ID"] = "#TODO: add your AWS access key if you want to write your results to our collaborative lab server"
|
@@ -229,17 +208,11 @@ def get_mlflow_model_by_name(experiment_name, run_name,
|
|
229 |
if os.path.isfile('cache/runs_names.pkl'):
|
230 |
runs = pd.read_pickle('cache/runs_names.pkl')
|
231 |
if runs['tags.mlflow.runName'][runs['tags.mlflow.runName'] == run_name].empty:
|
232 |
-
<<<<<<< HEAD
|
233 |
-
runs = fetch_runs_list_mlflow(experiment) #returns a pandas data frame where each row is a run (if several exist under that name)
|
234 |
-
else:
|
235 |
-
runs = fetch_runs_list_mlflow(experiment) #returns a pandas data frame where each row is a run (if several exist under that name)
|
236 |
-
=======
|
237 |
# returns a pandas data frame where each row is a run (if several exist under that name)
|
238 |
runs = fetch_runs_list_mlflow(experiment)
|
239 |
else:
|
240 |
# returns a pandas data frame where each row is a run (if several exist under that name)
|
241 |
runs = fetch_runs_list_mlflow(experiment)
|
242 |
-
>>>>>>> ea1d33b387781225b4149b4b1b3b04f34dc42268
|
243 |
|
244 |
# 3. get the selected run between all runs inside the selected experiment
|
245 |
run = runs.loc[runs['tags.mlflow.runName'] == run_name]
|
@@ -257,18 +230,10 @@ def get_mlflow_model_by_name(experiment_name, run_name,
|
|
257 |
# model = mlflow.pytorch.load_model(os.path.join(
|
258 |
# artifact_uri, "model"), map_location=torch.device(DEVICE))
|
259 |
model = fetch_from_mlflow(os.path.join(
|
260 |
-
<<<<<<< HEAD
|
261 |
-
artifact_uri, "model"), use_cache=True, download_model=download_model)
|
262 |
-
|
263 |
-
return state_dict, model
|
264 |
-
|
265 |
-
=======
|
266 |
artifact_uri, "model"), use_cache=True, download_model=download_model)
|
267 |
|
268 |
return state_dict, model
|
269 |
|
270 |
-
|
271 |
-
>>>>>>> ea1d33b387781225b4149b4b1b3b04f34dc42268
|
272 |
def data_loader_mean_and_std(data_loader, transform=None):
|
273 |
means = []
|
274 |
stds = []
|
@@ -279,20 +244,11 @@ def data_loader_mean_and_std(data_loader, transform=None):
|
|
279 |
stds.append(x.std(dim=(0, 2, 3)).unsqueeze(0))
|
280 |
return torch.cat(means).mean(dim=0), torch.cat(stds).mean(dim=0)
|
281 |
|
282 |
-
<<<<<<< HEAD
|
283 |
-
def fetch_runs_list_mlflow(experiment):
|
284 |
-
=======
|
285 |
-
|
286 |
def fetch_runs_list_mlflow(experiment):
|
287 |
-
>>>>>>> ea1d33b387781225b4149b4b1b3b04f34dc42268
|
288 |
runs = mlflow.search_runs(experiment.experiment_id)
|
289 |
runs.to_pickle('cache/runs_names.pkl') # where to save it, usually as a .pkl
|
290 |
return runs
|
291 |
|
292 |
-
<<<<<<< HEAD
|
293 |
-
=======
|
294 |
-
|
295 |
-
>>>>>>> ea1d33b387781225b4149b4b1b3b04f34dc42268
|
296 |
def fetch_from_mlflow(uri, use_cache=True, download_model=True):
|
297 |
cache_loc = os.path.join('cache', uri.split('//')[1]) + '.pt'
|
298 |
if use_cache and os.path.exists(cache_loc):
|
|
|
20 |
|
21 |
|
22 |
class SmartFormatter(argparse.HelpFormatter):
|
23 |
+
|
|
|
24 |
def _split_lines(self, text, width):
|
25 |
if text.startswith('R|'):
|
26 |
return text[2:].splitlines()
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
# this is the RawTextHelpFormatter._split_lines
|
28 |
return argparse.HelpFormatter._split_lines(self, text, width)
|
29 |
|
|
|
129 |
|
130 |
def get_b2_bucket():
|
131 |
bucket_name = 'perturbed-minds'
|
|
|
|
|
|
|
|
|
132 |
application_key_id = '003d6b042de536a0000000008'
|
133 |
application_key = 'K003HMNxnoa91Dy9c0V8JVCKNUnwR9U'
|
|
|
134 |
info = InMemoryAccountInfo()
|
135 |
b2_api = B2Api(info)
|
136 |
b2_api.authorize_account('production', application_key_id, application_key)
|
|
|
193 |
def get_name(obj):
|
194 |
return obj.__name__ if hasattr(obj, '__name__') else type(obj).__name__
|
195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
def get_mlflow_model_by_name(experiment_name, run_name,
|
197 |
tracking_uri="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com",
|
198 |
download_model=True):
|
|
|
|
|
199 |
# 0. mlflow basics
|
200 |
mlflow.set_tracking_uri(tracking_uri)
|
201 |
os.environ["AWS_ACCESS_KEY_ID"] = "#TODO: add your AWS access key if you want to write your results to our collaborative lab server"
|
|
|
208 |
if os.path.isfile('cache/runs_names.pkl'):
|
209 |
runs = pd.read_pickle('cache/runs_names.pkl')
|
210 |
if runs['tags.mlflow.runName'][runs['tags.mlflow.runName'] == run_name].empty:
|
|
|
|
|
|
|
|
|
|
|
211 |
# returns a pandas data frame where each row is a run (if several exist under that name)
|
212 |
runs = fetch_runs_list_mlflow(experiment)
|
213 |
else:
|
214 |
# returns a pandas data frame where each row is a run (if several exist under that name)
|
215 |
runs = fetch_runs_list_mlflow(experiment)
|
|
|
216 |
|
217 |
# 3. get the selected run between all runs inside the selected experiment
|
218 |
run = runs.loc[runs['tags.mlflow.runName'] == run_name]
|
|
|
230 |
# model = mlflow.pytorch.load_model(os.path.join(
|
231 |
# artifact_uri, "model"), map_location=torch.device(DEVICE))
|
232 |
model = fetch_from_mlflow(os.path.join(
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
artifact_uri, "model"), use_cache=True, download_model=download_model)
|
234 |
|
235 |
return state_dict, model
|
236 |
|
|
|
|
|
237 |
def data_loader_mean_and_std(data_loader, transform=None):
|
238 |
means = []
|
239 |
stds = []
|
|
|
244 |
stds.append(x.std(dim=(0, 2, 3)).unsqueeze(0))
|
245 |
return torch.cat(means).mean(dim=0), torch.cat(stds).mean(dim=0)
|
246 |
|
|
|
|
|
|
|
|
|
247 |
def fetch_runs_list_mlflow(experiment):
|
|
|
248 |
runs = mlflow.search_runs(experiment.experiment_id)
|
249 |
runs.to_pickle('cache/runs_names.pkl') # where to save it, usually as a .pkl
|
250 |
return runs
|
251 |
|
|
|
|
|
|
|
|
|
252 |
def fetch_from_mlflow(uri, use_cache=True, download_model=True):
|
253 |
cache_loc = os.path.join('cache', uri.split('//')[1]) + '.pt'
|
254 |
if use_cache and os.path.exists(cache_loc):
|