File size: 5,147 Bytes
e75a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import matplotlib.pyplot as plt
import numpy as np


def ax_tiny_histogram(ax, labels, colors, values):
    # Create bars
    bars = ax.barh(range(len(labels)), values, color=colors, alpha=0.5)

    # Add labels inside the bars, left-aligned
    for i, (bar, label) in enumerate(zip(bars, labels)):
        ax.text(min(values)-0.004, bar.get_y() + bar.get_height()/2, label,
                va='center', ha='left', fontsize=8, color='black', clip_on=True)
        ax.text(max(values)+0.001, bar.get_y() + bar.get_height()/2,  f'{values[i]:.3f}',
        va='center', ha='right', fontsize=8)

    ax.set_yticks([])
    ax.set_xticks([])  # Hide ticks for minimal look
    for spine in ax.spines.values():
        spine.set_visible(False)

    ax.set_xlim(min(values)-0.005, max(values)+0.005)
    return ax


def multiple_matrix_plot(result, labels, colors, custom_val_formula=lambda x: 2*x[0]*x[1]/(x[0]+x[1]), rename_dict={}): # custom_val_formula set to F1 score, and x is [precision, recall]
    # result: mDark -> mMed -> rinv -> {label->[P, R]} - the order of the labels is set with 'labels' and the colors are set with 'colors'
    # labels: list of labels to plot
    mediator_masses = sorted(list(result.keys()))
    r_invs = sorted(list(set([rinv for mMed in result for rinv in result[mMed]])))
    sz = 3
    #fig, ax = plt.subplots(len(mediator_masses), len(r_invs), figsize=(sz*len(r_invs), 6*len(mediator_masses)))
    fig, ax = plt.subplots(len(mediator_masses), len(r_invs), figsize=(sz*len(r_invs), 0.65*sz*len(mediator_masses)))
    if len(mediator_masses) == 1 and len(r_invs) == 1:
        ax = np.array([[ax]])
    for i, mMed in enumerate(mediator_masses):
        for k, rinv in enumerate(r_invs):
            if mMed not in result:
                continue
            if rinv not in result[mMed]:
                continue
            r = result[mMed][rinv]
            r = {key: custom_val_formula(val) for key, val in r.items()}
            #ax_tiny = fig.add_axes([0.3, 0.1 + i*0.2, 0.15, 0.15])
            #ax_tiny = fig.add_axes([0.1 + k*0.2, 0.1 + i*0.2, 0.15, 0.15])
            for label in labels:
                if label not in r:
                    print("Label not in result:", label , " - skipping!")
                    return None, None
            ax_tiny_histogram(ax[i, k], [rename_dict.get(l,l) for l in labels], colors, [r[label] for label in labels])
            ax[i, k].set_title(f"$m_{{Z'}}$ = {mMed} GeV, $r_{{inv.}}$ = {rinv}")
            #ax.set_title(f"$m_{mMed}$ GeV")
            #ax.set_xlabel("$r_{inv}$")
            #ax.set_ylabel("$m_{Z'}$ [GeV]")
    #ax.set_xticks(range(len(r_invs)))
    #ax.set_xticklabels(r_invs)
    #ax.set_yticks(range(len(mediator_masses)))
    #ax.set_yticklabels(mediator_masses)
    fig.tight_layout()
    return fig, ax

def matrix_plot(result, color_scheme, cbar_label, ax=None, metric_comp_func=None, is_qcd=False):
    make_fig = ax is None
    dark_masses = [20]
    if is_qcd:
        dark_masses = [0]
    if make_fig:
        fig, ax = plt.subplots(len(dark_masses), 1, figsize=(5, 5))
    mediator_masses = sorted(list(result.keys()))
    r_invs = sorted(list(set([rinv for mMed in result for mDark in result[mMed] for rinv in result[mMed][mDark]])))
    if len(dark_masses) == 1:
        ax = [ax]
    for i, mDark in enumerate(dark_masses):
        data = np.zeros((len(mediator_masses), len(r_invs)))
        for j, mMed in enumerate(mediator_masses):
            for k, rinv in enumerate(r_invs):
                if mMed not in result:
                    continue
                if mDark not in result[mMed]:
                    continue
                if rinv not in result[mMed][mDark]:
                    continue
                r = result[mMed][mDark][rinv]
                if metric_comp_func is not None:
                    try:
                        r = metric_comp_func(r)
                    except:
                        r=0
                data[j, k] = r
        ax[i].imshow(data, cmap="Blues")
        for (j, k), val in np.ndenumerate(data):
            ax[i].text(k, j, f'{val:.3f}', ha='center', va='center', color='black')
        ax[i].set_xticks(range(len(r_invs)))
        ax[i].set_xticklabels(r_invs)
        ax[i].set_yticks(range(len(mediator_masses)))
        ax[i].set_yticklabels(mediator_masses)
        ax[i].set_xlabel("$r_{inv}$")
        ax[i].set_ylabel("$m_{Z'}$ [GeV]")
        #ax[i].set_title(f"mDark = {mDark} GeV")
        if color_scheme.lower() == "greens":
            # color it from 0 to 1.0 - set limits on the cbar
            cbar = plt.colorbar(ax[i].imshow(data, cmap=color_scheme), ax=ax[i])
        else:
            cbar = plt.colorbar(ax[i].imshow(data, cmap=color_scheme), ax=ax[i])
        cbar.set_label(cbar_label)
    if make_fig:
        fig.tight_layout()
        return fig

def scatter_plot(ax, xs, ys, label, color=None, pattern=".--"):
    idx = np.argsort(xs)
    xs = np.array(xs)[idx]
    ys = np.array(ys)[idx]
    if color is not None:
        ax.plot(xs, ys, pattern, label=label, color=color)
    else:
        ax.plot(xs, ys, pattern, label=label, color=color)