bndl's picture
Upload 115 files
4f5540c
import os
import torch
import numpy as np
import pandas as pd
from torch_geometric.data import Batch, Data
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from polymerlearn.utils import make_like_batch
from polymerlearn.utils.train_graphs import check_early_stop\
from polymerlearn.utils.losses import evidential_loss
def clone_dict(d):
'''
Clones all elements (assumed torch.Tensors) of dictionary d
'''
clone_dict = {}
for k, v in d:
clone_dict[k] = v.detach().clone().item()
return clone_dict
def CV_eval_evidential(
dataset,
model_generator: torch.nn.Module,
optimizer_generator,
model_generator_kwargs: dict = {},
optimizer_kwargs: dict = {},
batch_size = 64,
verbose = 1,
epochs = 1000,
use_val = False,
get_scores = False,
val_size = 0.1,
stop_option = 0,
early_stop_delay = 100):
'''
Args:
stop_option (int): Option that specifies which method to use for early
stopping/validation saving. 0 simply performs all epochs for each fold.
1 performs all epochs but uses model with highest validation score for
evaluation on test set. 2 stops early if the validation loss was at least
`early_stop_delay` epochs ago; it loads that trial's model and evaluates
on it.
'''
num_folds = 5
fold_count = 0
r2_test_per_fold = []
mse_test_per_fold = []
mae_test_per_fold = []
all_predictions = []
all_y = []
all_reference_inds = []
for test_batch, Ytest, add_test, test_inds in \
dataset.Kfold_CV(folds = num_folds, val = use_val, val_size = val_size):
# Instantiate fold-level model and optimizer:
model = model_generator(**model_generator_kwargs)
optimizer = optimizer_generator(model.parameters(), **optimizer_kwargs)
fold_count += 1
loss_list = []
if stop_option >= 1:
min_val_loss = 1e10
min_val_state_dict = None
for e in range(epochs):
# Bootstrap batches:
batch, Y, add_features = dataset.get_train_batch(size = batch_size)
train_predictions = []
cum_loss = 0
for i in range(batch_size):
# Predictions:
train_prediction = model(*make_like_batch(batch[i]), torch.tensor(add_features[i]).float())
train_predictions.append(train_prediction['gamma'].detach().clone().item())
# Compute and backprop loss
#loss = criterion(train_prediction, torch.tensor([Y[i]]))
loss = evidential_loss(torch.tensor([Y[i]]),
output_dict=train_prediction,
coef = 1)
optimizer.zero_grad()
loss.backward()
cum_loss += loss.item()
optimizer.step()
# Test on validation:
if use_val:
model.eval()
val_batch, Yval, add_feat_val = dataset.get_validation()
cum_val_loss = 0
val_preds = []
with torch.no_grad():
for i in range(Yval.shape[0]):
pred = model(*make_like_batch(val_batch[i]), add_feat_val[i])
val_preds.append(pred['gamma'].detach().clone().item())
cum_val_loss += evidential_loss(Yval[i], pred, coef = 1).item()
loss_list.append(cum_val_loss)
model.train() # Must switch back to train after eval
if e % 50 == 0 and (verbose == 1):
print_str = f'Fold: {fold_count} \t Epoch: {e}, \
\t Train r2: {r2_score(Y, train_predictions):.4f} \t Train Loss: {cum_loss:.4f}'
if use_val:
print_str += f'\t Val r2: {r2_score(Yval, val_preds):.4f} \t Val Loss: {cum_val_loss:.4f}'
print(print_str)
if stop_option >= 1:
if cum_val_loss < min_val_loss:
# If min val loss, store state dict
min_val_loss = cum_val_loss
min_val_state_dict = model.state_dict()
# Check early stop if needed:
if stop_option == 2:
# Check criteria:
if check_early_stop(loss_list, early_stop_delay) and e > early_stop_delay:
break
if stop_option >= 1: # Loads the min val loss state dict even if we didn't break
# Load in the model with min val loss
model = model_generator(**model_generator_kwargs)
model.load_state_dict(min_val_state_dict)
# Test:
test_preds = []
with torch.no_grad():
for i in range(Ytest.shape[0]):
pred = model(*make_like_batch(test_batch[i]), torch.tensor(add_test[i]).float())
test_preds.append(pred['gamma'].clone().detach().item())
all_predictions.append(pred)
all_y.append(Ytest[i].item())
all_reference_inds.append(test_inds[i])
r2_test = r2_score(Ytest.numpy(), test_preds)
mse_test = mean_squared_error(Ytest.numpy(), test_preds)
mae_test = mean_absolute_error(Ytest.numpy(), test_preds)
print(f'Fold: {fold_count} \t Test r2: {r2_test:.4f} \t Test Loss: {mse_test:.4f} \t Test MAE: {mae_test:.4f}')
r2_test_per_fold.append(r2_test)
mse_test_per_fold.append(mse_test)
mae_test_per_fold.append(mae_test)
print('Final avg. r2: ', np.mean(r2_test_per_fold))
print('Final avg. MSE:', np.mean(mse_test_per_fold))
print('Final avg. MAE:', np.mean(mae_test_per_fold))
r2_avg = np.mean(r2_test_per_fold)
mae_avg = np.mean(mae_test_per_fold)
big_ret_dict = {
'r2': r2_avg,
'mae': mae_avg,
'all_predictions': all_predictions,
'all_y': all_y,
'all_reference_inds': all_reference_inds,
#'model_state_dicts': model_state_dicts
}
if get_scores:
return big_ret_dict
return all_predictions, all_y, all_reference_inds