Spaces:
Sleeping
Sleeping
import os | |
from tqdm import tqdm | |
import argparse | |
import pickle | |
from src.plotting.eval_matrix import matrix_plot, scatter_plot | |
from src.utils.paths import get_path | |
import matplotlib.pyplot as plt | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--input", type=str, required=False, default="scouting_PFNano_signals2/SVJ_hadronic_std/batch_eval/small_dataset") | |
args = parser.parse_args() | |
path = get_path(args.input, "results") | |
def get_steps(config): | |
if "ckpt_step" in config: | |
return config["ckpt_step"] | |
# else, config["load_model_weights"] looks like /.../.../step_xxxx_epoch_y.ckpt (fallback) | |
return int(config["load_model_weights"].split("/")[-1].split("_")[1]) | |
def get_short(network_config): | |
if "transformer" in network_config.lower(): | |
return "Transformer" | |
if "lgatr" in network_config.lower(): | |
return "LGATr" | |
if "gatr" in network_config.lower(): | |
return "GATr" | |
return "Unknown" | |
def get_model_details(path_to_eval): | |
config = pickle.load(open(os.path.join(path_to_eval, "run_config.pkl"), "rb")) | |
return config["num_parameters"], get_short(config["network_config"]), get_steps(config) | |
models = sorted([x for x in os.listdir(path) if not (os.path.isfile(os.path.join(path, x)) or "AK8" in x)])# + ["AK8", "AK8_GenJets"] | |
data = [get_model_details(os.path.join(path, model)) for model in models] + [(0, "AK8", 0), (0, "AK8_GenJets", 0)] | |
models = models + ["AK8", "AK8_GenJets"] | |
out_file_PR = os.path.join(get_path(args.input, "results"), "precision_recall_n_params.pdf") | |
sz = 5 | |
fig, ax = plt.subplots(3, len(models), figsize=(sz * len(models), sz * 3)) | |
result_scatter = {} # e.g. Transformer -> [xarr, yarr, yarr1, yarr2] | |
result_scatter_900_03 = {} | |
result_by_step = {"900_03": {}, "700_07": {}} # Model+n_params -> [step, p, r, f1] | |
def get_arch_name(n_params, net_short): | |
if net_short == "Transformer": | |
if n_params == 4674: | |
return "Tr-2-16-4" | |
elif n_params == 1201108: | |
return "Tr" | |
elif n_params == 1322274: | |
return "Tr" | |
elif n_params == 167394: | |
return "Tr-5-64-4" | |
if net_short == "LGATr": | |
if n_params == 8424: | |
return "LGATr-2-4-4" | |
elif n_params == 1201108: | |
return "LGATr" | |
elif n_params == 156332: | |
return "LGATr-3-16-16" | |
if net_short == "GATr": | |
if n_params == 6533: | |
return "GATr-2-4-4" | |
if n_params == 926041: | |
return "GATr" | |
if "AK8" in net_short: | |
return net_short | |
return None | |
# n_params, P, R, f1 | |
for i, model in tqdm(enumerate(models)): | |
output_path = os.path.join(path, model, "count_matched_quarks") | |
if not os.path.exists(os.path.join(output_path, "result.pkl")): | |
continue | |
result = pickle.load(open(os.path.join(output_path, "result.pkl"), "rb")) | |
result_fakes = pickle.load(open(os.path.join(output_path, "result_fakes.pkl"), "rb")) | |
result_bc = pickle.load(open(os.path.join(output_path, "result_bc.pkl"), "rb")) | |
result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb")) | |
#matrix_plot(result_PR, "Oranges", "Precision (N matched dark quarks / N predicted jets)", metric_comp_func = lambda r: r[0], ax=ax[0, i]) | |
#matrix_plot(result_PR, "Reds", "Recall (N matched dark quarks / N dark quarks)", metric_comp_func = lambda r: r[1], ax=ax[1, i]) | |
#matrix_plot(result_PR, "Purples", r"$F_1$ score", metric_comp_func = lambda r: 2 * r[0] * r[1] / (r[0] + r[1]), ax=ax[2, i]) | |
arch = get_arch_name(data[i][0], data[i][1]) | |
if arch is not None: | |
if result_by_step["900_03"].get(arch) is None: | |
for key in result_by_step: | |
result_by_step[key][arch] = [[], [], [], []] | |
pr = result_PR[900][20][0.3] | |
result_by_step["900_03"][arch][0].append(data[i][2]) | |
result_by_step["900_03"][arch][1].append(pr[0]) | |
result_by_step["900_03"][arch][2].append(pr[1]) | |
result_by_step["900_03"][arch][3].append(2 * pr[0] * pr[1] / (pr[0] + pr[1])) | |
pr = result_PR[700][20][0.7] | |
result_by_step["700_07"][arch][0].append(data[i][2]) | |
result_by_step["700_07"][arch][1].append(pr[0]) | |
result_by_step["700_07"][arch][2].append(pr[1]) | |
result_by_step["700_07"][arch][3].append(2 * pr[0] * pr[1] / (pr[0] + pr[1])) | |
if data[i][2] != 40000: | |
continue | |
ax[0, i].set_title(str(data[i][0]) + " " + data[i][1]) | |
ax[1, i].set_title(str(data[i][0]) + " " + data[i][1]) | |
ax[2, i].set_title(str(data[i][0]) + " " + data[i][1]) | |
if data[i][1] not in result_scatter: | |
result_scatter[data[i][1]] = [[], [], [], []] | |
result_scatter_900_03[data[i][1]] = [[], [], [], []] | |
result_scatter[data[i][1]][0].append(data[i][0]) | |
pr = result_PR[700][20][0.7] | |
pr_900_03 = result_PR[900][20][0.3] | |
result_scatter[data[i][1]][3].append(2 * pr[0] * pr[1] / (pr[0] + pr[1])) | |
result_scatter[data[i][1]][1].append(pr[0]) | |
result_scatter[data[i][1]][2].append(pr[1]) | |
result_scatter_900_03[data[i][1]][3].append(2 * pr_900_03[0] * pr_900_03[1] / (pr_900_03[0] + pr_900_03[1])) | |
result_scatter_900_03[data[i][1]][1].append(pr_900_03[0]) | |
result_scatter_900_03[data[i][1]][2].append(pr_900_03[1]) | |
result_scatter_900_03[data[i][1]][0].append(data[i][0]) | |
fig.tight_layout() | |
fig.savefig(out_file_PR) | |
print("Saved to", out_file_PR) | |
fig_scatter, ax_scatter = plt.subplots(3, 1, figsize=(sz , sz * 3)) | |
colors = { | |
"Transformer": "green", | |
"GATr": "blue", | |
"LGATr": "red", | |
} | |
for key in result_scatter: | |
scatter_plot(ax_scatter[0], result_scatter[key][0], result_scatter[key][1], key) | |
scatter_plot(ax_scatter[1], result_scatter[key][0], result_scatter[key][2], key) | |
scatter_plot(ax_scatter[2], result_scatter[key][0], result_scatter[key][3], key) | |
ax_scatter[0].set_ylabel("Precision") | |
ax_scatter[1].set_ylabel("Recall") | |
ax_scatter[2].set_ylabel("F1 score") | |
ax_scatter[0].set_xlabel("N params") | |
ax_scatter[1].set_xlabel("N params") | |
ax_scatter[2].set_xlabel("N params") | |
ax_scatter[0].legend() | |
ax_scatter[1].legend() | |
ax_scatter[2].legend() | |
ax_scatter[0].grid() | |
ax_scatter[1].grid() | |
ax_scatter[2].grid() | |
ax_scatter[0].set_xscale("log") | |
ax_scatter[1].set_xscale("log") | |
ax_scatter[2].set_xscale("log") | |
fig_scatter.tight_layout() | |
fig_scatter.savefig(out_file_PR.replace(".pdf", "_scatter_700_07.pdf")) | |
print("Saved to", out_file_PR.replace(".pdf", "_scatter_700_07.pdf")) | |
fig_scatter, ax_scatter = plt.subplots(3, 1, figsize=(sz, sz*3)) | |
for key in result_scatter_900_03: | |
scatter_plot(ax_scatter[0], result_scatter_900_03[key][0], result_scatter_900_03[key][1], key) | |
scatter_plot(ax_scatter[1], result_scatter_900_03[key][0], result_scatter_900_03[key][2], key) | |
scatter_plot(ax_scatter[2], result_scatter_900_03[key][0], result_scatter_900_03[key][3], key) | |
ax_scatter[0].set_ylabel("Precision") | |
ax_scatter[1].set_ylabel("Recall") | |
ax_scatter[2].set_ylabel("F1 score") | |
ax_scatter[0].set_xlabel("N params") | |
ax_scatter[1].set_xlabel("N params") | |
ax_scatter[2].set_xlabel("N params") | |
ax_scatter[0].legend() | |
ax_scatter[1].legend() | |
ax_scatter[2].legend() | |
ax_scatter[0].grid() | |
ax_scatter[1].grid() | |
ax_scatter[2].grid() | |
ax_scatter[0].set_xscale("log") | |
ax_scatter[1].set_xscale("log") | |
ax_scatter[2].set_xscale("log") | |
fig_scatter.tight_layout() | |
fig_scatter.savefig(out_file_PR.replace(".pdf", "_scatter_900_03.pdf")) | |
print("Saved to", out_file_PR.replace(".pdf", "_scatter_900_03.pdf")) | |
fig_scatter, ax_scatter = plt.subplots(3, 2, figsize=(sz*2, sz*3)) | |
fig_params_paper, ax_params_paper = plt.subplots(1, 2, figsize=(sz, sz*1.5)) | |
for i, key in enumerate(sorted(list(result_by_step.keys()))): | |
for model in result_by_step[key]: | |
#scatter_plot(ax_scatter[], result_scatter_900_03[key][0], result_scatter_900_03[key][1], key) | |
#scatter_plot(ax_scatter[1], result_scatter_900_03[key][0], result_scatter_900_03[key][2], key) | |
#scatter_plot(ax_scatter[2], result_scatter_900_03[key][0], result_scatter_900_03[key][3], key) | |
if "AK8" in model: | |
# put a horizontal dotted line instead of a scatterplot, as there is only one dot | |
colors = {"AK8": "gray", "AK8_GenJets": "black"} | |
ax_scatter[0, i].axhline(result_by_step[key][model][1][0], label=model, color=colors[model], linestyle="--") | |
ax_scatter[1, i].axhline(result_by_step[key][model][2][0], label=model, color=colors[model], linestyle="--") | |
ax_scatter[2, i].axhline(result_by_step[key][model][3][0], label=model, color=colors[model], linestyle="--") | |
else: | |
scatter_plot(ax_scatter[0, i], result_by_step[key][model][0], result_by_step[key][model][1], model) | |
scatter_plot(ax_scatter[1, i], result_by_step[key][model][0], result_by_step[key][model][2], model) | |
scatter_plot(ax_scatter[2, i], result_by_step[key][model][0], result_by_step[key][model][3], model) | |
ax_scatter[0, i].set_title(key) | |
ax_scatter[1, i].set_title(key) | |
ax_scatter[2, i].set_title(key) | |
ax_scatter[0, i].set_ylabel("Precision") | |
ax_scatter[1, i].set_ylabel("Recall") | |
ax_scatter[2, i].set_ylabel("F_1 score") | |
ax_scatter[0, i].set_xlabel("training steps") | |
ax_scatter[1, i].set_xlabel("training steps") | |
ax_scatter[2, i].set_xlabel("training steps") | |
ax_scatter[0, i].legend() | |
ax_scatter[1, i].legend() | |
ax_scatter[2, i].legend() | |
ax_scatter[0, i].grid() | |
ax_scatter[1, i].grid() | |
ax_scatter[2, i].grid() | |
ax_scatter[0, i].set_xscale("log") | |
ax_scatter[1, i].set_xscale("log") | |
ax_scatter[2, i].set_xscale("log") | |
fig_scatter.tight_layout() | |
fig_scatter.savefig(out_file_PR.replace(".pdf", "_by_step.pdf")) | |
print("Saved to", out_file_PR.replace(".pdf", "_by_step.pdf")) | |