Spaces:
Sleeping
Sleeping
Upload 347 files
Browse files- configs/experiment/dti_experiment.yaml +1 -1
- configs/trainer/ddp.yaml +1 -1
- configs/trainer/default.yaml +1 -1
- configs/trainer/gpu.yaml +1 -1
- configs/webserver_inference.yaml +27 -0
- deepscreen/__pycache__/__init__.cpython-311.pyc +0 -0
- deepscreen/__pycache__/predict.cpython-311.pyc +0 -0
- deepscreen/__pycache__/test.cpython-311.pyc +0 -0
- deepscreen/data/__pycache__/dti.cpython-311.pyc +0 -0
- deepscreen/data/dti.py +4 -2
- deepscreen/data/featurizers/__pycache__/__init__.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/__pycache__/categorical.cpython-311.pyc +0 -0
- deepscreen/data/utils/__pycache__/collator.cpython-311.pyc +0 -0
- deepscreen/data/utils/__pycache__/label.cpython-311.pyc +0 -0
- deepscreen/data/utils/__pycache__/sampler.cpython-311.pyc +0 -0
- deepscreen/gui/app.py +0 -10
- deepscreen/models/__pycache__/dti.cpython-311.pyc +0 -0
- deepscreen/models/dti.py +27 -5
- deepscreen/models/metrics/__pycache__/sensitivity.cpython-311.pyc +0 -0
- deepscreen/predict.py +2 -3
- deepscreen/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- deepscreen/utils/__pycache__/hydra.cpython-311.pyc +0 -0
- deepscreen/utils/__pycache__/instantiators.cpython-311.pyc +0 -0
- deepscreen/utils/__pycache__/lightning.cpython-311.pyc +0 -0
- deepscreen/utils/__pycache__/logging.cpython-311.pyc +0 -0
- deepscreen/utils/__pycache__/rich.cpython-311.pyc +0 -0
- deepscreen/utils/__pycache__/utils.cpython-311.pyc +0 -0
- deepscreen/utils/hydra.py +24 -19
- deepscreen/utils/lightning.py +2 -2
configs/experiment/dti_experiment.yaml
CHANGED
@@ -9,7 +9,7 @@ seed: 12345
|
|
9 |
trainer:
|
10 |
min_epochs: 1
|
11 |
max_epochs: 500
|
12 |
-
precision:
|
13 |
|
14 |
callbacks:
|
15 |
early_stopping:
|
|
|
9 |
trainer:
|
10 |
min_epochs: 1
|
11 |
max_epochs: 500
|
12 |
+
precision: 16-mixed
|
13 |
|
14 |
callbacks:
|
15 |
early_stopping:
|
configs/trainer/ddp.yaml
CHANGED
@@ -7,4 +7,4 @@ accelerator: gpu
|
|
7 |
devices: 4
|
8 |
num_nodes: 1
|
9 |
sync_batchnorm: True
|
10 |
-
precision:
|
|
|
7 |
devices: 4
|
8 |
num_nodes: 1
|
9 |
sync_batchnorm: True
|
10 |
+
precision: 16-mixed
|
configs/trainer/default.yaml
CHANGED
@@ -5,7 +5,7 @@ default_root_dir: ${paths.output_dir}
|
|
5 |
min_epochs: 1
|
6 |
max_epochs: 50
|
7 |
|
8 |
-
precision:
|
9 |
|
10 |
gradient_clip_val: 0.5
|
11 |
gradient_clip_algorithm: norm
|
|
|
5 |
min_epochs: 1
|
6 |
max_epochs: 50
|
7 |
|
8 |
+
precision: 32
|
9 |
|
10 |
gradient_clip_val: 0.5
|
11 |
gradient_clip_algorithm: norm
|
configs/trainer/gpu.yaml
CHANGED
@@ -3,4 +3,4 @@ defaults:
|
|
3 |
|
4 |
accelerator: gpu
|
5 |
devices: 1
|
6 |
-
precision:
|
|
|
3 |
|
4 |
accelerator: gpu
|
5 |
devices: 1
|
6 |
+
precision: 16-mixed
|
configs/webserver_inference.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- model: dti_model # fixed for web server version
|
4 |
+
- task: null
|
5 |
+
- data: dti_data # fixed for web server version
|
6 |
+
- callbacks: null
|
7 |
+
- trainer: default
|
8 |
+
- paths: default
|
9 |
+
- extras: null
|
10 |
+
- hydra: default
|
11 |
+
- _self_
|
12 |
+
- preset: null
|
13 |
+
- experiment: null
|
14 |
+
- sweep: null
|
15 |
+
- debug: null
|
16 |
+
- optional local: default
|
17 |
+
|
18 |
+
job_name: "webserver_inference"
|
19 |
+
|
20 |
+
tags: null
|
21 |
+
|
22 |
+
# passing checkpoint path is necessary for prediction
|
23 |
+
ckpt_path: ???
|
24 |
+
|
25 |
+
paths:
|
26 |
+
output_dir: null
|
27 |
+
work_dir: null
|
deepscreen/__pycache__/__init__.cpython-311.pyc
CHANGED
Binary files a/deepscreen/__pycache__/__init__.cpython-311.pyc and b/deepscreen/__pycache__/__init__.cpython-311.pyc differ
|
|
deepscreen/__pycache__/predict.cpython-311.pyc
CHANGED
Binary files a/deepscreen/__pycache__/predict.cpython-311.pyc and b/deepscreen/__pycache__/predict.cpython-311.pyc differ
|
|
deepscreen/__pycache__/test.cpython-311.pyc
CHANGED
Binary files a/deepscreen/__pycache__/test.cpython-311.pyc and b/deepscreen/__pycache__/test.cpython-311.pyc differ
|
|
deepscreen/data/__pycache__/dti.cpython-311.pyc
CHANGED
Binary files a/deepscreen/data/__pycache__/dti.cpython-311.pyc and b/deepscreen/data/__pycache__/dti.cpython-311.pyc differ
|
|
deepscreen/data/dti.py
CHANGED
@@ -150,9 +150,11 @@ class DTIDataset(Dataset):
|
|
150 |
sample = self.df.loc[i]
|
151 |
return {
|
152 |
'N': i,
|
153 |
-
'X1':
|
|
|
154 |
'ID1': sample.get('ID1', sample['X1']),
|
155 |
-
'X2':
|
|
|
156 |
'ID2': sample.get('ID2', sample['X2']),
|
157 |
'Y': sample.get('Y'),
|
158 |
'IDX': sample['IDX'],
|
|
|
150 |
sample = self.df.loc[i]
|
151 |
return {
|
152 |
'N': i,
|
153 |
+
'X1': sample['X1'],
|
154 |
+
'X1^': self.drug_featurizer(sample['X1']),
|
155 |
'ID1': sample.get('ID1', sample['X1']),
|
156 |
+
'X2': sample['X2'],
|
157 |
+
'X2^': self.protein_featurizer(sample['X2']),
|
158 |
'ID2': sample.get('ID2', sample['X2']),
|
159 |
'Y': sample.get('Y'),
|
160 |
'IDX': sample['IDX'],
|
deepscreen/data/featurizers/__pycache__/__init__.cpython-311.pyc
CHANGED
Binary files a/deepscreen/data/featurizers/__pycache__/__init__.cpython-311.pyc and b/deepscreen/data/featurizers/__pycache__/__init__.cpython-311.pyc differ
|
|
deepscreen/data/featurizers/__pycache__/categorical.cpython-311.pyc
CHANGED
Binary files a/deepscreen/data/featurizers/__pycache__/categorical.cpython-311.pyc and b/deepscreen/data/featurizers/__pycache__/categorical.cpython-311.pyc differ
|
|
deepscreen/data/utils/__pycache__/collator.cpython-311.pyc
CHANGED
Binary files a/deepscreen/data/utils/__pycache__/collator.cpython-311.pyc and b/deepscreen/data/utils/__pycache__/collator.cpython-311.pyc differ
|
|
deepscreen/data/utils/__pycache__/label.cpython-311.pyc
CHANGED
Binary files a/deepscreen/data/utils/__pycache__/label.cpython-311.pyc and b/deepscreen/data/utils/__pycache__/label.cpython-311.pyc differ
|
|
deepscreen/data/utils/__pycache__/sampler.cpython-311.pyc
CHANGED
Binary files a/deepscreen/data/utils/__pycache__/sampler.cpython-311.pyc and b/deepscreen/data/utils/__pycache__/sampler.cpython-311.pyc differ
|
|
deepscreen/gui/app.py
CHANGED
@@ -16,16 +16,6 @@ root = Path.cwd()
|
|
16 |
task_list = [f.stem for f in root.parent.joinpath("configs/task").iterdir() if f.suffix == ".yaml"]
|
17 |
preset_list = [f.stem for f in root.parent.joinpath("configs/preset").iterdir() if f.suffix == ".yaml"]
|
18 |
predictor_list = [f.stem for f in root.parent.joinpath("configs/model/predictor").iterdir() if f.suffix == ".yaml"]
|
19 |
-
drug_encoder_list = [f.stem for f in root.parent.joinpath("configs/model/predictor/drug_encoder").iterdir() if
|
20 |
-
f.suffix == ".yaml"]
|
21 |
-
drug_featurizer_list = [f.stem for f in root.parent.joinpath("configs/data/drug_featurizer").iterdir() if
|
22 |
-
f.suffix == ".yaml"]
|
23 |
-
protein_encoder_list = [f.stem for f in root.parent.joinpath("configs/model/predictor/protein_encoder").iterdir() if
|
24 |
-
f.suffix == ".yaml"]
|
25 |
-
protein_featurizer_list = [f.stem for f in root.parent.joinpath("configs/data/protein_featurizer").iterdir() if
|
26 |
-
f.suffix == ".yaml"]
|
27 |
-
classifier_list = [f.stem for f in root.parent.joinpath("configs/model/predictor/decoder").iterdir() if
|
28 |
-
f.suffix == ".yaml"]
|
29 |
|
30 |
|
31 |
def load_csv(file):
|
|
|
16 |
task_list = [f.stem for f in root.parent.joinpath("configs/task").iterdir() if f.suffix == ".yaml"]
|
17 |
preset_list = [f.stem for f in root.parent.joinpath("configs/preset").iterdir() if f.suffix == ".yaml"]
|
18 |
predictor_list = [f.stem for f in root.parent.joinpath("configs/model/predictor").iterdir() if f.suffix == ".yaml"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
def load_csv(file):
|
deepscreen/models/__pycache__/dti.cpython-311.pyc
CHANGED
Binary files a/deepscreen/models/__pycache__/dti.cpython-311.pyc and b/deepscreen/models/__pycache__/dti.cpython-311.pyc differ
|
|
deepscreen/models/dti.py
CHANGED
@@ -64,7 +64,7 @@ class DTILightningModule(LightningModule):
|
|
64 |
self.forward(dummy_batch)
|
65 |
|
66 |
def forward(self, batch):
|
67 |
-
output = self.predictor(batch['X1'], batch['X2'])
|
68 |
target = batch.get('Y')
|
69 |
indexes = batch.get('IDX')
|
70 |
preds = None
|
@@ -92,7 +92,12 @@ class DTILightningModule(LightningModule):
|
|
92 |
self.train_metrics(preds=preds, target=target, indexes=indexes.long())
|
93 |
self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
|
94 |
|
95 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
def on_train_epoch_end(self):
|
98 |
pass
|
@@ -104,6 +109,13 @@ class DTILightningModule(LightningModule):
|
|
104 |
self.val_metrics(preds=preds, target=target, indexes=indexes.long())
|
105 |
self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
def on_validation_epoch_end(self):
|
108 |
pass
|
109 |
|
@@ -115,15 +127,25 @@ class DTILightningModule(LightningModule):
|
|
115 |
self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
|
116 |
|
117 |
# return a dictionary for callbacks like BasePredictionWriter
|
118 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
def on_test_epoch_end(self):
|
121 |
pass
|
122 |
|
123 |
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
124 |
-
preds,
|
125 |
# return a dictionary for callbacks like BasePredictionWriter
|
126 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
def configure_optimizers(self):
|
129 |
optimizers_config = {'optimizer': self.hparams.optimizer(params=self.parameters())}
|
|
|
64 |
self.forward(dummy_batch)
|
65 |
|
66 |
def forward(self, batch):
|
67 |
+
output = self.predictor(batch['X1^'], batch['X2^'])
|
68 |
target = batch.get('Y')
|
69 |
indexes = batch.get('IDX')
|
70 |
preds = None
|
|
|
92 |
self.train_metrics(preds=preds, target=target, indexes=indexes.long())
|
93 |
self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
|
94 |
|
95 |
+
return {
|
96 |
+
'N': batch['N'],
|
97 |
+
'ID1': batch['ID1'], 'X1': batch['X1'],
|
98 |
+
'ID2': batch['ID2'], 'X2': batch['X2'],
|
99 |
+
'Y^': preds, 'Y': target, 'loss': loss
|
100 |
+
}
|
101 |
|
102 |
def on_train_epoch_end(self):
|
103 |
pass
|
|
|
109 |
self.val_metrics(preds=preds, target=target, indexes=indexes.long())
|
110 |
self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
|
111 |
|
112 |
+
return {
|
113 |
+
'N': batch['N'],
|
114 |
+
'ID1': batch['ID1'], 'X1': batch['X1'],
|
115 |
+
'ID2': batch['ID2'], 'X2': batch['X2'],
|
116 |
+
'Y^': preds, 'Y': target, 'loss': loss
|
117 |
+
}
|
118 |
+
|
119 |
def on_validation_epoch_end(self):
|
120 |
pass
|
121 |
|
|
|
127 |
self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
|
128 |
|
129 |
# return a dictionary for callbacks like BasePredictionWriter
|
130 |
+
return {
|
131 |
+
'N': batch['N'],
|
132 |
+
'ID1': batch['ID1'], 'X1': batch['X1'],
|
133 |
+
'ID2': batch['ID2'], 'X2': batch['X2'],
|
134 |
+
'Y^': preds, 'Y': target, 'loss': loss
|
135 |
+
}
|
136 |
|
137 |
def on_test_epoch_end(self):
|
138 |
pass
|
139 |
|
140 |
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
141 |
+
preds, _, _, _ = self.forward(batch)
|
142 |
# return a dictionary for callbacks like BasePredictionWriter
|
143 |
+
return {
|
144 |
+
'N': batch['N'],
|
145 |
+
'ID1': batch['ID1'], 'X1': batch['X1'],
|
146 |
+
'ID2': batch['ID2'], 'X2': batch['X2'],
|
147 |
+
'Y^': preds
|
148 |
+
}
|
149 |
|
150 |
def configure_optimizers(self):
|
151 |
optimizers_config = {'optimizer': self.hparams.optimizer(params=self.parameters())}
|
deepscreen/models/metrics/__pycache__/sensitivity.cpython-311.pyc
CHANGED
Binary files a/deepscreen/models/metrics/__pycache__/sensitivity.cpython-311.pyc and b/deepscreen/models/metrics/__pycache__/sensitivity.cpython-311.pyc differ
|
|
deepscreen/predict.py
CHANGED
@@ -34,9 +34,6 @@ def predict(cfg: DictConfig) -> Tuple[list, dict]:
|
|
34 |
Returns:
|
35 |
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
|
36 |
"""
|
37 |
-
assert cfg.ckpt_path, "Checkpoint path (`ckpt_path`) must be specified for predicting."
|
38 |
-
cfg = checkpoint_rerun_config(cfg)
|
39 |
-
|
40 |
log.info(f"Instantiating data <{cfg.data._target_}>")
|
41 |
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
|
42 |
|
@@ -65,6 +62,8 @@ def predict(cfg: DictConfig) -> Tuple[list, dict]:
|
|
65 |
|
66 |
@hydra.main(version_base="1.3", config_path="../configs", config_name="predict.yaml")
|
67 |
def main(cfg: DictConfig):
|
|
|
|
|
68 |
predictions, _ = predict(cfg)
|
69 |
return predictions
|
70 |
|
|
|
34 |
Returns:
|
35 |
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
|
36 |
"""
|
|
|
|
|
|
|
37 |
log.info(f"Instantiating data <{cfg.data._target_}>")
|
38 |
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
|
39 |
|
|
|
62 |
|
63 |
@hydra.main(version_base="1.3", config_path="../configs", config_name="predict.yaml")
|
64 |
def main(cfg: DictConfig):
|
65 |
+
assert cfg.ckpt_path, "Checkpoint path (`ckpt_path`) must be specified for predicting."
|
66 |
+
cfg = checkpoint_rerun_config(cfg)
|
67 |
predictions, _ = predict(cfg)
|
68 |
return predictions
|
69 |
|
deepscreen/utils/__pycache__/__init__.cpython-311.pyc
CHANGED
Binary files a/deepscreen/utils/__pycache__/__init__.cpython-311.pyc and b/deepscreen/utils/__pycache__/__init__.cpython-311.pyc differ
|
|
deepscreen/utils/__pycache__/hydra.cpython-311.pyc
CHANGED
Binary files a/deepscreen/utils/__pycache__/hydra.cpython-311.pyc and b/deepscreen/utils/__pycache__/hydra.cpython-311.pyc differ
|
|
deepscreen/utils/__pycache__/instantiators.cpython-311.pyc
CHANGED
Binary files a/deepscreen/utils/__pycache__/instantiators.cpython-311.pyc and b/deepscreen/utils/__pycache__/instantiators.cpython-311.pyc differ
|
|
deepscreen/utils/__pycache__/lightning.cpython-311.pyc
ADDED
Binary file (4.24 kB). View file
|
|
deepscreen/utils/__pycache__/logging.cpython-311.pyc
CHANGED
Binary files a/deepscreen/utils/__pycache__/logging.cpython-311.pyc and b/deepscreen/utils/__pycache__/logging.cpython-311.pyc differ
|
|
deepscreen/utils/__pycache__/rich.cpython-311.pyc
CHANGED
Binary files a/deepscreen/utils/__pycache__/rich.cpython-311.pyc and b/deepscreen/utils/__pycache__/rich.cpython-311.pyc differ
|
|
deepscreen/utils/__pycache__/utils.cpython-311.pyc
CHANGED
Binary files a/deepscreen/utils/__pycache__/utils.cpython-311.pyc and b/deepscreen/utils/__pycache__/utils.cpython-311.pyc differ
|
|
deepscreen/utils/hydra.py
CHANGED
@@ -73,28 +73,32 @@ class CSVExperimentSummary(Callback):
|
|
73 |
override_dict['epoch'] = int(re.search(r'epoch_(\d+)', override_dict['ckpt_path']).group(1))
|
74 |
|
75 |
# Add metrics info
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
else:
|
90 |
-
|
91 |
-
metrics_df =
|
92 |
-
|
93 |
-
metrics_df = metrics_df.assign(**override_dict)
|
94 |
-
metrics_df.index = [0]
|
95 |
|
96 |
# Add extra info from the input batch experiment summary
|
97 |
-
if self.input_experiment_summary is not None:
|
98 |
orig_meta = self.input_experiment_summary[
|
99 |
self.input_experiment_summary['ckpt_path'] == metrics_df['ckpt_path'][0]
|
100 |
].head(1)
|
@@ -102,6 +106,7 @@ class CSVExperimentSummary(Callback):
|
|
102 |
metrics_df = metrics_df.combine_first(orig_meta)
|
103 |
|
104 |
summary_df = pd.concat([summary_df, metrics_df])
|
|
|
105 |
# Drop empty columns
|
106 |
summary_df.dropna(inplace=True, axis=1, how='all')
|
107 |
summary_df.to_csv(summary_file_path, index=False, mode='w')
|
|
|
73 |
override_dict['epoch'] = int(re.search(r'epoch_(\d+)', override_dict['ckpt_path']).group(1))
|
74 |
|
75 |
# Add metrics info
|
76 |
+
metrics_df = pd.DataFrame()
|
77 |
+
if config.get('logger'):
|
78 |
+
output_dir = Path(config.hydra.runtime.output_dir).resolve()
|
79 |
+
csv_metrics_path = output_dir / config.logger.csv.name / "metrics.csv"
|
80 |
+
if csv_metrics_path.is_file():
|
81 |
+
log.info(f"Summarizing metrics with prefix `{self.prefix}` from {csv_metrics_path}")
|
82 |
+
# Use only columns that start with the specified prefix
|
83 |
+
metrics_df = pd.read_csv(csv_metrics_path)
|
84 |
+
# Find rows where any 'test/' column is not null and reset its epoch to the best model epoch
|
85 |
+
test_columns = [col for col in metrics_df.columns if col.startswith('test/')]
|
86 |
+
mask = metrics_df[test_columns].notna().any(axis=1)
|
87 |
+
metrics_df.loc[mask, 'epoch'] = override_dict['epoch']
|
88 |
+
# Group and filter by best epoch
|
89 |
+
metrics_df = metrics_df.groupby('epoch').first()
|
90 |
+
metrics_df = metrics_df[metrics_df.index == override_dict['epoch']]
|
91 |
+
else:
|
92 |
+
log.info(f"No metrics.csv found in {output_dir}")
|
93 |
+
|
94 |
+
if metrics_df.empty:
|
95 |
+
metrics_df = pd.DataFrame(data=override_dict, index=[0])
|
96 |
else:
|
97 |
+
metrics_df = metrics_df.assign(**override_dict)
|
98 |
+
metrics_df.index = [0]
|
|
|
|
|
|
|
99 |
|
100 |
# Add extra info from the input batch experiment summary
|
101 |
+
if self.input_experiment_summary is not None and 'ckpt_path' in metrics_df.columns:
|
102 |
orig_meta = self.input_experiment_summary[
|
103 |
self.input_experiment_summary['ckpt_path'] == metrics_df['ckpt_path'][0]
|
104 |
].head(1)
|
|
|
106 |
metrics_df = metrics_df.combine_first(orig_meta)
|
107 |
|
108 |
summary_df = pd.concat([summary_df, metrics_df])
|
109 |
+
|
110 |
# Drop empty columns
|
111 |
summary_df.dropna(inplace=True, axis=1, how='all')
|
112 |
summary_df.to_csv(summary_file_path, index=False, mode='w')
|
deepscreen/utils/lightning.py
CHANGED
@@ -22,14 +22,14 @@ class CSVPredictionWriter(BasePredictionWriter):
|
|
22 |
output_df = self.outputs_to_dataframe(outputs)
|
23 |
output_df.to_csv(self.output_file,
|
24 |
mode='a',
|
25 |
-
|
26 |
header=not self.output_file.is_file())
|
27 |
|
28 |
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
|
29 |
output_df = pd.concat([self.outputs_to_dataframe(outputs) for outputs in predictions])
|
30 |
output_df.to_csv(self.output_file,
|
31 |
mode='w',
|
32 |
-
|
33 |
header=True)
|
34 |
|
35 |
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int, dataloader_idx: int = 0):
|
|
|
22 |
output_df = self.outputs_to_dataframe(outputs)
|
23 |
output_df.to_csv(self.output_file,
|
24 |
mode='a',
|
25 |
+
index=False,
|
26 |
header=not self.output_file.is_file())
|
27 |
|
28 |
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
|
29 |
output_df = pd.concat([self.outputs_to_dataframe(outputs) for outputs in predictions])
|
30 |
output_df.to_csv(self.output_file,
|
31 |
mode='w',
|
32 |
+
index=False,
|
33 |
header=True)
|
34 |
|
35 |
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int, dataloader_idx: int = 0):
|