File size: 6,289 Bytes
4f5540c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
import torch
import numpy as np
import pandas as pd

from torch_geometric.data import Batch, Data
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from polymerlearn.utils import make_like_batch
from polymerlearn.utils.train_graphs import check_early_stop\

from polymerlearn.utils.losses import evidential_loss


def clone_dict(d):
    '''
    Clones all elements (assumed torch.Tensors) of dictionary d
    '''

    clone_dict = {}
    for k, v in d:
        clone_dict[k] = v.detach().clone().item()

    return clone_dict

def CV_eval_evidential(
        dataset,
        model_generator: torch.nn.Module,
        optimizer_generator,
        model_generator_kwargs: dict = {},
        optimizer_kwargs: dict = {},
        batch_size = 64,
        verbose = 1,
        epochs = 1000,
        use_val = False,
        get_scores = False,
        val_size = 0.1,
        stop_option = 0,
        early_stop_delay = 100):
    '''
    
    Args:
        stop_option (int): Option that specifies which method to use for early
            stopping/validation saving. 0 simply performs all epochs for each fold.
            1 performs all epochs but uses model with highest validation score for 
            evaluation on test set. 2 stops early if the validation loss was at least
            `early_stop_delay` epochs ago; it loads that trial's model and evaluates
            on it.
    
    '''

    num_folds = 5
    fold_count = 0

    r2_test_per_fold = []
    mse_test_per_fold = []
    mae_test_per_fold = []

    all_predictions = []
    all_y = []
    all_reference_inds = []

    for test_batch, Ytest, add_test, test_inds in \
            dataset.Kfold_CV(folds = num_folds, val = use_val, val_size = val_size):

        # Instantiate fold-level model and optimizer:
        model = model_generator(**model_generator_kwargs)
        optimizer = optimizer_generator(model.parameters(), **optimizer_kwargs)

        fold_count += 1
        loss_list = []

        if stop_option >= 1:
            min_val_loss = 1e10
            min_val_state_dict = None

        for e in range(epochs):
            
            # Bootstrap batches:
            batch, Y, add_features = dataset.get_train_batch(size = batch_size)

            train_predictions = []
            cum_loss = 0

            for i in range(batch_size):

                # Predictions:
                train_prediction = model(*make_like_batch(batch[i]), torch.tensor(add_features[i]).float())
                train_predictions.append(train_prediction['gamma'].detach().clone().item())

                # Compute and backprop loss
                #loss = criterion(train_prediction, torch.tensor([Y[i]]))
                loss = evidential_loss(torch.tensor([Y[i]]), 
                    output_dict=train_prediction,
                    coef = 1)
                optimizer.zero_grad()
                loss.backward()
                cum_loss += loss.item()
                optimizer.step()

            # Test on validation:
            if use_val:
                model.eval()
                val_batch, Yval, add_feat_val = dataset.get_validation()
                cum_val_loss = 0
                val_preds = []
                with torch.no_grad():
                    for i in range(Yval.shape[0]):
                        pred = model(*make_like_batch(val_batch[i]), add_feat_val[i])
                        val_preds.append(pred['gamma'].detach().clone().item())
                        cum_val_loss += evidential_loss(Yval[i], pred, coef = 1).item()
                    
                loss_list.append(cum_val_loss)
                model.train() # Must switch back to train after eval

            if e % 50 == 0 and (verbose == 1):
                print_str = f'Fold: {fold_count} \t Epoch: {e}, \
                    \t Train r2: {r2_score(Y, train_predictions):.4f} \t Train Loss: {cum_loss:.4f}'
                if use_val:
                    print_str += f'\t Val r2: {r2_score(Yval, val_preds):.4f} \t Val Loss: {cum_val_loss:.4f}'
                print(print_str)

            if stop_option >= 1:
                if cum_val_loss < min_val_loss:
                    # If min val loss, store state dict
                    min_val_loss = cum_val_loss
                    min_val_state_dict = model.state_dict()

            # Check early stop if needed:
            if stop_option == 2:
                # Check criteria:
                if check_early_stop(loss_list, early_stop_delay) and e > early_stop_delay:
                    break

        
        if stop_option >= 1: # Loads the min val loss state dict even if we didn't break
            # Load in the model with min val loss
            model = model_generator(**model_generator_kwargs)
            model.load_state_dict(min_val_state_dict)

        # Test:
        test_preds = []
        with torch.no_grad():
            for i in range(Ytest.shape[0]):
                pred = model(*make_like_batch(test_batch[i]), torch.tensor(add_test[i]).float())
                test_preds.append(pred['gamma'].clone().detach().item())
                all_predictions.append(pred)
                all_y.append(Ytest[i].item())
                all_reference_inds.append(test_inds[i])

        r2_test = r2_score(Ytest.numpy(), test_preds)
        mse_test = mean_squared_error(Ytest.numpy(), test_preds)
        mae_test = mean_absolute_error(Ytest.numpy(), test_preds)

        print(f'Fold: {fold_count} \t Test r2: {r2_test:.4f} \t Test Loss: {mse_test:.4f} \t Test MAE: {mae_test:.4f}')

        r2_test_per_fold.append(r2_test)
        mse_test_per_fold.append(mse_test)
        mae_test_per_fold.append(mae_test)

    print('Final avg. r2: ', np.mean(r2_test_per_fold))
    print('Final avg. MSE:', np.mean(mse_test_per_fold))
    print('Final avg. MAE:', np.mean(mae_test_per_fold))

    r2_avg = np.mean(r2_test_per_fold)
    mae_avg = np.mean(mae_test_per_fold)

    big_ret_dict = {
        'r2': r2_avg,
        'mae': mae_avg,
        'all_predictions': all_predictions,
        'all_y': all_y,
        'all_reference_inds': all_reference_inds,
        #'model_state_dicts': model_state_dicts
    }

    if get_scores:
        return big_ret_dict

    return all_predictions, all_y, all_reference_inds