jetclustering / scripts /plot_eval_n_params.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
9.83 kB
import os
from tqdm import tqdm
import argparse
import pickle
from src.plotting.eval_matrix import matrix_plot, scatter_plot
from src.utils.paths import get_path
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, required=False, default="scouting_PFNano_signals2/SVJ_hadronic_std/batch_eval/small_dataset")
args = parser.parse_args()
path = get_path(args.input, "results")
def get_steps(config):
if "ckpt_step" in config:
return config["ckpt_step"]
# else, config["load_model_weights"] looks like /.../.../step_xxxx_epoch_y.ckpt (fallback)
return int(config["load_model_weights"].split("/")[-1].split("_")[1])
def get_short(network_config):
if "transformer" in network_config.lower():
return "Transformer"
if "lgatr" in network_config.lower():
return "LGATr"
if "gatr" in network_config.lower():
return "GATr"
return "Unknown"
def get_model_details(path_to_eval):
config = pickle.load(open(os.path.join(path_to_eval, "run_config.pkl"), "rb"))
return config["num_parameters"], get_short(config["network_config"]), get_steps(config)
models = sorted([x for x in os.listdir(path) if not (os.path.isfile(os.path.join(path, x)) or "AK8" in x)])# + ["AK8", "AK8_GenJets"]
data = [get_model_details(os.path.join(path, model)) for model in models] + [(0, "AK8", 0), (0, "AK8_GenJets", 0)]
models = models + ["AK8", "AK8_GenJets"]
out_file_PR = os.path.join(get_path(args.input, "results"), "precision_recall_n_params.pdf")
sz = 5
fig, ax = plt.subplots(3, len(models), figsize=(sz * len(models), sz * 3))
result_scatter = {} # e.g. Transformer -> [xarr, yarr, yarr1, yarr2]
result_scatter_900_03 = {}
result_by_step = {"900_03": {}, "700_07": {}} # Model+n_params -> [step, p, r, f1]
def get_arch_name(n_params, net_short):
if net_short == "Transformer":
if n_params == 4674:
return "Tr-2-16-4"
elif n_params == 1201108:
return "Tr"
elif n_params == 1322274:
return "Tr"
elif n_params == 167394:
return "Tr-5-64-4"
if net_short == "LGATr":
if n_params == 8424:
return "LGATr-2-4-4"
elif n_params == 1201108:
return "LGATr"
elif n_params == 156332:
return "LGATr-3-16-16"
if net_short == "GATr":
if n_params == 6533:
return "GATr-2-4-4"
if n_params == 926041:
return "GATr"
if "AK8" in net_short:
return net_short
return None
# n_params, P, R, f1
for i, model in tqdm(enumerate(models)):
output_path = os.path.join(path, model, "count_matched_quarks")
if not os.path.exists(os.path.join(output_path, "result.pkl")):
continue
result = pickle.load(open(os.path.join(output_path, "result.pkl"), "rb"))
result_fakes = pickle.load(open(os.path.join(output_path, "result_fakes.pkl"), "rb"))
result_bc = pickle.load(open(os.path.join(output_path, "result_bc.pkl"), "rb"))
result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb"))
#matrix_plot(result_PR, "Oranges", "Precision (N matched dark quarks / N predicted jets)", metric_comp_func = lambda r: r[0], ax=ax[0, i])
#matrix_plot(result_PR, "Reds", "Recall (N matched dark quarks / N dark quarks)", metric_comp_func = lambda r: r[1], ax=ax[1, i])
#matrix_plot(result_PR, "Purples", r"$F_1$ score", metric_comp_func = lambda r: 2 * r[0] * r[1] / (r[0] + r[1]), ax=ax[2, i])
arch = get_arch_name(data[i][0], data[i][1])
if arch is not None:
if result_by_step["900_03"].get(arch) is None:
for key in result_by_step:
result_by_step[key][arch] = [[], [], [], []]
pr = result_PR[900][20][0.3]
result_by_step["900_03"][arch][0].append(data[i][2])
result_by_step["900_03"][arch][1].append(pr[0])
result_by_step["900_03"][arch][2].append(pr[1])
result_by_step["900_03"][arch][3].append(2 * pr[0] * pr[1] / (pr[0] + pr[1]))
pr = result_PR[700][20][0.7]
result_by_step["700_07"][arch][0].append(data[i][2])
result_by_step["700_07"][arch][1].append(pr[0])
result_by_step["700_07"][arch][2].append(pr[1])
result_by_step["700_07"][arch][3].append(2 * pr[0] * pr[1] / (pr[0] + pr[1]))
if data[i][2] != 40000:
continue
ax[0, i].set_title(str(data[i][0]) + " " + data[i][1])
ax[1, i].set_title(str(data[i][0]) + " " + data[i][1])
ax[2, i].set_title(str(data[i][0]) + " " + data[i][1])
if data[i][1] not in result_scatter:
result_scatter[data[i][1]] = [[], [], [], []]
result_scatter_900_03[data[i][1]] = [[], [], [], []]
result_scatter[data[i][1]][0].append(data[i][0])
pr = result_PR[700][20][0.7]
pr_900_03 = result_PR[900][20][0.3]
result_scatter[data[i][1]][3].append(2 * pr[0] * pr[1] / (pr[0] + pr[1]))
result_scatter[data[i][1]][1].append(pr[0])
result_scatter[data[i][1]][2].append(pr[1])
result_scatter_900_03[data[i][1]][3].append(2 * pr_900_03[0] * pr_900_03[1] / (pr_900_03[0] + pr_900_03[1]))
result_scatter_900_03[data[i][1]][1].append(pr_900_03[0])
result_scatter_900_03[data[i][1]][2].append(pr_900_03[1])
result_scatter_900_03[data[i][1]][0].append(data[i][0])
fig.tight_layout()
fig.savefig(out_file_PR)
print("Saved to", out_file_PR)
fig_scatter, ax_scatter = plt.subplots(3, 1, figsize=(sz , sz * 3))
colors = {
"Transformer": "green",
"GATr": "blue",
"LGATr": "red",
}
for key in result_scatter:
scatter_plot(ax_scatter[0], result_scatter[key][0], result_scatter[key][1], key)
scatter_plot(ax_scatter[1], result_scatter[key][0], result_scatter[key][2], key)
scatter_plot(ax_scatter[2], result_scatter[key][0], result_scatter[key][3], key)
ax_scatter[0].set_ylabel("Precision")
ax_scatter[1].set_ylabel("Recall")
ax_scatter[2].set_ylabel("F1 score")
ax_scatter[0].set_xlabel("N params")
ax_scatter[1].set_xlabel("N params")
ax_scatter[2].set_xlabel("N params")
ax_scatter[0].legend()
ax_scatter[1].legend()
ax_scatter[2].legend()
ax_scatter[0].grid()
ax_scatter[1].grid()
ax_scatter[2].grid()
ax_scatter[0].set_xscale("log")
ax_scatter[1].set_xscale("log")
ax_scatter[2].set_xscale("log")
fig_scatter.tight_layout()
fig_scatter.savefig(out_file_PR.replace(".pdf", "_scatter_700_07.pdf"))
print("Saved to", out_file_PR.replace(".pdf", "_scatter_700_07.pdf"))
fig_scatter, ax_scatter = plt.subplots(3, 1, figsize=(sz, sz*3))
for key in result_scatter_900_03:
scatter_plot(ax_scatter[0], result_scatter_900_03[key][0], result_scatter_900_03[key][1], key)
scatter_plot(ax_scatter[1], result_scatter_900_03[key][0], result_scatter_900_03[key][2], key)
scatter_plot(ax_scatter[2], result_scatter_900_03[key][0], result_scatter_900_03[key][3], key)
ax_scatter[0].set_ylabel("Precision")
ax_scatter[1].set_ylabel("Recall")
ax_scatter[2].set_ylabel("F1 score")
ax_scatter[0].set_xlabel("N params")
ax_scatter[1].set_xlabel("N params")
ax_scatter[2].set_xlabel("N params")
ax_scatter[0].legend()
ax_scatter[1].legend()
ax_scatter[2].legend()
ax_scatter[0].grid()
ax_scatter[1].grid()
ax_scatter[2].grid()
ax_scatter[0].set_xscale("log")
ax_scatter[1].set_xscale("log")
ax_scatter[2].set_xscale("log")
fig_scatter.tight_layout()
fig_scatter.savefig(out_file_PR.replace(".pdf", "_scatter_900_03.pdf"))
print("Saved to", out_file_PR.replace(".pdf", "_scatter_900_03.pdf"))
fig_scatter, ax_scatter = plt.subplots(3, 2, figsize=(sz*2, sz*3))
fig_params_paper, ax_params_paper = plt.subplots(1, 2, figsize=(sz, sz*1.5))
for i, key in enumerate(sorted(list(result_by_step.keys()))):
for model in result_by_step[key]:
#scatter_plot(ax_scatter[], result_scatter_900_03[key][0], result_scatter_900_03[key][1], key)
#scatter_plot(ax_scatter[1], result_scatter_900_03[key][0], result_scatter_900_03[key][2], key)
#scatter_plot(ax_scatter[2], result_scatter_900_03[key][0], result_scatter_900_03[key][3], key)
if "AK8" in model:
# put a horizontal dotted line instead of a scatterplot, as there is only one dot
colors = {"AK8": "gray", "AK8_GenJets": "black"}
ax_scatter[0, i].axhline(result_by_step[key][model][1][0], label=model, color=colors[model], linestyle="--")
ax_scatter[1, i].axhline(result_by_step[key][model][2][0], label=model, color=colors[model], linestyle="--")
ax_scatter[2, i].axhline(result_by_step[key][model][3][0], label=model, color=colors[model], linestyle="--")
else:
scatter_plot(ax_scatter[0, i], result_by_step[key][model][0], result_by_step[key][model][1], model)
scatter_plot(ax_scatter[1, i], result_by_step[key][model][0], result_by_step[key][model][2], model)
scatter_plot(ax_scatter[2, i], result_by_step[key][model][0], result_by_step[key][model][3], model)
ax_scatter[0, i].set_title(key)
ax_scatter[1, i].set_title(key)
ax_scatter[2, i].set_title(key)
ax_scatter[0, i].set_ylabel("Precision")
ax_scatter[1, i].set_ylabel("Recall")
ax_scatter[2, i].set_ylabel("F_1 score")
ax_scatter[0, i].set_xlabel("training steps")
ax_scatter[1, i].set_xlabel("training steps")
ax_scatter[2, i].set_xlabel("training steps")
ax_scatter[0, i].legend()
ax_scatter[1, i].legend()
ax_scatter[2, i].legend()
ax_scatter[0, i].grid()
ax_scatter[1, i].grid()
ax_scatter[2, i].grid()
ax_scatter[0, i].set_xscale("log")
ax_scatter[1, i].set_xscale("log")
ax_scatter[2, i].set_xscale("log")
fig_scatter.tight_layout()
fig_scatter.savefig(out_file_PR.replace(".pdf", "_by_step.pdf"))
print("Saved to", out_file_PR.replace(".pdf", "_by_step.pdf"))