Spaces:
Sleeping
Sleeping
Upload dti.py
Browse files- deepscreen/models/dti.py +7 -7
deepscreen/models/dti.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from functools import partial
|
2 |
-
from typing import Optional, Sequence, Dict
|
3 |
|
4 |
from torch import nn, optim, Tensor
|
5 |
from lightning import LightningModule
|
@@ -53,12 +53,12 @@ class DTILightningModule(LightningModule):
|
|
53 |
dataloader = self.trainer.datamodule.train_dataloader()
|
54 |
dummy_batch = next(iter(dataloader))
|
55 |
self.forward(dummy_batch)
|
56 |
-
case 'validate':
|
57 |
-
|
58 |
-
case 'test':
|
59 |
-
|
60 |
-
case 'predict':
|
61 |
-
|
62 |
|
63 |
# for key, value in dummy_batch.items():
|
64 |
# if isinstance(value, Tensor):
|
|
|
1 |
from functools import partial
|
2 |
+
from typing import Optional, Sequence, Dict, Any
|
3 |
|
4 |
from torch import nn, optim, Tensor
|
5 |
from lightning import LightningModule
|
|
|
53 |
dataloader = self.trainer.datamodule.train_dataloader()
|
54 |
dummy_batch = next(iter(dataloader))
|
55 |
self.forward(dummy_batch)
|
56 |
+
# case 'validate':
|
57 |
+
# dataloader = self.trainer.datamodule.val_dataloader()
|
58 |
+
# case 'test':
|
59 |
+
# dataloader = self.trainer.datamodule.test_dataloader()
|
60 |
+
# case 'predict':
|
61 |
+
# dataloader = self.trainer.datamodule.predict_dataloader()
|
62 |
|
63 |
# for key, value in dummy_batch.items():
|
64 |
# if isinstance(value, Tensor):
|