File size: 2,731 Bytes
2458333 a27d55f 2458333 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import os, sys
currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
import numpy as np
import tensorflow as tf
import voxelmorph as vxm
import neurite as ne
from datetime import datetime
import ddmr.utils.constants as C
from ddmr.data_generator import DataGeneratorManager
from ddmr.utils.misc import try_mkdir
from ddmr.utils.nifti_utils import save_nifti
from ddmr.networks import WeaklySupervised
from ddmr.losses import HausdorffDistanceErosion
from ddmr.layers import UncertaintyWeighting
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # Check availability before running using 'nvidia-smi'
C.TRAINING_DATASET = '/mnt/EncryptedData1/Users/javier/vessel_registration/sanity_dataset_vessels'
C.BATCH_SIZE = 2
C.LIMIT_NUM_SAMPLES = None
C.EPOCHS = 10000
# Load data
# Build data generator
data_generator = DataGeneratorManager(C.TRAINING_DATASET, C.BATCH_SIZE, True, C.LIMIT_NUM_SAMPLES,
1 - C.TRAINING_PERC, voxelmorph=True, segmentations=True)
train_generator = data_generator.get_generator('train')
validation_generator = data_generator.get_generator('validation')
data_folder = '../train_3d_multiloss_segm_haus_dice_ncc_grad_203925-29012021'
# Build model
in_shape = train_generator.get_input_shape()[1:-1]
enc_features = [16, 32, 32, 32, 32, 32]# const.ENCODER_FILTERS
dec_features = [32, 32, 32, 32, 32, 32, 32, 16, 16]# const.ENCODER_FILTERS[::-1]
nb_features = [enc_features, dec_features]
vxm_model = WeaklySupervised(inshape=in_shape, all_labels=[1], nb_unet_features=nb_features, int_steps=5)
vxm_model.load_weights(os.path.join(data_folder, 'checkpoints', 'best_model.h5'), by_name=True)
# Get some samples and plot them
sample = validation_generator[0]
samp_id = 1
pred_img, pred_seg, pred_flow = vxm_model.predict([sample[0][0][samp_id, ...][np.newaxis, ...],
sample[0][1][samp_id, ...][np.newaxis, ...],
sample[0][2][samp_id, ...][np.newaxis, ...]])
save_nifti(np.squeeze(pred_img), os.path.join(data_folder, 'pred_img.nii.gz'))
save_nifti(np.squeeze(pred_seg), os.path.join(data_folder, 'pred_seg.nii.gz'))
save_nifti(sample[0][0][samp_id, ...], os.path.join(data_folder, 'mov_seg.nii.gz'))
save_nifti(sample[0][1][samp_id, ...], os.path.join(data_folder, 'fix_seg.nii.gz'))
save_nifti(sample[0][2][samp_id, ...], os.path.join(data_folder, 'mov_img.nii.gz'))
save_nifti(sample[0][-2][samp_id, ...], os.path.join(data_folder, 'fix_img.nii.gz'))
|