Spaces:
Sleeping
Sleeping
import matplotlib.pyplot as plt | |
import numpy as np | |
def ax_tiny_histogram(ax, labels, colors, values): | |
# Create bars | |
bars = ax.barh(range(len(labels)), values, color=colors, alpha=0.5) | |
# Add labels inside the bars, left-aligned | |
for i, (bar, label) in enumerate(zip(bars, labels)): | |
ax.text(min(values)-0.004, bar.get_y() + bar.get_height()/2, label, | |
va='center', ha='left', fontsize=8, color='black', clip_on=True) | |
ax.text(max(values)+0.001, bar.get_y() + bar.get_height()/2, f'{values[i]:.3f}', | |
va='center', ha='right', fontsize=8) | |
ax.set_yticks([]) | |
ax.set_xticks([]) # Hide ticks for minimal look | |
for spine in ax.spines.values(): | |
spine.set_visible(False) | |
ax.set_xlim(min(values)-0.005, max(values)+0.005) | |
return ax | |
def multiple_matrix_plot(result, labels, colors, custom_val_formula=lambda x: 2*x[0]*x[1]/(x[0]+x[1]), rename_dict={}): # custom_val_formula set to F1 score, and x is [precision, recall] | |
# result: mDark -> mMed -> rinv -> {label->[P, R]} - the order of the labels is set with 'labels' and the colors are set with 'colors' | |
# labels: list of labels to plot | |
mediator_masses = sorted(list(result.keys())) | |
r_invs = sorted(list(set([rinv for mMed in result for rinv in result[mMed]]))) | |
sz = 3 | |
#fig, ax = plt.subplots(len(mediator_masses), len(r_invs), figsize=(sz*len(r_invs), 6*len(mediator_masses))) | |
fig, ax = plt.subplots(len(mediator_masses), len(r_invs), figsize=(sz*len(r_invs), 0.65*sz*len(mediator_masses))) | |
if len(mediator_masses) == 1 and len(r_invs) == 1: | |
ax = np.array([[ax]]) | |
for i, mMed in enumerate(mediator_masses): | |
for k, rinv in enumerate(r_invs): | |
if mMed not in result: | |
continue | |
if rinv not in result[mMed]: | |
continue | |
r = result[mMed][rinv] | |
r = {key: custom_val_formula(val) for key, val in r.items()} | |
#ax_tiny = fig.add_axes([0.3, 0.1 + i*0.2, 0.15, 0.15]) | |
#ax_tiny = fig.add_axes([0.1 + k*0.2, 0.1 + i*0.2, 0.15, 0.15]) | |
for label in labels: | |
if label not in r: | |
print("Label not in result:", label , " - skipping!") | |
return None, None | |
ax_tiny_histogram(ax[i, k], [rename_dict.get(l,l) for l in labels], colors, [r[label] for label in labels]) | |
ax[i, k].set_title(f"$m_{{Z'}}$ = {mMed} GeV, $r_{{inv.}}$ = {rinv}") | |
#ax.set_title(f"$m_{mMed}$ GeV") | |
#ax.set_xlabel("$r_{inv}$") | |
#ax.set_ylabel("$m_{Z'}$ [GeV]") | |
#ax.set_xticks(range(len(r_invs))) | |
#ax.set_xticklabels(r_invs) | |
#ax.set_yticks(range(len(mediator_masses))) | |
#ax.set_yticklabels(mediator_masses) | |
fig.tight_layout() | |
return fig, ax | |
def matrix_plot(result, color_scheme, cbar_label, ax=None, metric_comp_func=None, is_qcd=False): | |
make_fig = ax is None | |
dark_masses = [20] | |
if is_qcd: | |
dark_masses = [0] | |
if make_fig: | |
fig, ax = plt.subplots(len(dark_masses), 1, figsize=(5, 5)) | |
mediator_masses = sorted(list(result.keys())) | |
r_invs = sorted(list(set([rinv for mMed in result for mDark in result[mMed] for rinv in result[mMed][mDark]]))) | |
if len(dark_masses) == 1: | |
ax = [ax] | |
for i, mDark in enumerate(dark_masses): | |
data = np.zeros((len(mediator_masses), len(r_invs))) | |
for j, mMed in enumerate(mediator_masses): | |
for k, rinv in enumerate(r_invs): | |
if mMed not in result: | |
continue | |
if mDark not in result[mMed]: | |
continue | |
if rinv not in result[mMed][mDark]: | |
continue | |
r = result[mMed][mDark][rinv] | |
if metric_comp_func is not None: | |
try: | |
r = metric_comp_func(r) | |
except: | |
r=0 | |
data[j, k] = r | |
ax[i].imshow(data, cmap="Blues") | |
for (j, k), val in np.ndenumerate(data): | |
ax[i].text(k, j, f'{val:.3f}', ha='center', va='center', color='black') | |
ax[i].set_xticks(range(len(r_invs))) | |
ax[i].set_xticklabels(r_invs) | |
ax[i].set_yticks(range(len(mediator_masses))) | |
ax[i].set_yticklabels(mediator_masses) | |
ax[i].set_xlabel("$r_{inv}$") | |
ax[i].set_ylabel("$m_{Z'}$ [GeV]") | |
#ax[i].set_title(f"mDark = {mDark} GeV") | |
if color_scheme.lower() == "greens": | |
# color it from 0 to 1.0 - set limits on the cbar | |
cbar = plt.colorbar(ax[i].imshow(data, cmap=color_scheme), ax=ax[i]) | |
else: | |
cbar = plt.colorbar(ax[i].imshow(data, cmap=color_scheme), ax=ax[i]) | |
cbar.set_label(cbar_label) | |
if make_fig: | |
fig.tight_layout() | |
return fig | |
def scatter_plot(ax, xs, ys, label, color=None, pattern=".--"): | |
idx = np.argsort(xs) | |
xs = np.array(xs)[idx] | |
ys = np.array(ys)[idx] | |
if color is not None: | |
ax.plot(xs, ys, pattern, label=label, color=color) | |
else: | |
ax.plot(xs, ys, pattern, label=label, color=color) | |