|
import numpy as np |
|
import torch |
|
|
|
from . import models |
|
|
|
|
|
def get_name_and_params(base): |
|
name = getattr(base, 'name') |
|
params = getattr(base, 'params') or {} |
|
return name, params |
|
|
|
|
|
def get_transform(base, transform, mode=None): |
|
if not base: return None |
|
transform = getattr(base, transform) |
|
if not transform: return None |
|
name, params = get_name_and_params(transform) |
|
if mode: |
|
params.update({'mode': mode}) |
|
return getattr(data.transforms, name)(**params) |
|
|
|
|
|
def build_transforms(cfg, mode): |
|
|
|
resizer = get_transform(cfg.transform, 'resize') |
|
|
|
augmenter = None |
|
if mode == "train": |
|
augmenter = get_transform(cfg.transform, 'augment') |
|
|
|
cropper = get_transform(cfg.transform, 'crop', mode=mode) |
|
|
|
preprocessor = get_transform(cfg.transform, 'preprocess') |
|
return { |
|
'resize': resizer, |
|
'augment': augmenter, |
|
'crop': cropper, |
|
'preprocess': preprocessor |
|
} |
|
|
|
|
|
def build_dataset(cfg, data_info, mode): |
|
dataset_class = getattr(data.datasets, cfg.data.dataset.name) |
|
dataset_params = cfg.data.dataset.params |
|
dataset_params.test_mode = mode != 'train' |
|
dataset_params = dict(dataset_params) |
|
if "FeatureDataset" not in cfg.data.dataset.name: |
|
transforms = build_transforms(cfg, mode) |
|
dataset_params.update(transforms) |
|
dataset_params.update(data_info) |
|
return dataset_class(**dataset_params) |
|
|
|
|
|
def build_dataloader(cfg, dataset, mode): |
|
|
|
def worker_init_fn(worker_id): |
|
np.random.seed(np.random.get_state()[1][0] + worker_id) |
|
|
|
dataloader_params = {} |
|
dataloader_params['num_workers'] = cfg.data.num_workers |
|
dataloader_params['drop_last'] = mode == 'train' |
|
dataloader_params['shuffle'] = mode == 'train' |
|
dataloader_params["pin_memory"] = cfg.data.get("pin_memory", True) |
|
if mode in ('train', 'valid'): |
|
if mode == "train": |
|
dataloader_params['batch_size'] = cfg.train.batch_size |
|
elif mode == "valid": |
|
dataloader_params["batch_size"] = cfg.evaluate.get("batch_size") or cfg.train.batch_size |
|
sampler = None |
|
if cfg.data.get("sampler") and mode == 'train': |
|
name, params = get_name_and_params(cfg.data.sampler) |
|
sampler = getattr(data.samplers, name)(dataset, **params) |
|
if sampler: |
|
dataloader_params['shuffle'] = False |
|
if cfg.strategy == 'ddp': |
|
sampler = data.samplers.DistributedSamplerWrapper(sampler) |
|
dataloader_params['sampler'] = sampler |
|
print(f'Using sampler {sampler} for training ...') |
|
elif cfg.strategy == 'ddp': |
|
dataloader_params["shuffle"] = False |
|
dataloader_params['sampler'] = DistributedSampler(dataset, shuffle=mode=="train") |
|
else: |
|
assert cfg.strategy != "ddp", "DDP currently not supported for inference" |
|
dataloader_params['batch_size'] = cfg.evaluate.get("batch_size") or cfg.train.batch_size |
|
|
|
loader = DataLoader(dataset, |
|
**dataloader_params, |
|
worker_init_fn=worker_init_fn) |
|
return loader |
|
|
|
|
|
def build_model(cfg): |
|
name, params = get_name_and_params(cfg.model) |
|
if cfg.model.params.get("cnn_params", None): |
|
cnn_params = cfg.model.params.cnn_params |
|
if cnn_params.get("load_pretrained_backbone", None): |
|
if "foldx" in cnn_params.load_pretrained_backbone: |
|
cfg.model.params.cnn_params.load_pretrained_backbone = cnn_params.load_pretrained_backbone.\ |
|
replace("foldx", f"fold{cfg.data.outer_fold}") |
|
print(f'Creating model <{name}> ...') |
|
model = getattr(models.engine, name)(**params) |
|
if 'backbone' in cfg.model.params: |
|
print(f' Using backbone <{cfg.model.params.backbone}> ...') |
|
if 'pretrained' in cfg.model.params: |
|
print(f' Pretrained : {cfg.model.params.pretrained}') |
|
if "load_pretrained" in cfg.model: |
|
import re |
|
if "foldx" in cfg.model.load_pretrained: |
|
cfg.model.load_pretrained = cfg.model.load_pretrained.replace("foldx", f"fold{cfg.data.outer_fold}") |
|
print(f" Loading pretrained checkpoint from {cfg.model.load_pretrained}") |
|
weights = torch.load(cfg.model.load_pretrained, map_location=lambda storage, loc: storage)['state_dict'] |
|
weights = {re.sub(r'^model.', '', k) : v for k,v in weights.items() if "loss_fn" not in k} |
|
model.load_state_dict(weights) |
|
return model |
|
|
|
|
|
def build_loss(cfg): |
|
name, params = get_name_and_params(cfg.loss) |
|
print(f'Using loss function <{name}> ...') |
|
params = dict(params) |
|
if "pos_weight" in params: |
|
params["pos_weight"] = torch.tensor(params["pos_weight"]) |
|
criterion = getattr(losses, name)(**params) |
|
return criterion |
|
|
|
|
|
def build_scheduler(cfg, optimizer): |
|
|
|
|
|
name, params = get_name_and_params(cfg.scheduler) |
|
print(f'Using learning rate schedule <{name}> ...') |
|
|
|
if name == 'CosineAnnealingLR': |
|
|
|
|
|
|
|
|
|
params = { |
|
'T_max': 100000, |
|
'eta_min': max(params.final_lr, 1.0e-8) |
|
} |
|
|
|
if name in ('OneCycleLR', 'CustomOneCycleLR'): |
|
|
|
lr_0 = cfg.optimizer.params.lr |
|
lr_1 = params.max_lr |
|
lr_2 = params.final_lr |
|
|
|
pct_start = params.pct_start |
|
params = {} |
|
params['steps_per_epoch'] = 100000 |
|
params['epochs'] = cfg.train.num_epochs |
|
params['max_lr'] = lr_1 |
|
params['pct_start'] = pct_start |
|
params['div_factor'] = lr_1 / lr_0 |
|
params['final_div_factor'] = lr_0 / max(lr_2, 1.0e-8) |
|
|
|
scheduler = getattr(optim, name)(optimizer=optimizer, **params) |
|
|
|
|
|
if name in ('OneCycleLR', 'CustomOneCycleLR'): |
|
scheduler.pct_start = params['pct_start'] |
|
|
|
|
|
if name in ('OneCycleLR', 'CustomOneCycleLR', 'CosineAnnealingLR'): |
|
scheduler.update_frequency = 'on_batch' |
|
elif name in ('ReduceLROnPlateau'): |
|
scheduler.update_frequency = 'on_valid' |
|
else: |
|
scheduler.update_frequency = 'on_epoch' |
|
|
|
return scheduler |
|
|
|
|
|
def build_optimizer(cfg, parameters): |
|
name, params = get_name_and_params(cfg.optimizer) |
|
print(f'Using optimizer <{name}> ...') |
|
optimizer = getattr(optim, name)(parameters, **params) |
|
return optimizer |
|
|
|
|
|
def build_task(cfg, model): |
|
name, params = get_name_and_params(cfg.task) |
|
print(f'Building task <{name}> ...') |
|
return getattr(tasks, name)(cfg, model, **params) |
|
|
|
|
|
|