private_polymer_compound_prediction / inference_polymers_gnn.py
bndl's picture
Upload 115 files
4f5540c
raw
history blame
2.97 kB
import os
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
import torch
import pandas as pd
from polymerlearn.utils import GraphDataset, get_Tg_add
from polymerlearn.models.gnn import PolymerGNN_Tg
from polymerlearn.utils import make_like_batch
def convert_to_graphdataset(df, acid_positions = (0, 12), glycol_positions = (13, 25)):
"""
Takes a dataframe with the input data and converts it to a graph dataset for the prediction
For the Tg the model takes the proportions of acid and glycols and the log of the Mw (PS) (no other parameters)
"""
targets = ["Tg"]
# Need to pass the target column to do the preprocessing - not sure why it is needed - must be clarified
for col in targets:
df[col] = 0
df_duplicated = pd.concat([df, df]) # Because must specify a test set - this must be changed in the code
add_duplicated = get_Tg_add(df_duplicated)
print(f"Inference to be done on data of size {df.shape}")
graph_data = GraphDataset(
data = df_duplicated,
structure_dir = './Structures/AG/xyz',
Y_target=targets,
test_size = 0.5,
add_features = add_duplicated,
ac = acid_positions,
gc = glycol_positions
)
test_data, Ytest, add_test = graph_data.get_test()
print(f"After data preprocessing, inference on {len(test_data)}")
n_predictions = len(Ytest)
return test_data, add_test, n_predictions
def predict_from_graph(input_graph_data, input_additional_features, n_predictions, model_path="../polymerlearn/data_models/tg_model_test.pth"):
model = PolymerGNN_Tg(
input_feat= 6, # How many input features on each node; don't change this
hidden_channels= 32, # How many intermediate dimensions to use in model
# Can change this ^^
num_additional= 1 # How many additional resin properties to include in the prediction
# Corresponds to the number in get_IV_add
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
state = torch.load(model_path, map_location=device)
model.load_state_dict(state)
model.eval()
predictions = []
print(f"Prediction done on {n_predictions}")
with torch.no_grad():
for i in range(n_predictions):
batch_like_tup = make_like_batch(input_graph_data[i])
pred = model(*batch_like_tup, input_additional_features[i]).item()
predictions.append(pred)
return predictions
def predict(df, acid_positions = (0, 12), glycol_positions = (13, 25), model_path="../polymerlearn/data_models/tg_model_test.pth"):
test_data, add_test, n_predictions = convert_to_graphdataset(df, acid_positions=acid_positions, glycol_positions=glycol_positions)
predictions = predict_from_graph(test_data, add_test, n_predictions, model_path=model_path)
return predictions