Spaces:
Running
Running
Update loss.py
Browse files
loss.py
CHANGED
|
@@ -34,7 +34,7 @@ def discriminator_loss(generator, discriminator, mol_graph, adj, annot, batch_si
|
|
| 34 |
return node, edge,d_loss
|
| 35 |
|
| 36 |
|
| 37 |
-
def generator_loss(generator, discriminator, v, adj, annot, batch_size, penalty, matrices2mol, fps_r,submodel):
|
| 38 |
|
| 39 |
# Compute loss with fake molecules.
|
| 40 |
|
|
@@ -53,7 +53,7 @@ def generator_loss(generator, discriminator, v, adj, annot, batch_size, penalty,
|
|
| 53 |
g_edges_hat_sample = torch.max(edge_sample, -1)[1]
|
| 54 |
g_nodes_hat_sample = torch.max(node_sample , -1)[1]
|
| 55 |
|
| 56 |
-
fake_mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True)
|
| 57 |
for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
|
| 58 |
g_loss = prediction_fake
|
| 59 |
# Compute penalty loss.
|
|
@@ -116,7 +116,7 @@ def discriminator2_loss(generator, discriminator, mol_graph, adj, annot, batch_s
|
|
| 116 |
|
| 117 |
return d2_loss
|
| 118 |
|
| 119 |
-
def generator2_loss(generator, discriminator, v, adj, annot, batch_size, penalty, matrices2mol, fps_r,ak1_adj,akt1_annot, submodel):
|
| 120 |
|
| 121 |
# Generate molecules.
|
| 122 |
|
|
@@ -140,7 +140,7 @@ def generator2_loss(generator, discriminator, v, adj, annot, batch_size, penalty
|
|
| 140 |
g2_loss_fake = - torch.mean(g_tra_logits_fake2)
|
| 141 |
|
| 142 |
# Reward
|
| 143 |
-
fake_mol_g = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True)
|
| 144 |
for e_, n_ in zip(dr_g_edges_hat_sample, dr_g_nodes_hat_sample)]
|
| 145 |
g2_loss = g2_loss_fake
|
| 146 |
if submodel == "RL":
|
|
|
|
| 34 |
return node, edge,d_loss
|
| 35 |
|
| 36 |
|
| 37 |
+
def generator_loss(generator, discriminator, v, adj, annot, batch_size, penalty, matrices2mol, fps_r,submodel, dataset_name):
|
| 38 |
|
| 39 |
# Compute loss with fake molecules.
|
| 40 |
|
|
|
|
| 53 |
g_edges_hat_sample = torch.max(edge_sample, -1)[1]
|
| 54 |
g_nodes_hat_sample = torch.max(node_sample , -1)[1]
|
| 55 |
|
| 56 |
+
fake_mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=dataset_name)
|
| 57 |
for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
|
| 58 |
g_loss = prediction_fake
|
| 59 |
# Compute penalty loss.
|
|
|
|
| 116 |
|
| 117 |
return d2_loss
|
| 118 |
|
| 119 |
+
def generator2_loss(generator, discriminator, v, adj, annot, batch_size, penalty, matrices2mol, fps_r,ak1_adj,akt1_annot, submodel, drugs_name):
|
| 120 |
|
| 121 |
# Generate molecules.
|
| 122 |
|
|
|
|
| 140 |
g2_loss_fake = - torch.mean(g_tra_logits_fake2)
|
| 141 |
|
| 142 |
# Reward
|
| 143 |
+
fake_mol_g = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=drugs_name)
|
| 144 |
for e_, n_ in zip(dr_g_edges_hat_sample, dr_g_nodes_hat_sample)]
|
| 145 |
g2_loss = g2_loss_fake
|
| 146 |
if submodel == "RL":
|