Dyna-1 / Dyna-1 /utils.py
gelnesr's picture
Initial commit
74bc48e
raw
history blame
2.96 kB
import yaml
import argparse
import numpy as np
import torch
import torch.nn as nn
from torcheval.metrics.functional import binary_auroc, binary_auprc
from sklearn.metrics import recall_score, precision_score
import mdtraj as md
def calc_dssp(pdb_file):
""" Calculate DSSP for a PDB file """
obj = md.load_pdb(pdb_file)
return ''.join([x for x in md.compute_dssp(obj, simplified=True).tolist()[0]]).replace(' ','.')
def dict2namespace(config):
""" Generate namespace given a config file (dictionary) """
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
def load_config(path, return_dict=False):
""" Load a config file """
with open(path, "r") as f:
config_dict = yaml.safe_load(f)
config = dict2namespace(config_dict)
if return_dict:
return config, config_dict
return config
def get_loss(logits, labels, mask):
"""Compute masked loss."""
loss = nn.BCEWithLogitsLoss(reduction='none', pos_weight = torch.tensor(1.0))(logits, labels)
masked_loss = loss * mask
mean_loss = masked_loss.sum() / mask.sum()
return masked_loss.sum(), mask.sum(), mean_loss
def get_masked(logits, labels, mask):
"""Mask logits and labels using a given mask"""
if isinstance(mask, np.ndarray):
masked_logits = logits[mask.astype(bool)]
masked_labels = labels[mask.astype(bool)]
else:
masked_logits = logits[mask.bool()].view(-1)
masked_labels = labels[mask.bool()].view(-1)
return masked_logits, masked_labels
def get_auroc(masked_logits, masked_labels):
"""Evaluate AUROC, AUPRC, and norm AUPRC for given logits, labels"""
if isinstance(masked_labels, np.ndarray):
masked_labels = torch.tensor(masked_labels)
if isinstance(masked_logits, np.ndarray):
masked_logits = torch.tensor(masked_logits)
roc = binary_auroc(masked_logits, masked_labels)
prc = binary_auprc(masked_logits, masked_labels)
baseline = masked_labels.sum()/len(masked_labels)
return abs(roc), abs(prc), abs(prc)-baseline
def get_pr_metrics(masked_logits, masked_labels, t):
"""Evaluate F1-score, precision, recall for given logits, labels, and threshold"""
pred_labels = (masked_logits > t).astype(int)
r = recall_score(masked_labels, pred_labels)
p = precision_score(masked_labels,pred_labels)
f1 = 2 * r * p / (r+p+1e-8)
n_labels_tot = len(masked_labels)
n_pos = masked_labels.sum()
return f1, p, r, n_labels_tot, n_pos
def prob_adjusted(logits, prior=0.05):
"""Adjust outputted logits to reflect a probability between 0.0-1.0"""
adjustment = np.log(prior/(1-prior))
prob = torch.sigmoid(logits-adjustment)
return prob