Spaces:
Sleeping
Sleeping
import os | |
import wandb | |
from src.utils.paths import get_path | |
api = wandb.Api() | |
def get_run_by_name(name): | |
runs = api.runs( | |
path="fcc_ml/svj_clustering", | |
filters={"display_name": {"$eq": name.strip()}} | |
) | |
runs = api.runs( | |
path="fcc_ml/svj_clustering", | |
filters={"display_name": {"$eq": name.strip()}} | |
) | |
if runs.length != 1: | |
return None | |
return runs[0] | |
def get_steps_from_file(fname): | |
# fname looks like "/work/gkrzmanc/jetclustering/results/train/lgatr_CONT_ds_cap_1000_2025_01_21_19_41_51/step_22000_epoch_1467.ckpt" -> extract 22000 | |
return int(fname.split("/")[-1].split("_")[1]) | |
def get_run_initial_steps(run): | |
if not run.config["load_model_weights"]: | |
return 0 | |
else: | |
run_name_1 = run.config["load_model_weights"].split("/")[-2] | |
run_1 = get_run_by_name(run_name_1) | |
if run_1 is None: raise Exception("Run doesn't exist: " + run_name_1) | |
return get_run_initial_steps(run_1) + get_steps_from_file(run.config["load_model_weights"]) | |
def extract_relative_path(run_path): | |
# just return everything after train/.. - run_path looks like /a/b/c/d/train/e/f | |
return get_path("train/" + run_path.split("train/")[-1], type="results", fallback=True) | |
#return "train/" + run_path.split("train/")[-1] | |
def get_run_step_direct(run_path, step): | |
# get the step of the run directly | |
p = extract_relative_path(run_path) | |
print("Run-path:", p) | |
lst = os.listdir(p) | |
lst = [x for x in lst if x.endswith(".ckpt")] # files are of format step_x_epoch_y.ckpt | |
steps = [int(x.split("_")[1]) for x in lst] | |
if step not in steps: | |
print("Available steps:", steps) | |
raise Exception("Step not found in run") | |
full_path = os.path.join(p, [x for x in lst if int(x.split("_")[1]) == step][0]) | |
# return everything after "train/" | |
return "train/" + full_path.split("train/")[-1] | |
def get_run_step_ckpt(run, step, steps_from_zero): | |
if not run.config["load_model_weights"] or steps_from_zero: | |
return get_run_step_direct(run.config["run_path"], step), run | |
else: | |
run_name_1 = run.config["load_model_weights"].split("/")[-2] | |
run_1 = get_run_by_name(run_name_1) | |
if run_1 is None: raise Exception("Run doesn't exist: " + run_name_1) | |
steps = get_run_initial_steps(run) | |
if step > steps: | |
print("Step", step, "is in run", run.name) | |
return get_run_step_direct(run_1.config["run_path"], step - steps), run_1 | |
else: | |
return get_run_step_ckpt(run_1, step) | |
args_to_update = ["validation_steps", "start_lr", "lr_scheduler", "optimizer", "embed_as_vectors", "epsilon", | |
"min_samples", "min_cluster_size", "spatial_part_only", "scalars_oc", "lorentz_norm", "beta_type", | |
"coord_loss_weight", "repul_loss_weight", "attr_loss_weight", "gt_radius", "loss", "num_steps", | |
"num_epochs", "hidden_s_channels", "hidden_mv_channels", "n_heads", "internal_dim", | |
"num_blocks", "network_config", "data_config", "no_pid"] | |
def update_args(args, run): | |
for arg in args_to_update: | |
if arg in ["min_samples", "min_cluster_size", "epsilon"]: | |
print("Skipping setting clustering args") | |
continue | |
if arg not in run.config: | |
print("Skipping setting", arg) | |
continue | |
print("Setting", arg, run.config[arg]) | |
setattr(args, arg, run.config[arg]) | |
print("Loaded args from run", run.name) | |
args.parent_run = run.name | |
return args | |