libokj commited on
Commit
5eeb7c8
·
1 Parent(s): be4442c

Update deepscreen/models/dti.py

Browse files
Files changed (1) hide show
  1. deepscreen/models/dti.py +40 -24
deepscreen/models/dti.py CHANGED
@@ -17,6 +17,8 @@ class DTILightningModule(LightningModule):
17
  model: a fully initialized instance of class torch.nn.Module
18
  metrics: a list of fully initialized instances of class torchmetrics.Metric
19
  """
 
 
20
  def __init__(
21
  self,
22
  optimizer: optim.Optimizer,
@@ -58,13 +60,10 @@ class DTILightningModule(LightningModule):
58
  case 'predict':
59
  dataloader = self.trainer.datamodule.predict_dataloader()
60
 
61
-
62
  # for key, value in dummy_batch.items():
63
  # if isinstance(value, Tensor):
64
  # dummy_batch[key] = value.to(self.device)
65
 
66
-
67
-
68
  def forward(self, batch):
69
  output = self.predictor(batch['X1^'], batch['X2^'])
70
  target = batch.get('Y')
@@ -94,13 +93,18 @@ class DTILightningModule(LightningModule):
94
  self.train_metrics(preds=preds, target=target, indexes=indexes.long())
95
  self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
96
 
97
- return {
98
- 'N': batch['N'],
99
- 'ID1': batch['ID1'], 'X1': batch['X1'],
100
- 'ID2': batch['ID2'], 'X2': batch['X2'],
101
- 'Y^': preds, 'Y': target, 'loss': loss
102
  }
103
 
 
 
 
 
 
 
104
  def on_train_epoch_end(self):
105
  pass
106
 
@@ -111,13 +115,18 @@ class DTILightningModule(LightningModule):
111
  self.val_metrics(preds=preds, target=target, indexes=indexes.long())
112
  self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
113
 
114
- return {
115
- 'N': batch['N'],
116
- 'ID1': batch['ID1'], 'X1': batch['X1'],
117
- 'ID2': batch['ID2'], 'X2': batch['X2'],
118
- 'Y^': preds, 'Y': target, 'loss': loss
119
  }
120
 
 
 
 
 
 
 
121
  def on_validation_epoch_end(self):
122
  pass
123
 
@@ -128,27 +137,34 @@ class DTILightningModule(LightningModule):
128
  self.test_metrics(preds=preds, target=target, indexes=indexes.long())
129
  self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
130
 
131
- # return a dictionary for callbacks like BasePredictionWriter
132
- return {
133
- 'N': batch['N'],
134
- 'ID1': batch['ID1'], 'X1': batch['X1'],
135
- 'ID2': batch['ID2'], 'X2': batch['X2'],
136
- 'Y^': preds, 'Y': target, 'loss': loss
137
  }
138
 
 
 
 
 
 
 
139
  def on_test_epoch_end(self):
140
  pass
141
 
142
  def predict_step(self, batch, batch_idx, dataloader_idx=0):
143
  preds, _, _, _ = self.forward(batch)
144
  # return a dictionary for callbacks like BasePredictionWriter
145
- return {
146
- 'N': batch['N'],
147
- 'ID1': batch['ID1'], 'X1': batch['X1'],
148
- 'ID2': batch['ID2'], 'X2': batch['X2'],
149
- 'Y^': preds
150
  }
151
 
 
 
 
 
 
 
152
  def configure_optimizers(self):
153
  optimizers_config = {'optimizer': self.hparams.optimizer(params=self.parameters())}
154
  if self.hparams.get('scheduler'):
 
17
  model: a fully initialized instance of class torch.nn.Module
18
  metrics: a list of fully initialized instances of class torchmetrics.Metric
19
  """
20
+ extra_return_keys = ['ID1', 'X1', 'ID2', 'X2', 'N']
21
+
22
  def __init__(
23
  self,
24
  optimizer: optim.Optimizer,
 
60
  case 'predict':
61
  dataloader = self.trainer.datamodule.predict_dataloader()
62
 
 
63
  # for key, value in dummy_batch.items():
64
  # if isinstance(value, Tensor):
65
  # dummy_batch[key] = value.to(self.device)
66
 
 
 
67
  def forward(self, batch):
68
  output = self.predictor(batch['X1^'], batch['X2^'])
69
  target = batch.get('Y')
 
93
  self.train_metrics(preds=preds, target=target, indexes=indexes.long())
94
  self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
95
 
96
+ return_dict = {
97
+ 'Y^': preds,
98
+ 'Y': target,
99
+ 'loss': loss
 
100
  }
101
 
102
+ for key in self.extra_return_keys:
103
+ if key in batch:
104
+ return_dict[key] = batch[key]
105
+
106
+ return return_dict
107
+
108
  def on_train_epoch_end(self):
109
  pass
110
 
 
115
  self.val_metrics(preds=preds, target=target, indexes=indexes.long())
116
  self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
117
 
118
+ return_dict = {
119
+ 'Y^': preds,
120
+ 'Y': target,
121
+ 'loss': loss
 
122
  }
123
 
124
+ for key in self.extra_return_keys:
125
+ if key in batch:
126
+ return_dict[key] = batch[key]
127
+
128
+ return return_dict
129
+
130
  def on_validation_epoch_end(self):
131
  pass
132
 
 
137
  self.test_metrics(preds=preds, target=target, indexes=indexes.long())
138
  self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
139
 
140
+ return_dict = {
141
+ 'Y^': preds,
142
+ 'Y': target,
143
+ 'loss': loss
 
 
144
  }
145
 
146
+ for key in self.extra_return_keys:
147
+ if key in batch:
148
+ return_dict[key] = batch[key]
149
+
150
+ return return_dict
151
+
152
  def on_test_epoch_end(self):
153
  pass
154
 
155
  def predict_step(self, batch, batch_idx, dataloader_idx=0):
156
  preds, _, _, _ = self.forward(batch)
157
  # return a dictionary for callbacks like BasePredictionWriter
158
+ return_dict = {
159
+ 'Y^': preds,
 
 
 
160
  }
161
 
162
+ for key in self.extra_return_keys:
163
+ if key in batch:
164
+ return_dict[key] = batch[key]
165
+
166
+ return return_dict
167
+
168
  def configure_optimizers(self):
169
  optimizers_config = {'optimizer': self.hparams.optimizer(params=self.parameters())}
170
  if self.hparams.get('scheduler'):