File size: 9,391 Bytes
4f5540c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
import math
from operator import index
import numpy as np
import torch
import torch_geometric
from typing import Dict, Iterable, Callable, Tuple
from polymerlearn.utils import make_like_batch
from polymerlearn.utils.graph_prep import get_AG_info
from polymerlearn.explain.custom_gcam import LayerGradCam

# Source: https://medium.com/the-dl/how-to-use-pytorch-hooks-5041d777f904

class FeatureExtractor(torch.nn.Module):
    '''
    Extracts inputs/outputs to each layer in the model

    Source: https://medium.com/the-dl/how-to-use-pytorch-hooks-5041d777f904
    
    '''
    

    def __init__(self, model: torch.nn.Module, use_mono: bool = False):
        super().__init__()
        self.model = model
        self.use_mono = use_mono
        self.layers = ['sage'] if self.use_mono else ['Asage', 'Gsage']
        #print(self.layers)
        if self.use_mono:
            self._features = {layer: None for layer in ['Asage', 'Gsage']}
        else:
            self._features = {layer: torch.empty(0) for layer in self.layers}

        for layer_id in self.layers:
            layer = dict([*self.model.named_modules()])[layer_id]
            # Register forward hook to get intermediate outputs of layers
            layer.register_forward_hook(self.save_outputs_hook(layer_id))

    def save_outputs_hook(self, layer_id: str) -> Callable:
        '''
        Hook function for saving outputs of intermediate layers
        '''
        if self.use_mono:
            def fn(_, __, output):
                if self._features['Asage'] is not None:
                    self._features['Gsage'] = output
                    #print('Reg G')
                else:
                    self._features['Asage'] = output
                    #print('Reg A')
        else:
            def fn(_, __, output):
                self._features[layer_id] = output
        return fn

    def forward(self, input_tup) -> Dict[str, torch.Tensor]:
        _ = self.model(*input_tup)
        # print('Features', self._features)
        # print('Features', len(self._features['Asage']))
        # print('Features', len(self._features['Gsage']))
        # exit()
        if self.use_mono:
            feat_copy = self._features
            self._features = {layer: None for layer in ['Asage', 'Gsage']}
            return feat_copy
        else:
            return self._features

def parse_batches(
        batch: torch_geometric.data.Batch, 
        add_test: torch.Tensor):
    Abatch, Gbatch = make_like_batch(batch)

    A_X = Abatch.x
    A_edge_index = Abatch.edge_index
    A_batch = Abatch.batch

    G_X = Gbatch.x
    G_edge_index = Gbatch.edge_index
    G_batch = Gbatch.batch

    return (A_X, 
        A_edge_index,
        A_batch,
        G_X,
        G_edge_index,
        G_batch,
        torch.tensor(add_test).float())

def index_to_batch_mapper(batch, ratio = 0.5):
    '''
    Computes a backwards map from index in a SAGPool output
      to the original sample inputs.
    '''
    num_batches = max(batch).item() + 1
    #print(f'batch (size: {batch.shape})', batch)
    #print('Num batches', num_batches)
    batch_sizes = [torch.sum(batch == b).item() for b in range(num_batches)]
    #print('Batch sizes', batch_sizes)

    # Multiply and take math.ceil for each batch
    final_sizes = [math.ceil(b * ratio) for b in batch_sizes]
    final_sizes = np.cumsum(final_sizes)
    #print(final_sizes)

    # Now return dictionary mapping integer index to the given input sample:
    ind_map = {}
    for i in range(len(final_sizes)):
        bottom = 0 if i == 0 else final_sizes[i-1]
        for j in range(bottom, final_sizes[i]):
            ind_map[j] = i

    return ind_map

dim1_sum = lambda t: torch.sum(t, dim=1)
dim1_L1norm = lambda t: torch.norm(t, p=1, dim=1)

class PolymerGNNExplainer:
    '''
    Explainer for the PolymerGNN. Uses Grad CAM with Captum implementation.
    '''

    def __init__(self, model: torch.nn.Module, explain_layer = 'fc1',
            pool_ratio = 0.5, use_mono: bool = False):
        
        self.model = model
        self.explain_layer = explain_layer
        self.ratio = pool_ratio
        self.use_mono = use_mono
        self.gcam  = LayerGradCam(model, getattr(model, explain_layer))
        self.extractor = FeatureExtractor(model, use_mono = self.use_mono)

    def get_attribution(self, 
            batch: Tuple,
            add_test: torch.Tensor,
            mol_rep_agg = dim1_sum):
        '''
        Get explaination for a given sample from the dataset on the model.

        ..note:: Assumes max pooling. Would need to implement another expansion
            to work backwards through another pooling method.

        Args:
        '''
        # Parse the batches for captum usage
        batches_tup = parse_batches(batch, add_test) # Parses batch into appropriate input for GNN
        input_tup = tuple([batches_tup[j] for j in range(1, len(batches_tup))])

        if mol_rep_agg is None:
            mol_rep_agg = lambda x: x

        # Compute the attribution from captum
        attribution = self.gcam.attribute(
            batches_tup[0],
            additional_forward_args = input_tup,
            attribute_to_layer_input = True
        )

        # Get intermediate features in a feedforward step
        features = self.extractor(batches_tup)

        def attr_scores(key = 'A', hc = 32):
            #print(key)
            bind = 2 if key == 'A' else -2 # Location of batch
            add_to_bottom = 0 if key == 'A' else 32
            # Map indices to batches
            ind_map = index_to_batch_mapper(batches_tup[bind], ratio = self.ratio)
            #print(ind_map)

            # Set which layer to get attributions from
            str_key = '{}sage'.format(key)
            #print('str key', features[str_key][0].shape)

            # assert (max(ind_map.keys()) + 1) == features[str_key][0].shape[0], \
            #     'Mismatch size dict={} vs. feat={}'.format((max(ind_map.keys()) + 1), features[str_key][0].shape[0])

            #print('Dict', max(ind_map.keys()) + 1)
            #print('Features', features[str_key][0].shape[0])

            # Get argmax of features on which to assign attributions
            feat_argmax = torch.argmax(features[str_key][0], dim = 0)
            # Accesses features for the given layer, defined by key
            #print(feat_argmax)

            # Expand scores backward from the max pooling:
            scores = torch.zeros((len(set(ind_map.values())), 32))
            for j in range(feat_argmax.shape[0]):
                score_ind = ind_map[feat_argmax[j].item()]
                scores[score_ind,j] = attribution[add_to_bottom + j] 

            return scores

        # Aggregates molecular representations together in scores:
        scores = {
            'A': mol_rep_agg(attr_scores('A')).detach().clone(),
            'G': mol_rep_agg(attr_scores('G')).detach().clone()
        }

        #print('-----------------------------------------------')

        # Score individual attributes:
        num_add = add_test.shape[0]

        scores['add'] = attribution[-num_add:].detach().clone()

        return scores

    def get_testing_explanation(self,
            dataset,
            test_inds = None,
            add_data_keys = ['Mw', 'AN', 'OHN', '%TMP']):
        '''
        
        Args:
            dataset: Dataset object from which to extract
            test_inds (list of ints, optional): If given, extracts testing 
                data from the dataset with respect to the indices.
            add_data_keys (list of str): List that should have the same
                length as additional 
        '''

        if test_inds is None:
            test_batch, Ytest, add_test = dataset.get_test()
            test_inds = dataset.test_mask
        else:
            test_batch = dataset.make_dataloader_by_mask(test_inds)
            Ytest = np.array(dataset.get_Y_by_mask(test_inds))
            add_test = dataset.get_additional_by_mask(test_inds)

        exp_summary = []

        # Summary tools for acid/glycol scores
        acid_key = {a:[] for a in dataset.acid_names}
        glycol_key = {g:[] for g in dataset.glycol_names}
        additional_key = {a:[] for a in add_data_keys}
        acids, glycols, _, _ = get_AG_info(dataset.data)

        for i in range(Ytest.shape[0]):
            scores = self.get_attribution(test_batch[i], add_test[i], mol_rep_agg=dim1_L1norm)
            Ti = test_inds[i]
            scores['table_ind'] = Ti

            # print(scores)
            # print(acids[Ti])
            # print(glycols[Ti])

            for a in range(len(acids[Ti])):
                Ascore = scores['A'].item() if len(acids[Ti]) == 1 else scores['A'][a].item()
                acid_key[acids[Ti][a]].append(Ascore)
            
            for g in range(len(glycols[Ti])):
                Gscore = scores['G'].item() if len(glycols[Ti]) == 1 else scores['G'][g].item()
                glycol_key[glycols[Ti][g]].append(Gscore)

            # Assign attributions to additional elements:
            for j in range(len(add_data_keys)):
                v = scores['add'][j - len(add_data_keys)].item()
                scores[add_data_keys[j]] = v
                additional_key[add_data_keys[j]].append(v)

            exp_summary.append(scores)

        return exp_summary, acid_key, glycol_key, additional_key