Update inference_polymers_gnn.py
Browse files- inference_polymers_gnn.py +91 -51
inference_polymers_gnn.py
CHANGED
@@ -5,69 +5,109 @@ from pathlib import Path
|
|
5 |
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
6 |
import torch
|
7 |
import pandas as pd
|
8 |
-
from polymerlearn.utils import GraphDataset, get_Tg_add
|
9 |
-
from polymerlearn.models.gnn import PolymerGNN_Tg
|
10 |
from polymerlearn.utils import make_like_batch
|
11 |
|
12 |
|
13 |
-
def convert_to_graphdataset(df, acid_positions = (0, 12), glycol_positions = (13, 25)):
|
14 |
-
|
15 |
-
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
# Need to pass the target column to do the preprocessing - not sure why it is needed - must be clarified
|
21 |
for col in targets:
|
22 |
df[col] = 0
|
23 |
df_duplicated = pd.concat([df, df]) # Because must specify a test set - this must be changed in the code
|
24 |
-
|
|
|
|
|
25 |
print(f"Inference to be done on data of size {df.shape}")
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
-
return
|
66 |
|
67 |
|
68 |
-
def predict(df, acid_positions = (0, 12), glycol_positions = (13, 25), model_path="../polymerlearn/data_models/tg_model_test.pth"):
|
69 |
-
|
70 |
|
71 |
-
|
72 |
|
73 |
-
|
|
|
5 |
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
6 |
import torch
|
7 |
import pandas as pd
|
8 |
+
from polymerlearn.utils import GraphDataset, get_Tg_add, get_IV_add
|
9 |
+
from polymerlearn.models.gnn import PolymerGNN_Tg, PolymerGNN_IV
|
10 |
from polymerlearn.utils import make_like_batch
|
11 |
|
12 |
|
13 |
+
# def convert_to_graphdataset(df, acid_positions = (0, 12), glycol_positions = (13, 25)):
|
14 |
+
# """
|
15 |
+
# Takes a dataframe with the input data and converts it to a graph dataset for the prediction
|
16 |
|
17 |
+
# For the Tg the model takes the proportions of acid and glycols and the log of the Mw (PS) (no other parameters)
|
18 |
+
# """
|
19 |
+
# targets = ["Tg", "IV"]
|
20 |
+
# # Need to pass the target column to do the preprocessing - not sure why it is needed - must be clarified
|
21 |
+
# for col in targets:
|
22 |
+
# df[col] = 0
|
23 |
+
# df_duplicated = pd.concat([df, df]) # Because must specify a test set - this must be changed in the code
|
24 |
+
# add_duplicated = get_Tg_add(df_duplicated)
|
25 |
+
# print(f"Inference to be done on data of size {df.shape}")
|
26 |
+
|
27 |
+
# graph_data = GraphDataset(
|
28 |
+
# data = df_duplicated,
|
29 |
+
# structure_dir = './Structures/AG/xyz',
|
30 |
+
# Y_target=targets,
|
31 |
+
# test_size = 0.5,
|
32 |
+
# add_features = add_duplicated,
|
33 |
+
# ac = acid_positions,
|
34 |
+
# gc = glycol_positions
|
35 |
+
# )
|
36 |
+
|
37 |
+
# test_data, Ytest, add_test = graph_data.get_test()
|
38 |
+
# print(f"After data preprocessing, inference on {len(test_data)}")
|
39 |
+
# n_predictions = len(Ytest)
|
40 |
+
# return test_data, add_test, n_predictions
|
41 |
+
|
42 |
+
|
43 |
+
def predict(df, model_path="../polymerlearn/data_models/", acid_positions = (0, 12), glycol_positions = (13, 25)):
|
44 |
+
|
45 |
+
targets = ["Tg", "IV"]
|
46 |
# Need to pass the target column to do the preprocessing - not sure why it is needed - must be clarified
|
47 |
for col in targets:
|
48 |
df[col] = 0
|
49 |
df_duplicated = pd.concat([df, df]) # Because must specify a test set - this must be changed in the code
|
50 |
+
|
51 |
+
add_features = {"Tg": get_Tg_add(df_duplicated), "IV": get_IV_add(df_duplicated)}
|
52 |
+
|
53 |
print(f"Inference to be done on data of size {df.shape}")
|
54 |
|
55 |
+
print("Tg", add_features["Tg"].shape)
|
56 |
+
print("Tg", add_features["IV"].shape)
|
57 |
+
|
58 |
+
pred_all = []
|
59 |
+
|
60 |
+
for pred in targets:
|
61 |
+
graph_data = GraphDataset(
|
62 |
+
data = df_duplicated,
|
63 |
+
structure_dir = './Structures/AG/xyz',
|
64 |
+
Y_target=targets,
|
65 |
+
test_size = 0.5,
|
66 |
+
add_features = add_features[pred],
|
67 |
+
ac = acid_positions,
|
68 |
+
gc = glycol_positions
|
69 |
+
)
|
70 |
+
|
71 |
+
test_data, Ytest, add_test = graph_data.get_test()
|
72 |
+
print(f"After data preprocessing, inference on {len(test_data)}")
|
73 |
+
n_predictions = len(Ytest)
|
74 |
+
|
75 |
+
if pred == "Tg":
|
76 |
+
model = PolymerGNN_Tg(
|
77 |
+
input_feat= 6, # How many input features on each node; don't change this
|
78 |
+
hidden_channels= 32, # How many intermediate dimensions to use in model
|
79 |
+
# Can change this ^^
|
80 |
+
num_additional= add_features[pred].shape[1] # How many additional resin properties to include in the prediction
|
81 |
+
# Corresponds to the number in get_IV_add
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
model = PolymerGNN_IV(
|
85 |
+
input_feat= 6, # How many input features on each node; don't change this
|
86 |
+
hidden_channels= 32, # How many intermediate dimensions to use in model
|
87 |
+
# Can change this ^^
|
88 |
+
num_additional= add_features[pred].shape[1] # How many additional resin properties to include in the prediction
|
89 |
+
# Corresponds to the number in get_IV_add
|
90 |
+
)
|
91 |
+
|
92 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
93 |
+
state = torch.load(os.path.join(model_path, f"{pred}_model.pth"), map_location=device)
|
94 |
+
model.load_state_dict(state)
|
95 |
+
model.eval()
|
96 |
+
predictions = []
|
97 |
+
print(f"Prediction done on {n_predictions}")
|
98 |
+
with torch.no_grad():
|
99 |
+
for i in range(n_predictions):
|
100 |
+
batch_like_tup = make_like_batch(test_data[i])
|
101 |
+
pred = model(*batch_like_tup, add_test[i]).item()
|
102 |
+
predictions.append(pred)
|
103 |
+
pred_all.append(predictions)
|
104 |
|
105 |
+
return pred_all
|
106 |
|
107 |
|
108 |
+
# def predict(df, acid_positions = (0, 12), glycol_positions = (13, 25), model_path="../polymerlearn/data_models/tg_model_test.pth"):
|
109 |
+
# test_data, add_test, n_predictions = convert_to_graphdataset(df, acid_positions=acid_positions, glycol_positions=glycol_positions)
|
110 |
|
111 |
+
# predictions = predict_from_graph(test_data, add_test, n_predictions, model_path=model_path)
|
112 |
|
113 |
+
# return predictions
|