import matplotlib matplotlib.use('WebAgg') import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import matplotlib.colors as mcolors from matplotlib.lines import Line2D from matplotlib import cm from mpl_toolkits.axes_grid1 import make_axes_locatable import tensorflow as tf import numpy as np import DeepDeformationMapRegistration.utils.constants as C from skimage.exposure import rescale_intensity import scipy.misc as scpmisc import os THRES = 0.9 # COLOR MAPS chunks = np.linspace(0, 1, 10) cmap1 = plt.get_cmap('hsv', 30) # cmaplist = [cmap1(i) for i in range(cmap1.N)] cmaplist = [(1, 1, 1, 1), (0, 0, 1, 1), (230 / 255, 97 / 255, 1 / 255, 1), (128 / 255, 0 / 255, 32 / 255, 1)] cmaplist[0] = (1, 1, 1, 1.0) cmap1 = mcolors.LinearSegmentedColormap.from_list('custom', cmaplist, cmap1.N) colors = [(0, 0, 1, i) for i in chunks] cmap2 = mcolors.LinearSegmentedColormap.from_list('mycmap', colors, N=100) colors = [(230 / 255, 97 / 255, 1 / 255, i) for i in chunks] cmap3 = mcolors.LinearSegmentedColormap.from_list('mycmap', colors, N=100) colors = [(128 / 255, 0 / 255, 32 / 255, i) for i in chunks] cmap4 = mcolors.LinearSegmentedColormap.from_list('mycmap', colors, N=100) cmap_bin = cm.get_cmap('viridis', 3) # viridis is the default colormap cmap_segs = np.asarray([mcolors.to_rgba(mcolors.CSS4_COLORS[c], 1) for c in mcolors.CSS4_COLORS.keys()]) cmap_segs.sort() # rnd_idxs = [30, 17, 72, 90, 74, 39, 120, 63, 52, 79, 140, 68, 131, 109, 57, 49, 11, 132, 29, 46, 51, 26, 53, 7, 89, 47, 43, 121, 31, 28, 106, 92, 130, 117, 91, 118, 61, 5, 80, 93, 58, 133, 14, 98, 116, 76, 113, 111, 136, 142, 95, 122, 86, 77, 36, 97, 141, 115, 18, 81, 88, 87, 44, 146, 103, 67, 147, 48, 42, 83, 128, 65, 139, 69, 27, 135, 94, 134, 50, 19, 114, 0, 96, 10, 138, 75, 13, 12, 102, 32, 66, 16, 8, 73, 85, 145, 54, 37, 70, 143] # cmap_segs = cmap_segs[rnd_idxs] np.random.shuffle(cmap_segs) cmap_segs[0, -1] = 0 cmap_segs = mcolors.ListedColormap(cmap_segs) def view_centerline_sample(sample: np.ndarray, dimensionality: int, ax=None, c=None, name=None): if dimensionality == 2: _plot_2d(sample, ax, c, name=name) elif dimensionality == 3: _plot_3d(sample, ax, c, name=name) else: raise ValueError('Invalid valud for dimensionality. Expected int 2 or 3') def matrix_to_orthographicProjection(matrix: np.ndarray, ret_list=False): """ Given a 3D matrix, returns the three orthographic projections: top, front, right. Top corresponds to dimensions 1 and 2 Front corresponds to dimensions 0 and 1 Right corresponds to dimensions 0 and 2 :param matrix: 3D matrix :param ret_list: return a list instead of an array (optional) :return: list or array with the three views [top, front, right] """ top = _getProjection(matrix, dim=0) # YZ front = _getProjection(matrix, dim=2) # XY right = _getProjection(matrix, dim=1) # XZ if ret_list: return top, front, right else: return np.asarray([top, front, right]) def _getProjection(matrix: np.ndarray, dim: int): orth_view = matrix.sum(axis=dim, dtype=float) orth_view = orth_view > 0.0 orth_view.astype(np.float) return orth_view def orthographicProjection_to_matrix(top: np.ndarray, front: np.ndarray, right: np.ndarray): """ Given the three orthographic projections, it returns a 3D-view of the object based on back projection :param top: 2D view top view :param front: 2D front view :param right: 2D right view :return: matrix with the 3D-view """ top_mat = np.tile(top, (front.shape[0], 1, 1)) front_mat = np.tile(top, (right.shape[1], 1, 1)) right_mat = np.tile(top, (top.shape[0], 1, 1)) reconstruction = np.zeros((front.shape[0], right.shape[1], top.shape[0])) iter = np.nditer([top_mat, front_mat, right_mat, reconstruction], flags=['multi_index'], op_flags=['readwrite']) while not iter.finished: if iter[0] and iter[1] and iter[2]: iter[3] = 1 iter.iternext() return reconstruction def _plot_2d(sample: np.ndarray, ax=None, c=None, name=None): if isinstance(sample, tf.Tensor): sample = sample.eval(session=tf.Session()) x_range = list() y_range = list() marker_size = list() for idx, val in np.ndenumerate(sample): if val >= THRES: x_range.append(idx[0]) y_range.append(idx[1]) marker_size.append(val ** 2) if not ax: fig = plt.figure() ax = fig.add_subplot(111) if c: ax.scatter(x_range, y_range, c=c, s=marker_size) else: ax.scatter(x_range, y_range, s=marker_size) ax.set_xlabel('X') ax.set_ylabel('Y') if name: ax.set_title(name) return ax def _plot_3d(sample: np.ndarray, ax=None, c=None, name=None): from mpl_toolkits.mplot3d import Axes3D if isinstance(sample, tf.Tensor): sample = sample.eval(session=tf.Session()) x_range = list() y_range = list() z_range = list() marker_size = list() for idx, val in np.ndenumerate(sample): if val >= THRES: x_range.append(idx[0]) y_range.append(idx[1]) z_range.append(idx[2]) marker_size.append(val ** 2) print('Found ', len(x_range), ' points') # x_range = np.linspace(start=0, stop=sample.shape[0], num=sample.shape[0]) # y_range = np.linspace(start=0, stop=sample.shape[1], num=sample.shape[1]) # z_range = np.linspace(start=0, stop=sample.shape[2], num=sample.shape[2]) # # sample_flat = sample.flatten(order='C') if len(x_range): if not ax: fig = plt.figure() ax = fig.add_subplot(111, projection='3d') if c: ax.scatter(x_range, y_range, z_range, c=c, s=marker_size) else: ax.scatter(x_range, y_range, z_range, s=marker_size) # ax.scatter(x_range, y_range, z_range, s=marker_size)#, c=sample_flat) ax.set_xlabel('X') ax.set_ylabel('Y') ax.set_zlabel('Z') if name: ax.set_title(name) return ax else: print('Nothing to plot') return None def plot_training(list_imgs: [np.ndarray], affine_transf=True, filename='img', fig=None): if fig is not None: fig.clear() plt.figure(fig.number) else: fig = plt.figure(dpi=C.DPI) ax_fix = fig.add_subplot(231) im_fix = ax_fix.imshow(list_imgs[0][:, :, 0]) ax_fix.set_title('Fix image', fontsize=C.FONT_SIZE) ax_fix.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_mov = fig.add_subplot(232) im_mov = ax_mov.imshow(list_imgs[1][:, :, 0]) ax_mov.set_title('Moving image', fontsize=C.FONT_SIZE) ax_mov.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_pred_im = fig.add_subplot(233) im_pred_im = ax_pred_im.imshow(list_imgs[2][:, :, 0]) ax_pred_im.set_title('Prediction', fontsize=C.FONT_SIZE) ax_pred_im.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_pred_disp = fig.add_subplot(234) if affine_transf: fake_bg = np.array([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]) bottom = np.asarray([0, 0, 0, 1]) transf_mat = np.reshape(list_imgs[3], (2, 3)) transf_mat = np.stack([transf_mat, bottom], axis=0) im_pred_disp = ax_pred_disp.imshow(fake_bg) for i in range(4): for j in range(4): ax_pred_disp.text(i, j, transf_mat[i, j], ha="center", va="center", color="b") ax_pred_disp.set_title('Affine transformation matrix') else: cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[3]) im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal') ax_pred_disp.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale) ax_pred_disp.set_title('Pred disp map', fontsize=C.FONT_SIZE) ax_pred_disp.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_gt_disp = fig.add_subplot(235) if affine_transf: fake_bg = np.array([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]) bottom = np.asarray([0, 0, 0, 1]) transf_mat = np.reshape(list_imgs[4], (2, 3)) transf_mat = np.stack([transf_mat, bottom], axis=0) im_gt_disp = ax_pred_disp.imshow(fake_bg) for i in range(4): for j in range(4): ax_pred_disp.text(i, j, transf_mat[i, j], ha="center", va="center", color="b") ax_pred_disp.set_title('Affine transformation matrix') else: cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[4]) im_gt_disp = ax_gt_disp.imshow(s, interpolation='none', aspect='equal') ax_gt_disp.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale) ax_gt_disp.set_title('GT disp map', fontsize=C.FONT_SIZE) ax_gt_disp.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) cb_fix = _set_colorbar(fig, ax_fix, im_fix, False) cb_mov = _set_colorbar(fig, ax_mov, im_mov, False) cb_pred = _set_colorbar(fig, ax_pred_im, im_pred_im, False) cb_pred_disp = _set_colorbar(fig, ax_pred_disp, im_pred_disp, False) cd_gt_disp = _set_colorbar(fig, ax_gt_disp, im_gt_disp, False) if filename is not None: plt.savefig(filename, format='png') # Call before show if not C.REMOTE: plt.show() else: plt.close() return fig def save_centreline_img(img, title, filename, fig=None): if fig is not None: fig.clear() plt.figure(fig.number) else: fig = plt.figure(dpi=C.DPI) dim = len(img.shape[:-1]) if dim == 2: ax = fig.add_subplot(111) fig.suptitle(title) im = ax.imshow(img[..., 0], cmap=cmap_bin) ax.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) #cb = _set_colorbar(fig, ax, im, False) else: ax = fig.add_subplot(111, projection='3d') fig.suptitle(title) im = ax.voxels(img[0, ..., 0] > 0.0) _square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax) ax.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) plt.savefig(filename, format='png') plt.close() def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None, show=False, step=1): if fig is not None: fig.clear() plt.figure(fig.number) else: fig = plt.figure(dpi=C.DPI) dim_h, dim_w, dim_d = disp_map.shape[1:-1] dim = disp_map.shape[-1] if dim == 2: ax_x = fig.add_subplot(131) ax_x.set_title('H displacement') im_x = ax_x.imshow(disp_map[..., ::step, ::step, C.H_DISP]) ax_x.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) cb_x = _set_colorbar(fig, ax_x, im_x, False) ax_y = fig.add_subplot(132) ax_y.set_title('W displacement') im_y = ax_y.imshow(disp_map[..., ::step, ::step, C.W_DISP]) ax_y.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) cb_y = _set_colorbar(fig, ax_y, im_y, False) ax = fig.add_subplot(133) if affine_transf: fake_bg = np.array([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]) bottom = np.asarray([0, 0, 0, 1]) transf_mat = np.reshape(disp_map, (2, 3)) transf_mat = np.stack([transf_mat, bottom], axis=0) im = ax.imshow(fake_bg) for i in range(4): for j in range(4): ax.text(i, j, transf_mat[i, j], ha="center", va="center", color="b") else: c, d, s = _prepare_quiver_map(disp_map, dim=dim, spc=step) im = ax.imshow(s, interpolation='none', aspect='equal') ax.quiver(c[C.H_DISP], c[C.W_DISP], d[C.H_DISP], d[C.W_DISP], scale=C.QUIVER_PARAMS.arrow_scale) cb = _set_colorbar(fig, ax, im, False) ax.set_title('Displacement map') ax.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) fig.suptitle(title) else: ax = fig.add_subplot(111, projection='3d') c, d, s = _prepare_quiver_map(disp_map[0, ...], dim=dim, spc=step) ax.quiver(c[C.H_DISP], c[C.W_DISP], c[C.D_DISP], d[C.H_DISP], d[C.W_DISP], d[C.D_DISP]) _square_3d_plot(np.arange(0, dim_h-1), np.arange(0, dim_w-1), np.arange(0, dim_d-1), ax) fig.suptitle('Displacement map') ax.tick_params(axis='both', # Same parameters as in 2D https://matplotlib.org/api/_as_gen/mpl_toolkits.mplot3d.axes3d.Axes3D.html which='both', bottom=False, left=False, labelleft=False, labelbottom=False) add_axes_arrows_3d(ax, xyz_label=['R', 'A', 'S']) fig.suptitle(title) plt.savefig(filename, format='png') if show: plt.show() plt.close() return fig def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, filename='img', fig=None, title_first_row='TRAINING', title_second_row='VALIDATION'): if fig is not None: fig.clear() plt.figure(fig.number) else: fig = plt.figure(dpi=C.DPI) dim = len(list_imgs[0].shape[:-1]) if dim == 2: # TRAINING ax_input = fig.add_subplot(241) ax_input.set_ylabel(title_first_row, fontsize=C.FONT_SIZE) im_fix = ax_input.imshow(list_imgs[0][:, :, 0]) ax_input.set_title('Fix image', fontsize=C.FONT_SIZE) ax_input.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_mov = fig.add_subplot(242) im_mov = ax_mov.imshow(list_imgs[1][:, :, 0]) ax_mov.set_title('Moving image', fontsize=C.FONT_SIZE) ax_mov.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_pred_im = fig.add_subplot(244) im_pred_im = ax_pred_im.imshow(list_imgs[2][:, :, 0]) ax_pred_im.set_title('Predicted fix image', fontsize=C.FONT_SIZE) ax_pred_im.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_pred_disp = fig.add_subplot(243) if affine_transf: fake_bg = np.array([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]) bottom = np.asarray([0, 0, 0, 1]) transf_mat = np.reshape(list_imgs[3], (2, 3)) transf_mat = np.stack([transf_mat, bottom], axis=0) im_pred_disp = ax_pred_disp.imshow(fake_bg) for i in range(4): for j in range(4): ax_pred_disp.text(i, j, transf_mat[i, j], ha="center", va="center", color="b") ax_pred_disp.set_title('Affine transformation matrix') else: cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[3], dim=dim) im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal') ax_pred_disp.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale) ax_pred_disp.set_title('Pred disp map', fontsize=C.FONT_SIZE) ax_pred_disp.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) # VALIDATION axinput_val = fig.add_subplot(245) axinput_val.set_ylabel(title_second_row, fontsize=C.FONT_SIZE) im_fix_val = axinput_val.imshow(list_imgs[4][:, :, 0]) axinput_val.set_title('Fix image', fontsize=C.FONT_SIZE) axinput_val.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_mov_val = fig.add_subplot(246) im_mov_val = ax_mov_val.imshow(list_imgs[5][:, :, 0]) ax_mov_val.set_title('Moving image', fontsize=C.FONT_SIZE) ax_mov_val.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_pred_im_val = fig.add_subplot(248) im_pred_im_val = ax_pred_im_val.imshow(list_imgs[6][:, :, 0]) ax_pred_im_val.set_title('Predicted fix image', fontsize=C.FONT_SIZE) ax_pred_im_val.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_pred_disp_val = fig.add_subplot(247) if affine_transf: fake_bg = np.array([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]) bottom = np.asarray([0, 0, 0, 1]) transf_mat = np.reshape(list_imgs[7], (2, 3)) transf_mat = np.stack([transf_mat, bottom], axis=0) im_pred_disp_val = ax_pred_disp_val.imshow(fake_bg) for i in range(4): for j in range(4): ax_pred_disp_val.text(i, j, transf_mat[i, j], ha="center", va="center", color="b") ax_pred_disp_val.set_title('Affine transformation matrix') else: c, d, s = _prepare_quiver_map(list_imgs[7], dim=dim) im_pred_disp_val = ax_pred_disp_val.imshow(s, interpolation='none', aspect='equal') ax_pred_disp_val.quiver(c[0], c[1], d[0], d[1], scale=C.QUIVER_PARAMS.arrow_scale) ax_pred_disp_val.set_title('Pred disp map', fontsize=C.FONT_SIZE) ax_pred_disp_val.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) cb_fix = _set_colorbar(fig, ax_input, im_fix, False) cb_mov = _set_colorbar(fig, ax_mov, im_mov, False) cb_pred = _set_colorbar(fig, ax_pred_im, im_pred_im, False) cb_pred_disp = _set_colorbar(fig, ax_pred_disp, im_pred_disp, False) cd_fix_val = _set_colorbar(fig, axinput_val, im_fix_val, False) cb_mov_val = _set_colorbar(fig, ax_mov_val, im_mov_val, False) cb_pred_val = _set_colorbar(fig, ax_pred_im_val, im_pred_im_val, False) cb_pred_disp_val = _set_colorbar(fig, ax_pred_disp_val, im_pred_disp_val, False) else: # 3D # TRAINING ax_input = fig.add_subplot(231, projection='3d') ax_input.set_ylabel(title_first_row, fontsize=C.FONT_SIZE) im_fix = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed') im_mov = ax_input.voxels(list_imgs[1][..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving') ax_input.set_title('Fix image', fontsize=C.FONT_SIZE) ax_input.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_pred_im = fig.add_subplot(232, projection='3d') im_pred_im = ax_input.voxels(list_imgs[2][..., 0] > 0.0, facecolors='green', edgecolors='green', label='Prediction') im_fix = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed') ax_pred_im.set_title('Predicted fix image', fontsize=C.FONT_SIZE) ax_pred_im.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_pred_disp = fig.add_subplot(233, projection='3d') c, d, s = _prepare_quiver_map(list_imgs[3], dim=dim) im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal') ax_pred_disp.quiver(c[C.H_DISP], c[C.W_DISP], c[C.D_DISP], d[C.H_DISP], d[C.W_DISP], d[C.D_DISP], scale=C.QUIVER_PARAMS.arrow_scale) ax_pred_disp.set_title('Pred disp map', fontsize=C.FONT_SIZE) ax_pred_disp.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) # VALIDATION axinput_val = fig.add_subplot(234, projection='3d') axinput_val.set_ylabel(title_second_row, fontsize=C.FONT_SIZE) im_fix_val = ax_input.voxels(list_imgs[4][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed (val)') im_mov_val = ax_input.voxels(list_imgs[5][..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving (val)') axinput_val.set_title('Fix image', fontsize=C.FONT_SIZE) axinput_val.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_pred_im_val = fig.add_subplot(235, projection='3d') im_pred_im_val = ax_input.voxels(list_imgs[2][..., 0] > 0.0, facecolors='green', edgecolors='green', label='Prediction (val)') im_fix_val = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed (val)') ax_pred_im_val.set_title('Predicted fix image', fontsize=C.FONT_SIZE) ax_pred_im_val.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_pred_disp_val = fig.add_subplot(236, projection='3d') c, d, s = _prepare_quiver_map(list_imgs[7], dim=dim) im_pred_disp_val = ax_pred_disp_val.imshow(s, interpolation='none', aspect='equal') ax_pred_disp_val.quiver(c[C.H_DISP], c[C.W_DISP], c[C.D_DISP], d[C.H_DISP], d[C.W_DISP], d[C.D_DISP], scale=C.QUIVER_PARAMS.arrow_scale) ax_pred_disp_val.set_title('Pred disp map', fontsize=C.FONT_SIZE) ax_pred_disp_val.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) if filename is not None: plt.savefig(filename, format='png') # Call before show if not C.REMOTE: plt.show() else: plt.close() return fig def _set_colorbar(fig, ax, im, drawedges=True): div = make_axes_locatable(ax) im_cax = div.append_axes('right', size='5%', pad=0.05) im_cb = fig.colorbar(im, cax=im_cax, drawedges=drawedges, shrink=0.5, orientation='vertical') im_cb.ax.tick_params(labelsize=5) return im_cb def _prepare_quiver_map(disp_map: np.ndarray, dim=2, spc=C.QUIVER_PARAMS.spacing): if isinstance(disp_map, tf.Tensor): if tf.executing_eagerly(): disp_map = disp_map.numpy() else: disp_map = disp_map.eval() dx = disp_map[..., C.H_DISP] dy = disp_map[..., C.W_DISP] if dim > 2: dz = disp_map[..., C.D_DISP] img_size_x = disp_map.shape[C.H_DISP] img_size_y = disp_map.shape[C.W_DISP] if dim > 2: img_size_z = disp_map.shape[C.D_DISP] if dim > 2: s = np.sqrt(np.square(dx) + np.square(dy) + np.square(dz)) s = np.reshape(s, [img_size_x, img_size_y, img_size_z]) cx, cy, cz = np.meshgrid(list(range(0, img_size_x)), list(range(0, img_size_y)), list(range(0, img_size_z)), indexing='ij') c = [cx[::spc, ::spc, ::spc], cy[::spc, ::spc, ::spc], cz[::spc, ::spc, ::spc]] d = [dx[::spc, ::spc, ::spc], dy[::spc, ::spc, ::spc], dz[::spc, ::spc, ::spc]] else: s = np.sqrt(np.square(dx) + np.square(dy)) s = np.reshape(s, [img_size_x, img_size_y]) cx, cy = np.meshgrid(list(range(0, img_size_x)), list(range(0, img_size_y))) c = [cx[::spc, ::spc], cy[::spc, ::spc]] d = [dx[::spc, ::spc], dy[::spc, ::spc]] return c, d, s def _prepare_colormap(disp_map: np.ndarray): if isinstance(disp_map, tf.Tensor): disp_map = disp_map.eval() dx = disp_map[:, :, 0] dy = disp_map[:, :, 1] mod_img = np.zeros_like(dx) for i in range(dx.shape[0]): for j in range(dx.shape[1]): vec = np.asarray([dx[i, j], dy[i, j]]) mod_img[i, j] = np.linalg.norm(vec, ord=2) p_l, p_h = np.percentile(mod_img, (0, 100)) mod_img = rescale_intensity(mod_img, in_range=(p_l, p_h), out_range=(0, 255)) return mod_img def plot_input_data(fix_img, mov_img, img_size=(64, 64), title=None, filename=None): num_samples = fix_img.shape[0] if num_samples != 16 and num_samples != 32: raise ValueError('Only batches of 16 or 32 samples!') fig, ax = plt.subplots(nrows=4 if num_samples == 16 else 8, ncols=4) ncol = 0 nrow = 0 black_col = np.ones([img_size[0], 0]) for sample in range(num_samples): combined_img = np.hstack([fix_img[sample, :, :, 0], black_col, mov_img[sample, :, :, 0]]) ax[nrow, ncol].imshow(combined_img, cmap='Greys') ax[nrow, ncol].set_ylabel('#{}'.format(sample)) ax[nrow, ncol].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ncol += 1 if ncol >= 4: ncol = 0 nrow += 1 if title is not None: fig.suptitle(title) if filename is not None: plt.savefig(filename, format='png') # Call before show if not C.REMOTE: plt.show() else: plt.close() return fig def plot_dataset_orthographic_views(view_sets: [[np.ndarray]]): """ :param views_fix: Expected order: top, front, left :param views_mov: Expected order: top, front, left :return: """ nrows = len(view_sets) fig, ax = plt.subplots(nrows=nrows, ncols=3) labels = ['top', 'front', 'left'] for nrow in range(nrows): for ncol in range(3): if nrows == 1: ax[ncol].imshow(view_sets[nrow][ncol][:, :, 0]) ax[ncol].set_title('Fix ' + labels[ncol]) ax[ncol].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) else: ax[nrow, ncol].imshow(view_sets[nrow][ncol][:, :, 0]) ax[nrow, ncol].set_title('Fix ' + labels[ncol]) ax[nrow, ncol].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) plt.show() return fig def plot_compare_2d_images(img1, img2, img1_name='img1', img2_name='img2'): fig, ax = plt.subplots(nrows=1, ncols=2) ax[0].imshow(img1[:, :, 0]) ax[0].set_title(img1_name) ax[0].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax[1].imshow(img2[:, :, 0]) ax[1].set_title(img2_name) ax[1].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) plt.show() return fig def plot_dataset_3d(img_sets): from mpl_toolkits.mplot3d import Axes3D fig = plt.figure() ax = fig.add_subplot(111, projection='3d') for idx in range(len(img_sets)): ax = _plot_3d(img_sets[idx], ax=ax, name='Set {}'.format(idx)) plt.show() return fig def plot_predictions(img_batches, disp_map_batch, seg_batches=None, step=1, filename='predictions', fig=None, show=False): fix_img_batch, mov_img_batch, pred_img_batch = img_batches if seg_batches != None: fix_seg_batch, mov_seg_batch, pred_seg_batch = seg_batches else: fix_seg_batch = mov_seg_batch = pred_seg_batch = None num_rows = fix_img_batch.shape[0] img_dim = len(fix_img_batch.shape) - 2 img_size = fix_img_batch.shape[1:-1] if fig is not None: fig.clear() plt.figure(fig.number) ax = fig.add_subplot(nrows=num_rows, ncols=5, figsize=(10, 3*num_rows), dpi=C.DPI) else: fig, ax = plt.subplots(nrows=num_rows, ncols=5, figsize=(10, 3*num_rows), dpi=C.DPI) if num_rows == 1: ax = ax[np.newaxis, ...] if img_dim == 3: # Extract slices from the images selected_slice = img_size[0] // 2 fix_img_batch = fix_img_batch[:, selected_slice, ...] mov_img_batch = mov_img_batch[:, selected_slice, ...] pred_img_batch = pred_img_batch[:, selected_slice, ...] if seg_batches != None: fix_seg_batch = fix_seg_batch[:, selected_slice, ...] mov_seg_batch = mov_seg_batch[:, selected_slice, ...] pred_seg_batch = pred_seg_batch[:, selected_slice, ...] disp_map_batch = disp_map_batch[:, selected_slice, ..., 1:] # Only the sagittal and longitudinal axes img_size = fix_img_batch.shape[1:-1] elif img_dim != 2: raise ValueError('Images have a bad shape: {}'.format(fix_img_batch.shape)) for row in range(num_rows): fix_img = fix_img_batch[row, :, :, 0].transpose() mov_img = mov_img_batch[row, :, :, 0].transpose() pred_img = pred_img_batch[row, :, :, 0].transpose() if seg_batches != None: fix_seg = fix_seg_batch[row, :, :, 0].transpose() mov_seg= mov_seg_batch[row, :, :, 0].transpose() pred_seg = pred_seg_batch[row, :, :, 0].transpose() disp_map = disp_map_batch[row, :, :, :].transpose((1, 0, 2)) ax[row, 0].imshow(fix_img, origin='lower', cmap='gray') if seg_batches != None: ax[row, 0].imshow(fix_seg, origin='lower', cmap=cmap_segs) ax[row, 0].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax[row, 1].imshow(mov_img, origin='lower', cmap='gray') if seg_batches != None: ax[row, 1].imshow(mov_seg, origin='lower', cmap=cmap_segs) ax[row, 1].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) c, d, s = _prepare_quiver_map(disp_map, spc=step) cx, cy = c dx, dy = d disp_map_color = _prepare_colormap(disp_map) ax[row, 2].imshow(disp_map_color, interpolation='none', aspect='equal', origin='lower') ax[row, 2].quiver(cx, cy, dx, dy, units='dots', scale=1) ax[row, 2].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax[row, 3].imshow(mov_img, origin='lower', cmap='gray') if seg_batches != None: ax[row, 3].imshow(mov_seg, origin='lower', cmap=cmap_segs) ax[row, 3].quiver(cx, cy, dx, dy, units='dots', scale=1, color='w') ax[row, 3].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax[row, 4].imshow(pred_img, origin='lower', cmap='gray') if seg_batches != None: ax[row, 4].imshow(pred_seg, origin='lower', cmap=cmap_segs) ax[row, 4].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) plt.axis('off') ax[0, 0].set_title('Fixed img ($I_f$)', fontsize=C.FONT_SIZE) ax[0, 1].set_title('Moving img ($I_m$)', fontsize=C.FONT_SIZE) ax[0, 2].set_title('Backwards\ndisp .map ($\delta$)', fontsize=C.FONT_SIZE) ax[0, 3].set_title('Disp. map over $I_m$', fontsize=C.FONT_SIZE) ax[0, 4].set_title('Predicted $I_m$', fontsize=C.FONT_SIZE) plt.tight_layout() if filename is not None: plt.savefig(filename, format='png') # Call before show if show: plt.show() plt.close() return fig def inspect_disp_map_generation(fix_img, mov_img, disp_map, filename=None, fig=None): if fig is not None: fig.clear() plt.figure(fig.number) else: fig = plt.figure(dpi=C.DPI) ax0 = fig.add_subplot(221) im0 = ax0.imshow(fix_img[..., 0]) ax0.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax1 = fig.add_subplot(222) im1 = ax1.imshow(mov_img[..., 0]) ax1.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) cx, cy, dx, dy, s = _prepare_quiver_map(disp_map) disp_map_color = _prepare_colormap(disp_map) ax2 = fig.add_subplot(223) im2 = ax2.imshow(s, interpolation='none', aspect='equal') ax2.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale) # ax2.figure.set_size_inches(img_size) ax2.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax3 = fig.add_subplot(224) dif = fix_img[..., 0] - mov_img[..., 0] im3 = ax3.imshow(dif) ax3.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale) ax3.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) plt.axis('off') ax0.set_title('Fixed img ($I_f$)', fontsize=C.FONT_SIZE) ax1.set_title('Moving img ($I_m$)', fontsize=C.FONT_SIZE) ax2.set_title('Displacement map', fontsize=C.FONT_SIZE) ax3.set_title('Fix - Mov', fontsize=C.FONT_SIZE) im0_cb = _set_colorbar(fig, ax0, im0, False) im1_cb = _set_colorbar(fig, ax1, im1, False) disp_cb = _set_colorbar(fig, ax2, im2, False) im3_cb = _set_colorbar(fig, ax3, im3, False) if filename is not None: plt.savefig(filename, format='png') # Call before show if not C.REMOTE: plt.show() else: plt.close() return fig def inspect_displacement_grid(ctrl_coords, dense_coords, target_coords, disp_coords, disp_map, mask, fix_img, mov_img, filename=None, fig=None): if fig is not None: fig.clear() plt.figure(fig.number) else: fig = plt.figure() ax_grid = fig.add_subplot(231) ax_grid.set_title('Grids', fontsize=C.FONT_SIZE) ax_grid.scatter(ctrl_coords[:, 0], ctrl_coords[:, 1], marker='+', c='r', s=20) ax_grid.scatter(dense_coords[:, 0], dense_coords[:, 1], marker='.', c='r', s=1) ax_grid.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_grid.scatter(target_coords[:, 0], target_coords[:, 1], marker='+', c='b', s=20) ax_grid.scatter(disp_coords[:, 0], disp_coords[:, 1], marker='.', c='b', s=1) ax_grid.set_aspect('equal') ax_disp = fig.add_subplot(232) ax_disp.set_title('Displacement map', fontsize=C.FONT_SIZE) cx, cy, dx, dy, s = _prepare_quiver_map(disp_map) ax_disp.imshow(s, interpolation='none', aspect='equal') ax_disp.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_mask = fig.add_subplot(233) ax_mask.set_title('Mask', fontsize=C.FONT_SIZE) ax_mask.imshow(mask) ax_mask.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_fix = fig.add_subplot(234) ax_fix.set_title('Fix image', fontsize=C.FONT_SIZE) ax_fix.imshow(fix_img[..., 0]) ax_fix.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_mov = fig.add_subplot(235) ax_mov.set_title('Moving image', fontsize=C.FONT_SIZE) ax_mov.imshow(mov_img[..., 0]) ax_mov.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_dif = fig.add_subplot(236) ax_dif.set_title('Fix - Moving image', fontsize=C.FONT_SIZE) ax_dif.imshow(fix_img[..., 0] - mov_img[..., 0], cmap=cmap_bin) ax_dif.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) legend_elems = [Line2D([0], [0], color=cmap_bin(0), lw=2), Line2D([0], [0], color=cmap_bin(2), lw=2)] ax_dif.legend(legend_elems, ['Mov', 'Fix'], loc='upper left', bbox_to_anchor=(0., 0., 1., 0.), ncol=2, mode='expand') if filename is not None: plt.savefig(filename, format='png') # Call before show if not C.REMOTE: plt.show() return fig def compare_disp_maps(disp_m_f, disp_f_m, fix_img, mov_img, filename=None, fig=None): if fig is not None: fig.clear() plt.figure(fig.number) else: fig = plt.figure() ax_d_m_f = fig.add_subplot(131) ax_d_m_f.set_title('Disp M->F', fontsize=C.FONT_SIZE) cx, cy, dx, dy, s = _prepare_quiver_map(disp_m_f) ax_d_m_f.imshow(s, interpolation='none', aspect='equal') ax_d_m_f.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale) ax_d_m_f.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_d_f_m = fig.add_subplot(132) ax_d_f_m.set_title('Disp F->M', fontsize=C.FONT_SIZE) cx, cy, dx, dy, s = _prepare_quiver_map(disp_f_m) ax_d_f_m.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale) ax_d_f_m.imshow(s, interpolation='none', aspect='equal') ax_d_f_m.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) ax_dif = fig.add_subplot(133) ax_dif.set_title('Fix - Moving image', fontsize=C.FONT_SIZE) ax_dif.imshow(fix_img[..., 0] - mov_img[..., 0], cmap=cmap_bin) ax_dif.tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False) legend_elems = [Line2D([0], [0], color=cmap_bin(0), lw=2), Line2D([0], [0], color=cmap_bin(2), lw=2)] ax_dif.legend(legend_elems, ['Mov', 'Fix'], loc='upper left', bbox_to_anchor=(0., 0., 1., 0.), ncol=2, mode='expand') if filename is not None: plt.savefig(filename, format='png') # Call before show if not C.REMOTE: plt.show() else: plt.close() return fig def plot_train_step(list_imgs: [np.ndarray], fig_title='TRAINING', dest_folder='.', save_file=True): # list_imgs[0]: fix image # list_imgs[1]: moving image # list_imgs[2]: prediction scale 1 # list_imgs[3]: prediction scale 2 # list_imgs[4]: prediction scale 3 # list_imgs[5]: disp map scale 1 # list_imgs[6]: disp map scale 2 # list_imgs[7]: disp map scale 3 num_imgs = len(list_imgs) num_preds = (num_imgs - 2) // 2 num_cols = num_preds + 1 # 3D # TRAINING fig = plt.figure(figsize=(12.8, 10.24)) fig.tight_layout(pad=5.0) ax = fig.add_subplot(2, num_cols, 1, projection='3d') ax.voxels(list_imgs[0][0, ..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed') ax.set_title('Fix image', fontsize=C.FONT_SIZE) _square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax) for i in range(2, num_preds+2): ax = fig.add_subplot(2, num_cols, i, projection='3d') ax.voxels(list_imgs[0][0, ..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed') ax.voxels(list_imgs[i][0, ..., 0] > 0.0, facecolors='green', edgecolors='green', label='Pred_{}'.format(i - 1)) ax.set_title('Pred. #{}'.format(i - 1)) _square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax) ax = fig.add_subplot(2, num_cols, num_preds+2, projection='3d') ax.voxels(list_imgs[1][0, ..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving') ax.set_title('Fix image', fontsize=C.FONT_SIZE) _square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax) for i in range(num_preds+2, 2 * num_preds + 2): ax = fig.add_subplot(2, num_cols, i + 1, projection='3d') c, d, s = _prepare_quiver_map(list_imgs[i][0, ...], dim=3) ax.quiver(c[C.H_DISP], c[C.W_DISP], c[C.D_DISP], d[C.H_DISP], d[C.W_DISP], d[C.D_DISP], norm=True) ax.set_title('Disp. #{}'.format(i - 5)) _square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax) fig.suptitle(fig_title, fontsize=C.FONT_SIZE) if save_file: plt.savefig(os.path.join(dest_folder, fig_title+'.png'), format='png') # Call before show if not C.REMOTE: plt.show() else: plt.close() return fig def _square_3d_plot(X, Y, Z, ax): max_range = np.array([X.max() - X.min(), Y.max() - Y.min(), Z.max() - Z.min()]).max() / 2.0 mid_x = (X.max() + X.min()) * 0.5 mid_y = (Y.max() + Y.min()) * 0.5 mid_z = (Z.max() + Z.min()) * 0.5 ax.set_xlim(mid_x - max_range, mid_x + max_range) ax.set_ylim(mid_y - max_range, mid_y + max_range) ax.set_zlim(mid_z - max_range, mid_z + max_range) def remove_tick_labels(ax, project_3d=False): ax.set_xticklabels([]) ax.set_yticklabels([]) if project_3d: ax.set_zticklabels([]) return ax def add_axes_arrows_3d(ax, arrow_length=10, xyz_colours=['r', 'g', 'b'], xyz_label=['X', 'Y', 'Z'], dist_arrow_text=3): x_limits = ax.get_xlim3d() y_limits = ax.get_ylim3d() z_limits = ax.get_zlim3d() x_len = x_limits[1] - x_limits[0] y_len = y_limits[1] - y_limits[0] z_len = z_limits[1] - z_limits[0] ax.quiver(x_limits[0], y_limits[0], z_limits[0], x_len, 0, 0, color=xyz_colours[0], arrow_length_ratio=0) # (init_loc, end_loc, params) ax.quiver(x_limits[0], y_limits[0], z_limits[0], 0, y_len, 0, color=xyz_colours[1], arrow_length_ratio=0) ax.quiver(x_limits[0], y_limits[0], z_limits[0], 0, 0, z_len, color=xyz_colours[2], arrow_length_ratio=0) # X axis ax.quiver(x_limits[1], y_limits[0], z_limits[0], arrow_length, 0, 0, color=xyz_colours[0]) ax.text(x_limits[1] + arrow_length + dist_arrow_text, y_limits[0], z_limits[0], xyz_label[0], fontsize=20, ha='right', va='top') # Y axis ax.quiver(x_limits[0], y_limits[1], z_limits[0], 0, arrow_length, 0, color=xyz_colours[1]) ax.text(x_limits[0], y_limits[1] + arrow_length + dist_arrow_text, z_limits[0], xyz_label[0], fontsize=20, ha='left', va='top') # Z axis ax.quiver(x_limits[0], y_limits[0], z_limits[1], 0, 0, arrow_length, color=xyz_colours[2]) ax.text(x_limits[0], y_limits[0], z_limits[1] + arrow_length + dist_arrow_text, xyz_label[0], fontsize=20, ha='center', va='bottom') return ax def add_axes_arrows_2d(ax, arrow_length=10, xy_colour=['r', 'g'], xy_label=['X', 'Y']): x_limits = list(ax.get_xlim()) y_limits = list(ax.get_ylim()) origin = (x_limits[0], y_limits[1]) ax.annotate('', xy=(origin[0] + arrow_length, origin[1]), xytext=origin, arrowprops=dict(headlength=8., headwidth=10., width=5., color=xy_colour[0])) ax.annotate('', xy=(origin[0], origin[1] + arrow_length), xytext=origin, arrowprops=dict(headlength=8., headwidth=10., width=5., color=xy_colour[0])) ax.text(origin[0] + arrow_length, origin[1], xy_label[0], fontsize=25, ha='left', va='bottom') ax.text(origin[0] - 1, origin[1] + arrow_length, xy_label[1], fontsize=25, ha='right', va='top') return ax def set_axes_size(w,h, ax=None): """ w, h: width, height in inches """ if not ax: ax=plt.gca() l = ax.figure.subplotpars.left r = ax.figure.subplotpars.right t = ax.figure.subplotpars.top b = ax.figure.subplotpars.bottom figw = float(w)/(r-l) figh = float(h)/(t-b) ax.figure.set_size_inches(figw, figh)