File size: 5,952 Bytes
e75a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from typing import Tuple, Union
import numpy as np
import torch
from torch_scatter import scatter_max, scatter_add, scatter_mean
from src.layers.object_cond import assert_no_nans, scatter_count, batch_cluster_indices
import dgl


def calc_energy_loss(
    batch, cluster_space_coords, beta, beta_stabilizing="soft_q_scaling", qmin=0.1, radius=0.7,
    e_frac_loss_return_particles=False, y=None, select_centers_by_particle=True
):
    # select_centers_by_particle: if True, we pretend we know which hits belong to which particle...
    list_graphs = dgl.unbatch(batch)
    node_counter = 0
    if beta_stabilizing == "paper":
        q = beta.arctanh() ** 2 + qmin
    elif beta_stabilizing == "clip":
        beta = beta.clip(0.0, 1 - 1e-4)
        q = beta.arctanh() ** 2 + qmin
    elif beta_stabilizing == "soft_q_scaling":
        q = (beta.clip(0.0, 1 - 1e-4) / 1.002).arctanh() ** 2 + qmin
    else:
        raise ValueError(f"beta_stablizing mode {beta_stabilizing} is not known")
    loss_E_frac = []
    loss_E_frac_true = []
    particle_ids_all = []
    reco_count = {}  # per-PID count
    non_reco_count = {}
    total_count = {}
    for g in list_graphs:
        particle_id = g.ndata["particle_number"]
        number_of_objects = len(particle_id.unique())
        print("No. of objects", number_of_objects)
        non = g.number_of_nodes()
        q_g = q[node_counter : non + node_counter]
        betas = beta[node_counter : non + node_counter]
        sorted, indices = torch.sort(betas.view(-1), descending=False)
        selected_centers = indices[0:number_of_objects]
        _, selected_centers_particles = scatter_max(
            betas.flatten().cpu(), particle_id.cpu().long() - 1
        )
        assert selected_centers.shape[0] == number_of_objects
        if select_centers_by_particle:
            selected_centers = selected_centers_particles.to(g.device)
        all_particles = set((particle_id.unique()-1).long().tolist())
        reco_particles = set((particle_id[selected_centers]-1).long().tolist())
        non_reco_particles = all_particles - reco_particles
        part_pids = y[:, 6].long()
        for particle in all_particles:
            curr_pid = part_pids[particle].item()
            if curr_pid in total_count:
                total_count[curr_pid] += 1
            else:
                total_count[curr_pid] = 1
            if particle in reco_particles:
                if curr_pid in reco_count:
                    reco_count[curr_pid] += 1
                else:
                    reco_count[curr_pid] = 1
            else:
                if curr_pid in non_reco_count:
                    non_reco_count[curr_pid] += 1
                else:
                    non_reco_count[curr_pid] = 1
        X = cluster_space_coords[node_counter : non + node_counter]

        if radius == "dynamic":
            pick_ = torch.argsort(
                torch.cdist(X[selected_centers], X[selected_centers], p=2),
                dim=1)[:, 1]
            current_radius = torch.cdist(torch.Tensor(X[selected_centers]), torch.Tensor(X[selected_centers]), p=2).gather(1, pick_.view(-1, 1))
            current_radius = current_radius / 2
            current_radius = max(0.1, current_radius.flatten().min())
            print("Current radius", current_radius)
        else:
            print("Radius", radius)
            current_radius = radius
        clusterings = get_clustering(selected_centers, X, betas, td=current_radius)
        clusterings = clusterings.to(g.device)
        node_counter += non
        counter = 0
        frac_energy = []
        frac_energy_true = []
        particle_ids = []
        for alpha in selected_centers:
            id_particle = particle_id[alpha]
            true_mask_particle = particle_id == id_particle
            true_energy = torch.sum(g.ndata["e_hits"][true_mask_particle])
            mask_clustering_particle = clusterings == counter
            clustered_energy = torch.sum(g.ndata["e_hits"][mask_clustering_particle])
            clustered_energy_true = torch.sum(
                g.ndata["e_hits"][
                    mask_clustering_particle * true_mask_particle.flatten()
                ]
            )  # only consider how much has been correctly assigned
            counter += 1
            frac_energy.append(clustered_energy / (true_energy + 1e-7))
            frac_energy_true.append(clustered_energy_true / (true_energy + 1e-7))
            particle_ids.append(id_particle.cpu().long().item())
        frac_energy = torch.stack(frac_energy, dim=0)
        if not e_frac_loss_return_particles:
            frac_energy = torch.mean(frac_energy)
        frac_energy_true = torch.stack(frac_energy_true, dim=0)
        if not e_frac_loss_return_particles:
            frac_energy_true = torch.mean(frac_energy_true)
        loss_E_frac.append(frac_energy)
        loss_E_frac_true.append(frac_energy_true)
        particle_ids_all.append(particle_ids)
    if e_frac_loss_return_particles:
        return loss_E_frac, [loss_E_frac_true, particle_ids_all, reco_count, non_reco_count, total_count]
    loss_E_frac = torch.mean(torch.stack(loss_E_frac, dim=0))
    loss_E_frac_true = torch.mean(torch.stack(loss_E_frac_true, dim=0))
    return loss_E_frac, loss_E_frac_true


def get_clustering(index_alpha_i, X, betas, td=0.7):
    n_points = betas.size(0)
    unassigned = torch.arange(n_points).to(betas.device)
    clustering = -1 * torch.ones(n_points, dtype=torch.long)
    counter = 0
    for index_condpoint in index_alpha_i:
        d = torch.norm(X[unassigned] - X[index_condpoint], dim=-1)
        assigned_to_this_condpoint = unassigned[d < td]
        clustering[assigned_to_this_condpoint] = counter
        unassigned = unassigned[~(d < td)]
        counter = counter + 1
    counter = 0
    for index_condpoint in index_alpha_i:
        clustering[index_condpoint] = counter
        counter = counter + 1

    return clustering