libokj commited on
Commit
ca1d737
·
1 Parent(s): ca59d96

Upload dti.py

Browse files
Files changed (1) hide show
  1. deepscreen/models/dti.py +7 -7
deepscreen/models/dti.py CHANGED
@@ -1,5 +1,5 @@
1
  from functools import partial
2
- from typing import Optional, Sequence, Dict
3
 
4
  from torch import nn, optim, Tensor
5
  from lightning import LightningModule
@@ -53,12 +53,12 @@ class DTILightningModule(LightningModule):
53
  dataloader = self.trainer.datamodule.train_dataloader()
54
  dummy_batch = next(iter(dataloader))
55
  self.forward(dummy_batch)
56
- case 'validate':
57
- dataloader = self.trainer.datamodule.val_dataloader()
58
- case 'test':
59
- dataloader = self.trainer.datamodule.test_dataloader()
60
- case 'predict':
61
- dataloader = self.trainer.datamodule.predict_dataloader()
62
 
63
  # for key, value in dummy_batch.items():
64
  # if isinstance(value, Tensor):
 
1
  from functools import partial
2
+ from typing import Optional, Sequence, Dict, Any
3
 
4
  from torch import nn, optim, Tensor
5
  from lightning import LightningModule
 
53
  dataloader = self.trainer.datamodule.train_dataloader()
54
  dummy_batch = next(iter(dataloader))
55
  self.forward(dummy_batch)
56
+ # case 'validate':
57
+ # dataloader = self.trainer.datamodule.val_dataloader()
58
+ # case 'test':
59
+ # dataloader = self.trainer.datamodule.test_dataloader()
60
+ # case 'predict':
61
+ # dataloader = self.trainer.datamodule.predict_dataloader()
62
 
63
  # for key, value in dummy_batch.items():
64
  # if isinstance(value, Tensor):