import os | |
import sys | |
from pathlib import Path | |
sys.path.insert(0, str(Path(__file__).resolve().parents[1])) | |
import torch | |
import pandas as pd | |
from polymerlearn.utils import GraphDataset, get_Tg_add, get_IV_add | |
from polymerlearn.models.gnn import PolymerGNN_Tg, PolymerGNN_IV | |
from polymerlearn.utils import make_like_batch | |
import numpy as np | |
# def convert_to_graphdataset(df, acid_positions = (0, 12), glycol_positions = (13, 25)): | |
# """ | |
# Takes a dataframe with the input data and converts it to a graph dataset for the prediction | |
# For the Tg the model takes the proportions of acid and glycols and the log of the Mw (PS) (no other parameters) | |
# """ | |
# targets = ["Tg", "IV"] | |
# # Need to pass the target column to do the preprocessing - not sure why it is needed - must be clarified | |
# for col in targets: | |
# df[col] = 0 | |
# df_duplicated = pd.concat([df, df]) # Because must specify a test set - this must be changed in the code | |
# add_duplicated = get_Tg_add(df_duplicated) | |
# print(f"Inference to be done on data of size {df.shape}") | |
# graph_data = GraphDataset( | |
# data = df_duplicated, | |
# structure_dir = './Structures/AG/xyz', | |
# Y_target=targets, | |
# test_size = 0.5, | |
# add_features = add_duplicated, | |
# ac = acid_positions, | |
# gc = glycol_positions | |
# ) | |
# test_data, Ytest, add_test = graph_data.get_test() | |
# print(f"After data preprocessing, inference on {len(test_data)}") | |
# n_predictions = len(Ytest) | |
# return test_data, add_test, n_predictions | |
def predict(df, model_path="../polymerlearn/data_models/", acid_positions = (0, 12), glycol_positions = (13, 25)): | |
targets = ["Tg", "IV"] | |
# Need to pass the target column to do the preprocessing - not sure why it is needed - must be clarified | |
for col in targets: | |
df[col] = 0 | |
df_duplicated = pd.concat([df, df]) # Because must specify a test set - this must be changed in the code | |
add_features = {"Tg": get_Tg_add(df_duplicated), "IV": get_IV_add(df_duplicated)} | |
print(f"Inference to be done on data of size {df.shape}") | |
print("Tg", add_features["Tg"].shape) | |
print("Tg", add_features["IV"].shape) | |
pred_all = [] | |
for pred in targets: | |
graph_data = GraphDataset( | |
data = df_duplicated, | |
structure_dir = './Structures/AG/xyz', | |
Y_target=targets, | |
test_size = 0.5, | |
add_features = add_features[pred], | |
ac = acid_positions, | |
gc = glycol_positions | |
) | |
test_data, Ytest, add_test = graph_data.get_test() | |
print(f"After data preprocessing, inference on {len(test_data)}") | |
n_predictions = len(Ytest) | |
if pred == "Tg": | |
model = PolymerGNN_Tg( | |
input_feat= 6, # How many input features on each node; don't change this | |
hidden_channels= 32, # How many intermediate dimensions to use in model | |
# Can change this ^^ | |
num_additional= add_features[pred].shape[1] # How many additional resin properties to include in the prediction | |
# Corresponds to the number in get_IV_add | |
) | |
else: | |
model = PolymerGNN_IV( | |
input_feat= 6, # How many input features on each node; don't change this | |
hidden_channels= 32, # How many intermediate dimensions to use in model | |
# Can change this ^^ | |
num_additional= add_features[pred].shape[1] # How many additional resin properties to include in the prediction | |
# Corresponds to the number in get_IV_add | |
) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
state = torch.load(os.path.join(model_path, f"{pred}_model.pth"), map_location=device) | |
model.load_state_dict(state) | |
model.eval() | |
predictions = [] | |
print(f"Prediction done on {n_predictions}") | |
with torch.no_grad(): | |
for i in range(n_predictions): | |
batch_like_tup = make_like_batch(test_data[i]) | |
pred = np.round(model(*batch_like_tup, add_test[i]).item(), 1) | |
predictions.append(pred) | |
pred_all.append(predictions) | |
return pred_all | |
# def predict(df, acid_positions = (0, 12), glycol_positions = (13, 25), model_path="../polymerlearn/data_models/tg_model_test.pth"): | |
# test_data, add_test, n_predictions = convert_to_graphdataset(df, acid_positions=acid_positions, glycol_positions=glycol_positions) | |
# predictions = predict_from_graph(test_data, add_test, n_predictions, model_path=model_path) | |
# return predictions |