import os import sys import copy import argparse import torch import torch.nn as nn import mlflow.pytorch from torch.utils.data import DataLoader from torchvision.models import resnet18 import torchvision.transforms as T from pytorch_lightning.metrics.functional import accuracy import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint from utils.base import display_mlflow_run_info, str2bool, fetch_from_mlflow, get_name, data_loader_mean_and_std from utils.debug import debug from utils.augmentation import get_augmentation from utils.dataset import Subset, get_dataset, k_fold from processingpipeline.pipeline import RawProcessingPipeline from processingpipeline.torch_pipeline import raw2rgb, RawToRGB, ParametrizedProcessing, NNProcessing from models.classifier import log_tensor, resnet_model, LitModel, TrackImagesCallback import segmentation_models_pytorch as smp from utils.pytorch_ssim import SSIM # args to set up task parser = argparse.ArgumentParser(description="classification_task") parser.add_argument("--tracking_uri", type=str, default="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com", help='URI of the mlflow server on AWS') parser.add_argument("--processor_uri", type=str, default=None, help='URI of the processing model (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/processing-model)') parser.add_argument("--classifier_uri", type=str, default=None, help='URI of the net (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/prediction-model)') parser.add_argument("--state_dict_uri", type=str, default=None, help='URI of the indices you want to load (e.g. s3://mlflow-artifacts-601883093460/7/4326da05aca54107be8c554de0674a14/artifacts/training') parser.add_argument("--experiment_name", type=str, default='classification learnable pipeline', help='Specify the experiment you are running, e.g. end2end segmentation') parser.add_argument("--run_name", type=str, default='test run', help='Specify the name of your run') parser.add_argument("--log_model", type=str2bool, default=True, help='Enables model logging') parser.add_argument("--save_locally", action='store_true', help='Model will be saved locally if action is taken') # TODO: bypass mlflow parser.add_argument("--track_processing", action='store_true', help='Save images after each trasformation of the pipeline for the test set') parser.add_argument("--track_processing_gradients", action='store_true', help='Save images of gradients after each trasformation of the pipeline for the test set') parser.add_argument("--track_save_tensors", action='store_true', help='Save the torch tensors after each trasformation of the pipeline for the test set') parser.add_argument("--track_predictions", action='store_true', help='Save images after each trasformation of the pipeline for the test set + input gradient') parser.add_argument("--track_n_images", default=5, help='Track the n first elements of dataset. Only used for args.track_processing=True') parser.add_argument("--track_every_epoch", action='store_true', help='Track images every epoch or once after training') # args to create dataset parser.add_argument("--seed", type=int, default=1, help='Global seed') parser.add_argument("--dataset", type=str, default='Microscopy', choices=["Drone", "DroneSegmentation", "Microscopy"], help='Select dataset') parser.add_argument("--n_splits", type=int, default=1, help='Number of splits used for training') parser.add_argument("--train_size", type=float, default=0.8, help='Fraction of training points in dataset') # args for training parser.add_argument("--lr", type=float, default=1e-5, help="learning rate used for training") parser.add_argument("--epochs", type=int, default=3, help="numper of epochs") parser.add_argument("--batch_size", type=int, default=32, help="Training batch size") parser.add_argument("--augmentation", type=str, default='none', choices=["none", "weak", "strong"], help="Applies augmentation to training") parser.add_argument("--augmentation_on_valid_epoch", action='store_true', help='Track images every epoch or once after training') # TODO: implement, actually should be disabled by default for 'val' and 'test parser.add_argument("--check_val_every_n_epoch", type=int, default=1) # args to specify the processing parser.add_argument("--processing_mode", type=str, default="parametrized", choices=["parametrized", "static", "neural_network", "none"], help="Which type of raw to rgb processing should be used") # args to specify model parser.add_argument("--classifier_network", type=str, default='ResNet18', help='Type of pretrained network') # TODO: implement different choices parser.add_argument("--classifier_pretrained", action='store_true', help='Whether to use a pre-trained model or not') parser.add_argument("--smp_encoder", type=str, default='resnet34', help='segmentation model encoder') parser.add_argument("--freeze_processor", action='store_true', help="Freeze raw to rgb processing model weights") parser.add_argument("--freeze_classifier", action='store_true', help="Freeze classification model weights") # args to specify static pipeline transformations parser.add_argument("--sp_debayer", type=str, default='bilinear', choices=['bilinear', 'malvar2004', 'menon2007'], help="Specify algorithm used as debayer") parser.add_argument("--sp_sharpening", type=str, default='sharpening_filter', choices=['sharpening_filter', 'unsharp_masking'], help="Specify algorithm used for sharpening") parser.add_argument("--sp_denoising", type=str, default='gaussian_denoising', choices=['gaussian_denoising', 'median_denoising', 'fft_denoising'], help="Specify algorithm used for denoising") # args to choose training mode parser.add_argument("--adv_training", action='store_true', help="Enable adversarial training") parser.add_argument("--adv_aux_weight", type=float, default=1, help="Weighting of the adversarial auxilliary loss") parser.add_argument("--adv_aux_loss", type=str, default='ssim', choices=['l2', 'ssim'], help="Type of adversarial auxilliary regularization loss") parser.add_argument("--cache_downloaded_models", type=str2bool, default=True) parser.add_argument('--test_run', action='store_true') if 'ipykernel_launcher' in sys.argv[0]: args = parser.parse_args([ '--dataset=Microscopy', '--epochs=100', '--augmentation=strong', '--lr=1e-5', '--freeze_processor', # '--track_processing', # '--test_run', # '--track_predictions', # '--track_every_epoch', # '--adv_training', # '--adv_aux_weight=100', # '--adv_aux_loss=l2', # '--log_model=', ]) else: args = parser.parse_args() def run_train(args): DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' training_mode = 'adversarial' if args.adv_training else 'default' # set tracking uri, this is the address of the mlflow server where light experimental data will be stored mlflow.set_tracking_uri(args.tracking_uri) mlflow.set_experiment(args.experiment_name) os.environ["AWS_ACCESS_KEY_ID"] = #TODO: add your AWS access key if you want to write your results to our collaborative lab server os.environ["AWS_SECRET_ACCESS_KEY"] = #TODO: add your AWS seceret access key if you want to write your results to our collaborative lab server # dataset dataset = get_dataset(args.dataset) print(f'dataset: {type(dataset).__name__}[{len(dataset)}]') print(f'task: {dataset.task}') print(f'mode: {training_mode} training') print(f'# cross-validation subsets: {args.n_splits}') pl.seed_everything(args.seed) idxs_kfold = k_fold(dataset, n_splits=args.n_splits, seed=args.seed, train_size=args.train_size) with mlflow.start_run(run_name=args.run_name) as parent_run: for k_iter, idxs in enumerate(idxs_kfold): print(f"K_fold subset: {k_iter+1}/{args.n_splits}") if args.processing_mode == 'static': if args.dataset == "Drone" or args.dataset == "DroneSegmentation": mean = torch.tensor([0.35, 0.36, 0.35]) std = torch.tensor([0.12, 0.11, 0.12]) elif args.dataset == "Microscopy": mean = torch.tensor([0.91, 0.84, 0.94]) std = torch.tensor([0.08, 0.12, 0.05]) dataset.transform = T.Compose([RawProcessingPipeline( camera_parameters=dataset.camera_parameters, debayer=args.sp_debayer, sharpening=args.sp_sharpening, denoising=args.sp_denoising, ), T.Normalize(mean, std)]) # XXX: Not clean processor = nn.Identity() if args.processor_uri is not None and args.processing_mode != 'none': print('Fetching processor: ', end='') model = fetch_from_mlflow(args.processor_uri, use_cache=args.cache_downloaded_models) processor = model.processor for param in processor.parameters(): param.requires_grad = True model.processor = None del model else: print(f'processing_mode: {args.processing_mode}') normalize_mosaic = None # normalize after raw has been passed to raw2rgb if args.dataset == "Microscopy": mosaic_mean = [0.5663, 0.1401, 0.0731] mosaic_std = [0.097, 0.0423, 0.008] normalize_mosaic = T.Normalize(mosaic_mean, mosaic_std) track_stages = args.track_processing or args.track_processing_gradients if args.processing_mode == 'parametrized': processor = ParametrizedProcessing( camera_parameters=dataset.camera_parameters, track_stages=track_stages, batch_norm_output=True, noise_layer=args.adv_training, # XXX: Remove? ) elif args.processing_mode == 'neural_network': processor = NNProcessing(track_stages=track_stages, normalize_mosaic=normalize_mosaic, batch_norm_output=True) elif args.processing_mode == 'none': processor = RawToRGB(reduce_size=True, out_channels=3, track_stages=track_stages, normalize_mosaic=normalize_mosaic) if args.classifier_uri: # fetch classifier print('Fetching classifier: ', end='') model = fetch_from_mlflow(args.classifier_uri, use_cache=args.cache_downloaded_models) classifier = model.classifier model.classifier = None del model else: if dataset.task == 'classification': classifier = resnet_model( model=resnet18, pretrained=args.classifier_pretrained, in_channels=3, fc_out_features=len(dataset.classes) ) else: # XXX: add other network choices to args.smp_network (FPN) and args.network classifier = smp.UnetPlusPlus( encoder_name=args.smp_encoder, encoder_depth=5, encoder_weights='imagenet', in_channels=3, classes=1, activation=None, ) if args.freeze_processor and len(list(iter(processor.parameters()))) == 0: print('Note: freezing processor without parameters.') assert not (args.freeze_processor and args.freeze_classifier), 'Likely no parameters to train.' if dataset.task == 'classification': loss = nn.CrossEntropyLoss() metrics = [accuracy] else: # loss = utils.base.smp_get_loss(args.smp_loss) # XXX: add other losses to args.smp_loss loss = smp.losses.DiceLoss(mode='binary', from_logits=True) metrics = [smp.utils.metrics.IoU()] loss_aux = None if args.adv_training: assert args.processing_mode == 'parametrized', f"Processing mode ({args.processing_mode}) should be set to 'parametrized' for adversarial training" assert args.freeze_classifier, "Classifier should be frozen for adversarial training" assert not args.freeze_processor, "Processor should not be frozen for adversarial training" processor_default = copy.deepcopy(processor) processor_default.track_stages = False processor_default.eval() processor_default.to(DEVICE) # debug(processor_default) def l2_regularization(x, y): return (x - y).norm() if args.adv_aux_loss == 'l2': regularization = l2_regularization elif args.adv_aux_loss == 'ssim': regularization = SSIM(window_size=11) else: NotImplementedError(args.adv_aux_loss) class AuxLoss(nn.Module): def __init__(self, loss_aux, weight=1): super().__init__() self.loss_aux = loss_aux self.weight = weight def forward(self, x): x_reference = processor_default(x) x_processed = processor.buffer['processed_rgb'] return self.weight * self.loss_aux(x_reference, x_processed) class WeightedLoss(nn.Module): def __init__(self, loss, weight=1): super().__init__() self.loss = loss self.weight = weight def forward(self, x, y): return self.weight * self.loss(x, y) def __repr__(self): return f'{self.weight} * {get_name(self.loss)}' loss = WeightedLoss(loss=nn.CrossEntropyLoss(), weight=-1) # loss = WeightedLoss(loss=nn.CrossEntropyLoss(), weight=0) loss_aux = AuxLoss( loss_aux=regularization, weight=args.adv_aux_weight, ) augmentation = get_augmentation(args.augmentation) model = LitModel( classifier=classifier, processor=processor, loss=loss, loss_aux=loss_aux, adv_training=args.adv_training, metrics=metrics, augmentation=augmentation, is_segmentation_task=dataset.task == 'segmentation', freeze_classifier=args.freeze_classifier, freeze_processor=args.freeze_processor, ) # get train_set_dict if args.state_dict_uri: state_dict = mlflow.pytorch.load_state_dict(args.state_dict_uri) train_indices = state_dict['train_indices'] valid_indices = state_dict['valid_indices'] else: train_indices = idxs[0] valid_indices = idxs[1] state_dict = vars(args).copy() track_indices = list(range(args.track_n_images)) if dataset.task == 'classification': state_dict['classes'] = dataset.classes state_dict['device'] = DEVICE state_dict['train_indices'] = train_indices state_dict['valid_indices'] = valid_indices state_dict['elements in train set'] = len(train_indices) state_dict['elements in test set'] = len(valid_indices) if args.test_run: train_indices = train_indices[:args.batch_size] valid_indices = valid_indices[:args.batch_size] train_set = Subset(dataset, indices=train_indices) valid_set = Subset(dataset, indices=valid_indices) track_set = Subset(dataset, indices=track_indices) train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=16, shuffle=True) valid_loader = DataLoader(valid_set, batch_size=args.batch_size, num_workers=16, shuffle=False) track_loader = DataLoader(track_set, batch_size=args.batch_size, num_workers=16, shuffle=False) with mlflow.start_run(run_name=f"{args.run_name}_{k_iter}", nested=True) as child_run: #mlflow.pytorch.autolog(silent=True) if k_iter == 0: display_mlflow_run_info(child_run) mlflow.pytorch.log_state_dict(state_dict, artifact_path=None) hparams = { 'dataset': args.dataset, 'processing_mode': args.processing_mode, 'training_mode': training_mode, } if training_mode == 'adversarial': hparams['adv_aux_weight'] = args.adv_aux_weight hparams['adv_aux_loss'] = args.adv_aux_loss mlflow.log_params(hparams) with open('results/state_dict.txt', 'w') as f: f.write('python ' + ' '.join(sys.argv) + '\n') f.write('\n'.join([f'{k}={v}' for k, v in state_dict.items()])) mlflow.log_artifact('results/state_dict.txt', artifact_path=None) mlf_logger = pl.loggers.MLFlowLogger(experiment_name=args.experiment_name, tracking_uri=args.tracking_uri,) mlf_logger._run_id = child_run.info.run_id callbacks = [] if args.track_processing: callbacks += [TrackImagesCallback(track_loader, track_every_epoch=args.track_every_epoch, track_processing=args.track_processing, track_gradients=args.track_processing_gradients, track_predictions=args.track_predictions, save_tensors=args.track_save_tensors)] #if True: #args.save_best: # if dataset.task == 'classification': #checkpoint_callback = ModelCheckpoint(pathmonitor="val_accuracy", mode='max') # checkpoint_callback = ModelCheckpoint(dirpath=args.tracking_uri, save_top_k=1, verbose=True, monitor="val_accuracy", mode="max") #dirpath=args.tracking_uri, # else: # checkpoint_callback = ModelCheckpoint(monitor="val_iou_score") #callbacks += [checkpoint_callback] trainer = pl.Trainer( gpus=1 if DEVICE == 'cuda' else 0, min_epochs=args.epochs, max_epochs=args.epochs, logger=mlf_logger, callbacks=callbacks, check_val_every_n_epoch=args.check_val_every_n_epoch, #checkpoint_callback=True, ) if args.log_model: mlflow.pytorch.autolog(log_every_n_epoch=10) print(f'model_uri="{mlflow.get_artifact_uri()}/model"') t = trainer.fit( model, train_dataloader=train_loader, val_dataloaders=valid_loader, ) # if args.adv_training: # for (name, p1), p2 in zip(processor.named_parameters(), processor_default.cpu().parameters()): # print(f"param '{name}' diff: {p2 - p1}, l2: {(p2-p1).norm().item()}") return model if __name__ == '__main__': model = run_train(args)