File size: 3,582 Bytes
e75a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import wandb
from src.utils.paths import get_path

api = wandb.Api()

def get_run_by_name(name):
    runs = api.runs(
        path="fcc_ml/svj_clustering",
        filters={"display_name": {"$eq": name.strip()}}
    )
    runs = api.runs(
        path="fcc_ml/svj_clustering",
        filters={"display_name": {"$eq": name.strip()}}
    )

    if runs.length != 1:
        return None
    return runs[0]

def get_steps_from_file(fname):
    # fname looks like "/work/gkrzmanc/jetclustering/results/train/lgatr_CONT_ds_cap_1000_2025_01_21_19_41_51/step_22000_epoch_1467.ckpt" -> extract 22000
    return int(fname.split("/")[-1].split("_")[1])

def get_run_initial_steps(run):
    if not run.config["load_model_weights"]:
        return 0
    else:
        run_name_1 = run.config["load_model_weights"].split("/")[-2]
        run_1 = get_run_by_name(run_name_1)
        if run_1 is None: raise Exception("Run doesn't exist: " + run_name_1)
        return get_run_initial_steps(run_1) + get_steps_from_file(run.config["load_model_weights"])

def extract_relative_path(run_path):
    # just return everything after train/.. - run_path looks like /a/b/c/d/train/e/f
    return get_path("train/" + run_path.split("train/")[-1], type="results", fallback=True)
    #return "train/" + run_path.split("train/")[-1]


def get_run_step_direct(run_path, step):
    # get the step of the run directly
    p = extract_relative_path(run_path)
    print("Run-path:", p)
    lst = os.listdir(p)
    lst = [x for x in lst if x.endswith(".ckpt")] # files are of format step_x_epoch_y.ckpt
    steps = [int(x.split("_")[1]) for x in lst]
    if step not in steps:
        print("Available steps:", steps)
        raise Exception("Step not found in run")
    full_path = os.path.join(p, [x for x in lst if int(x.split("_")[1]) == step][0])
    # return everything after "train/"
    return "train/" + full_path.split("train/")[-1]


def get_run_step_ckpt(run, step, steps_from_zero):
    if not run.config["load_model_weights"] or steps_from_zero:
        return get_run_step_direct(run.config["run_path"], step), run
    else:
        run_name_1 = run.config["load_model_weights"].split("/")[-2]
        run_1 = get_run_by_name(run_name_1)
        if run_1 is None: raise Exception("Run doesn't exist: " + run_name_1)
        steps = get_run_initial_steps(run)
        if step > steps:
            print("Step", step, "is in run", run.name)
            return get_run_step_direct(run_1.config["run_path"], step - steps), run_1
        else:
            return get_run_step_ckpt(run_1, step)

args_to_update = ["validation_steps", "start_lr", "lr_scheduler", "optimizer", "embed_as_vectors", "epsilon",
                  "min_samples", "min_cluster_size", "spatial_part_only", "scalars_oc", "lorentz_norm", "beta_type",
                  "coord_loss_weight", "repul_loss_weight", "attr_loss_weight", "gt_radius", "loss", "num_steps",
                  "num_epochs", "hidden_s_channels", "hidden_mv_channels", "n_heads", "internal_dim",
                  "num_blocks", "network_config", "data_config", "no_pid"]

def update_args(args, run):
    for arg in args_to_update:
        if arg in ["min_samples", "min_cluster_size", "epsilon"]:
            print("Skipping setting clustering args")
            continue
        if arg not in run.config:
            print("Skipping setting", arg)
            continue
        print("Setting", arg, run.config[arg])
        setattr(args, arg, run.config[arg])
    print("Loaded args from run", run.name)
    args.parent_run = run.name
    return args