jetclustering / src /utils /wandb_utils.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
3.58 kB
import os
import wandb
from src.utils.paths import get_path
api = wandb.Api()
def get_run_by_name(name):
runs = api.runs(
path="fcc_ml/svj_clustering",
filters={"display_name": {"$eq": name.strip()}}
)
runs = api.runs(
path="fcc_ml/svj_clustering",
filters={"display_name": {"$eq": name.strip()}}
)
if runs.length != 1:
return None
return runs[0]
def get_steps_from_file(fname):
# fname looks like "/work/gkrzmanc/jetclustering/results/train/lgatr_CONT_ds_cap_1000_2025_01_21_19_41_51/step_22000_epoch_1467.ckpt" -> extract 22000
return int(fname.split("/")[-1].split("_")[1])
def get_run_initial_steps(run):
if not run.config["load_model_weights"]:
return 0
else:
run_name_1 = run.config["load_model_weights"].split("/")[-2]
run_1 = get_run_by_name(run_name_1)
if run_1 is None: raise Exception("Run doesn't exist: " + run_name_1)
return get_run_initial_steps(run_1) + get_steps_from_file(run.config["load_model_weights"])
def extract_relative_path(run_path):
# just return everything after train/.. - run_path looks like /a/b/c/d/train/e/f
return get_path("train/" + run_path.split("train/")[-1], type="results", fallback=True)
#return "train/" + run_path.split("train/")[-1]
def get_run_step_direct(run_path, step):
# get the step of the run directly
p = extract_relative_path(run_path)
print("Run-path:", p)
lst = os.listdir(p)
lst = [x for x in lst if x.endswith(".ckpt")] # files are of format step_x_epoch_y.ckpt
steps = [int(x.split("_")[1]) for x in lst]
if step not in steps:
print("Available steps:", steps)
raise Exception("Step not found in run")
full_path = os.path.join(p, [x for x in lst if int(x.split("_")[1]) == step][0])
# return everything after "train/"
return "train/" + full_path.split("train/")[-1]
def get_run_step_ckpt(run, step, steps_from_zero):
if not run.config["load_model_weights"] or steps_from_zero:
return get_run_step_direct(run.config["run_path"], step), run
else:
run_name_1 = run.config["load_model_weights"].split("/")[-2]
run_1 = get_run_by_name(run_name_1)
if run_1 is None: raise Exception("Run doesn't exist: " + run_name_1)
steps = get_run_initial_steps(run)
if step > steps:
print("Step", step, "is in run", run.name)
return get_run_step_direct(run_1.config["run_path"], step - steps), run_1
else:
return get_run_step_ckpt(run_1, step)
args_to_update = ["validation_steps", "start_lr", "lr_scheduler", "optimizer", "embed_as_vectors", "epsilon",
"min_samples", "min_cluster_size", "spatial_part_only", "scalars_oc", "lorentz_norm", "beta_type",
"coord_loss_weight", "repul_loss_weight", "attr_loss_weight", "gt_radius", "loss", "num_steps",
"num_epochs", "hidden_s_channels", "hidden_mv_channels", "n_heads", "internal_dim",
"num_blocks", "network_config", "data_config", "no_pid"]
def update_args(args, run):
for arg in args_to_update:
if arg in ["min_samples", "min_cluster_size", "epsilon"]:
print("Skipping setting clustering args")
continue
if arg not in run.config:
print("Skipping setting", arg)
continue
print("Setting", arg, run.config[arg])
setattr(args, arg, run.config[arg])
print("Loaded args from run", run.name)
args.parent_run = run.name
return args