Spaces:
Sleeping
Sleeping
| import warnings | |
| import iml | |
| import numpy as np | |
| from iml import Instance, Model | |
| from iml.datatypes import DenseData | |
| from iml.explanations import AdditiveExplanation | |
| from iml.links import IdentityLink | |
| from scipy.stats import gaussian_kde | |
| import matplotlib | |
| try: | |
| import matplotlib.pyplot as pl | |
| from matplotlib.colors import LinearSegmentedColormap | |
| from matplotlib.ticker import MaxNLocator | |
| cdict1 = { | |
| 'red': ((0.0, 0.11764705882352941, 0.11764705882352941), | |
| (1.0, 0.9607843137254902, 0.9607843137254902)), | |
| 'green': ((0.0, 0.5333333333333333, 0.5333333333333333), | |
| (1.0, 0.15294117647058825, 0.15294117647058825)), | |
| 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745), | |
| (1.0, 0.3411764705882353, 0.3411764705882353)), | |
| 'alpha': ((0.0, 1, 1), | |
| (0.5, 0.3, 0.3), | |
| (1.0, 1, 1)) | |
| } # #1E88E5 -> #ff0052 | |
| red_blue = LinearSegmentedColormap('RedBlue', cdict1) | |
| cdict1 = { | |
| 'red': ((0.0, 0.11764705882352941, 0.11764705882352941), | |
| (1.0, 0.9607843137254902, 0.9607843137254902)), | |
| 'green': ((0.0, 0.5333333333333333, 0.5333333333333333), | |
| (1.0, 0.15294117647058825, 0.15294117647058825)), | |
| 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745), | |
| (1.0, 0.3411764705882353, 0.3411764705882353)), | |
| 'alpha': ((0.0, 1, 1), | |
| (0.5, 1, 1), | |
| (1.0, 1, 1)) | |
| } # #1E88E5 -> #ff0052 | |
| red_blue_solid = LinearSegmentedColormap('RedBlue', cdict1) | |
| except ImportError: | |
| pass | |
| labels = { | |
| 'MAIN_EFFECT': "SHAP main effect value for\n%s", | |
| 'INTERACTION_VALUE': "SHAP interaction value", | |
| 'INTERACTION_EFFECT': "SHAP interaction value for\n%s and %s", | |
| 'VALUE': "SHAP value (impact on model output)", | |
| 'VALUE_FOR': "SHAP value for\n%s", | |
| 'PLOT_FOR': "SHAP plot for %s", | |
| 'FEATURE': "Feature %s", | |
| 'FEATURE_VALUE': "Feature value", | |
| 'FEATURE_VALUE_LOW': "Low", | |
| 'FEATURE_VALUE_HIGH': "High", | |
| 'JOINT_VALUE': "Joint SHAP value" | |
| } | |
| def shap_summary_plot(shap_values, features=None, feature_names=None, max_display=None, plot_type="dot", | |
| color=None, axis_color="#333333", title=None, alpha=1, show=True, sort=True, | |
| color_bar=True, auto_size_plot=True, layered_violin_max_num_bins=20): | |
| """Create a SHAP summary plot, colored by feature values when they are provided. | |
| Parameters | |
| ---------- | |
| shap_values : numpy.array | |
| Matrix of SHAP values (# samples x # features) | |
| features : numpy.array or pandas.DataFrame or list | |
| Matrix of feature values (# samples x # features) or a feature_names list as shorthand | |
| feature_names : list | |
| Names of the features (length # features) | |
| max_display : int | |
| How many top features to include in the plot (default is 20, or 7 for interaction plots) | |
| plot_type : "dot" (default) or "violin" | |
| What type of summary plot to produce | |
| """ | |
| assert len(shap_values.shape) != 1, "Summary plots need a matrix of shap_values, not a vector." | |
| # default color: | |
| if color is None: | |
| color = "coolwarm" if plot_type == 'layered_violin' else "#ff0052" | |
| # convert from a DataFrame or other types | |
| if str(type(features)) == "<class 'pandas.core.frame.DataFrame'>": | |
| if feature_names is None: | |
| feature_names = features.columns | |
| features = features.values | |
| elif str(type(features)) == "<class 'list'>": | |
| if feature_names is None: | |
| feature_names = features | |
| features = None | |
| elif (features is not None) and len(features.shape) == 1 and feature_names is None: | |
| feature_names = features | |
| features = None | |
| if feature_names is None: | |
| feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)] | |
| mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1)) | |
| # plotting SHAP interaction values | |
| if len(shap_values.shape) == 3: | |
| if max_display is None: | |
| max_display = 7 | |
| else: | |
| max_display = min(len(feature_names), max_display) | |
| sort_inds = np.argsort(-np.abs(shap_values[:, :-1, :-1].sum(1)).sum(0)) | |
| # get plotting limits | |
| delta = 1.0 / (shap_values.shape[1] ** 2) | |
| slow = np.nanpercentile(shap_values, delta) | |
| shigh = np.nanpercentile(shap_values, 100 - delta) | |
| v = max(abs(slow), abs(shigh)) | |
| slow = -0.2 | |
| shigh = 0.2 | |
| # mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1)) | |
| ax = mpl_fig.subplot(1, max_display, 1) | |
| proj_shap_values = shap_values[:, sort_inds[0], np.hstack((sort_inds, len(sort_inds)))] | |
| proj_shap_values[:, 1:] *= 2 # because off diag effects are split in half | |
| shap_summary_plot( | |
| proj_shap_values, features[:, sort_inds], | |
| feature_names=feature_names[sort_inds], | |
| sort=False, show=False, color_bar=False, | |
| auto_size_plot=False, | |
| max_display=max_display | |
| ) | |
| pl.xlim((slow, shigh)) | |
| pl.xlabel("") | |
| title_length_limit = 11 | |
| pl.title(shorten_text(feature_names[sort_inds[0]], title_length_limit)) | |
| for i in range(1, max_display): | |
| ind = sort_inds[i] | |
| pl.subplot(1, max_display, i + 1) | |
| proj_shap_values = shap_values[:, ind, np.hstack((sort_inds, len(sort_inds)))] | |
| proj_shap_values *= 2 | |
| proj_shap_values[:, i] /= 2 # because only off diag effects are split in half | |
| shap_summary_plot( | |
| proj_shap_values, features[:, sort_inds], | |
| sort=False, | |
| feature_names=["" for i in range(features.shape[1])], | |
| show=False, | |
| color_bar=False, | |
| auto_size_plot=False, | |
| max_display=max_display | |
| ) | |
| pl.xlim((slow, shigh)) | |
| pl.xlabel("") | |
| if i == max_display // 2: | |
| pl.xlabel(labels['INTERACTION_VALUE']) | |
| pl.title(shorten_text(feature_names[ind], title_length_limit)) | |
| pl.tight_layout(pad=0, w_pad=0, h_pad=0.0) | |
| pl.subplots_adjust(hspace=0, wspace=0.1) | |
| # if show: | |
| # # pl.show() | |
| return mpl_fig | |
| if max_display is None: | |
| max_display = 20 | |
| if sort: | |
| # order features by the sum of their effect magnitudes | |
| feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1]) | |
| feature_order = feature_order[-min(max_display, len(feature_order)):] | |
| else: | |
| feature_order = np.flip(np.arange(min(max_display, shap_values.shape[1] - 1)), 0) | |
| row_height = 0.4 | |
| if auto_size_plot: | |
| pl.gcf().set_size_inches(8, len(feature_order) * row_height + 1.5) | |
| pl.axvline(x=0, color="#999999", zorder=-1) | |
| if plot_type == "dot": | |
| for pos, i in enumerate(feature_order): | |
| pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1) | |
| shaps = shap_values[:, i] | |
| values = None if features is None else features[:, i] | |
| inds = np.arange(len(shaps)) | |
| np.random.shuffle(inds) | |
| if values is not None: | |
| values = values[inds] | |
| shaps = shaps[inds] | |
| colored_feature = True | |
| try: | |
| values = np.array(values, dtype=np.float64) # make sure this can be numeric | |
| except: | |
| colored_feature = False | |
| N = len(shaps) | |
| # hspacing = (np.max(shaps) - np.min(shaps)) / 200 | |
| # curr_bin = [] | |
| nbins = 100 | |
| quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8)) | |
| inds = np.argsort(quant + np.random.randn(N) * 1e-6) | |
| layer = 0 | |
| last_bin = -1 | |
| ys = np.zeros(N) | |
| for ind in inds: | |
| if quant[ind] != last_bin: | |
| layer = 0 | |
| ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1) | |
| layer += 1 | |
| last_bin = quant[ind] | |
| ys *= 0.9 * (row_height / np.max(ys + 1)) | |
| if features is not None and colored_feature: | |
| # trim the color range, but prevent the color range from collapsing | |
| vmin = np.nanpercentile(values, 5) | |
| vmax = np.nanpercentile(values, 95) | |
| if vmin == vmax: | |
| vmin = np.nanpercentile(values, 1) | |
| vmax = np.nanpercentile(values, 99) | |
| if vmin == vmax: | |
| vmin = np.min(values) | |
| vmax = np.max(values) | |
| assert features.shape[0] == len(shaps), "Feature and SHAP matrices must have the same number of rows!" | |
| nan_mask = np.isnan(values) | |
| pl.scatter(shaps[nan_mask], pos + ys[nan_mask], color="#777777", vmin=vmin, | |
| vmax=vmax, s=16, alpha=alpha, linewidth=0, | |
| zorder=3, rasterized=len(shaps) > 500) | |
| pl.scatter(shaps[np.invert(nan_mask)], pos + ys[np.invert(nan_mask)], | |
| cmap=red_blue, vmin=vmin, vmax=vmax, s=16, | |
| c=values[np.invert(nan_mask)], alpha=alpha, linewidth=0, | |
| zorder=3, rasterized=len(shaps) > 500) | |
| else: | |
| pl.scatter(shaps, pos + ys, s=16, alpha=alpha, linewidth=0, zorder=3, | |
| color=color if colored_feature else "#777777", rasterized=len(shaps) > 500) | |
| elif plot_type == "violin": | |
| for pos, i in enumerate(feature_order): | |
| pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1) | |
| if features is not None: | |
| global_low = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 1) | |
| global_high = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 99) | |
| for pos, i in enumerate(feature_order): | |
| shaps = shap_values[:, i] | |
| shap_min, shap_max = np.min(shaps), np.max(shaps) | |
| rng = shap_max - shap_min | |
| xs = np.linspace(np.min(shaps) - rng * 0.2, np.max(shaps) + rng * 0.2, 100) | |
| if np.std(shaps) < (global_high - global_low) / 100: | |
| ds = gaussian_kde(shaps + np.random.randn(len(shaps)) * (global_high - global_low) / 100)(xs) | |
| else: | |
| ds = gaussian_kde(shaps)(xs) | |
| ds /= np.max(ds) * 3 | |
| values = features[:, i] | |
| window_size = max(10, len(values) // 20) | |
| smooth_values = np.zeros(len(xs) - 1) | |
| sort_inds = np.argsort(shaps) | |
| trailing_pos = 0 | |
| leading_pos = 0 | |
| running_sum = 0 | |
| back_fill = 0 | |
| for j in range(len(xs) - 1): | |
| while leading_pos < len(shaps) and xs[j] >= shaps[sort_inds[leading_pos]]: | |
| running_sum += values[sort_inds[leading_pos]] | |
| leading_pos += 1 | |
| if leading_pos - trailing_pos > 20: | |
| running_sum -= values[sort_inds[trailing_pos]] | |
| trailing_pos += 1 | |
| if leading_pos - trailing_pos > 0: | |
| smooth_values[j] = running_sum / (leading_pos - trailing_pos) | |
| for k in range(back_fill): | |
| smooth_values[j - k - 1] = smooth_values[j] | |
| else: | |
| back_fill += 1 | |
| vmin = np.nanpercentile(values, 5) | |
| vmax = np.nanpercentile(values, 95) | |
| if vmin == vmax: | |
| vmin = np.nanpercentile(values, 1) | |
| vmax = np.nanpercentile(values, 99) | |
| if vmin == vmax: | |
| vmin = np.min(values) | |
| vmax = np.max(values) | |
| pl.scatter(shaps, np.ones(shap_values.shape[0]) * pos, s=9, cmap=red_blue_solid, vmin=vmin, vmax=vmax, | |
| c=values, alpha=alpha, linewidth=0, zorder=1) | |
| # smooth_values -= nxp.nanpercentile(smooth_values, 5) | |
| # smooth_values /= np.nanpercentile(smooth_values, 95) | |
| smooth_values -= vmin | |
| if vmax - vmin > 0: | |
| smooth_values /= vmax - vmin | |
| for i in range(len(xs) - 1): | |
| if ds[i] > 0.05 or ds[i + 1] > 0.05: | |
| pl.fill_between([xs[i], xs[i + 1]], [pos + ds[i], pos + ds[i + 1]], | |
| [pos - ds[i], pos - ds[i + 1]], color=red_blue_solid(smooth_values[i]), | |
| zorder=2) | |
| else: | |
| parts = pl.violinplot(shap_values[:, feature_order], range(len(feature_order)), points=200, vert=False, | |
| widths=0.7, | |
| showmeans=False, showextrema=False, showmedians=False) | |
| for pc in parts['bodies']: | |
| pc.set_facecolor(color) | |
| pc.set_edgecolor('none') | |
| pc.set_alpha(alpha) | |
| elif plot_type == "layered_violin": # courtesy of @kodonnell | |
| num_x_points = 200 | |
| bins = np.linspace(0, features.shape[0], layered_violin_max_num_bins + 1).round(0).astype( | |
| 'int') # the indices of the feature data corresponding to each bin | |
| shap_min, shap_max = np.min(shap_values[:, :-1]), np.max(shap_values[:, :-1]) | |
| x_points = np.linspace(shap_min, shap_max, num_x_points) | |
| # loop through each feature and plot: | |
| for pos, ind in enumerate(feature_order): | |
| # decide how to handle: if #unique < layered_violin_max_num_bins then split by unique value, otherwise use bins/percentiles. | |
| # to keep simpler code, in the case of uniques, we just adjust the bins to align with the unique counts. | |
| feature = features[:, ind] | |
| unique, counts = np.unique(feature, return_counts=True) | |
| if unique.shape[0] <= layered_violin_max_num_bins: | |
| order = np.argsort(unique) | |
| thesebins = np.cumsum(counts[order]) | |
| thesebins = np.insert(thesebins, 0, 0) | |
| else: | |
| thesebins = bins | |
| nbins = thesebins.shape[0] - 1 | |
| # order the feature data so we can apply percentiling | |
| order = np.argsort(feature) | |
| # x axis is located at y0 = pos, with pos being there for offset | |
| y0 = np.ones(num_x_points) * pos | |
| # calculate kdes: | |
| ys = np.zeros((nbins, num_x_points)) | |
| for i in range(nbins): | |
| # get shap values in this bin: | |
| shaps = shap_values[order[thesebins[i]:thesebins[i + 1]], ind] | |
| # if there's only one element, then we can't | |
| if shaps.shape[0] == 1: | |
| warnings.warn( | |
| "not enough data in bin #%d for feature %s, so it'll be ignored. Try increasing the number of records to plot." | |
| % (i, feature_names[ind])) | |
| # to ignore it, just set it to the previous y-values (so the area between them will be zero). Not ys is already 0, so there's | |
| # nothing to do if i == 0 | |
| if i > 0: | |
| ys[i, :] = ys[i - 1, :] | |
| continue | |
| # save kde of them: note that we add a tiny bit of gaussian noise to avoid singular matrix errors | |
| ys[i, :] = gaussian_kde(shaps + np.random.normal(loc=0, scale=0.001, size=shaps.shape[0]))(x_points) | |
| # scale it up so that the 'size' of each y represents the size of the bin. For continuous data this will | |
| # do nothing, but when we've gone with the unqique option, this will matter - e.g. if 99% are male and 1% | |
| # female, we want the 1% to appear a lot smaller. | |
| size = thesebins[i + 1] - thesebins[i] | |
| bin_size_if_even = features.shape[0] / nbins | |
| relative_bin_size = size / bin_size_if_even | |
| ys[i, :] *= relative_bin_size | |
| # now plot 'em. We don't plot the individual strips, as this can leave whitespace between them. | |
| # instead, we plot the full kde, then remove outer strip and plot over it, etc., to ensure no | |
| # whitespace | |
| ys = np.cumsum(ys, axis=0) | |
| width = 0.8 | |
| scale = ys.max() * 2 / width # 2 is here as we plot both sides of x axis | |
| for i in range(nbins - 1, -1, -1): | |
| y = ys[i, :] / scale | |
| c = pl.get_cmap(color)(i / ( | |
| nbins - 1)) if color in pl.cm.datad else color # if color is a cmap, use it, otherwise use a color | |
| pl.fill_between(x_points, pos - y, pos + y, facecolor=c) | |
| pl.xlim(shap_min, shap_max) | |
| # draw the color bar | |
| if color_bar and features is not None and (plot_type != "layered_violin" or color in pl.cm.datad): | |
| import matplotlib.cm as cm | |
| m = cm.ScalarMappable(cmap=red_blue_solid if plot_type != "layered_violin" else pl.get_cmap(color)) | |
| m.set_array([0, 1]) | |
| cb = pl.colorbar(m, ticks=[0, 1], aspect=1000) | |
| cb.set_ticklabels([labels['FEATURE_VALUE_LOW'], labels['FEATURE_VALUE_HIGH']]) | |
| cb.set_label(labels['FEATURE_VALUE'], size=12, labelpad=0) | |
| cb.ax.tick_params(labelsize=11, length=0) | |
| cb.set_alpha(1) | |
| cb.outline.set_visible(False) | |
| bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted()) | |
| cb.ax.set_aspect((bbox.height - 0.9) * 20) | |
| # cb.draw_all() | |
| pl.gca().xaxis.set_ticks_position('bottom') | |
| pl.gca().yaxis.set_ticks_position('none') | |
| pl.gca().spines['right'].set_visible(False) | |
| pl.gca().spines['top'].set_visible(False) | |
| pl.gca().spines['left'].set_visible(False) | |
| pl.gca().tick_params(color=axis_color, labelcolor=axis_color) | |
| pl.yticks(range(len(feature_order)), [feature_names[i] for i in feature_order], fontsize=13) | |
| pl.gca().tick_params('y', length=20, width=0.5, which='major') | |
| pl.gca().tick_params('x', labelsize=11) | |
| pl.ylim(-1, len(feature_order)) | |
| pl.xlabel(labels['VALUE'], fontsize=13) | |
| pl.tight_layout() | |
| # if show: | |
| # pl.show() | |
| return mpl_fig | |
| def approx_interactions(index, shap_values, X): | |
| """ Order other features by how much interaction they seem to have with the feature at the given index. | |
| This just bins the SHAP values for a feature along that feature's value. For true Shapley interaction | |
| index values for SHAP see the interaction_contribs option implemented in XGBoost. | |
| """ | |
| if X.shape[0] > 10000: | |
| a = np.arange(X.shape[0]) | |
| np.random.shuffle(a) | |
| inds = a[:10000] | |
| else: | |
| inds = np.arange(X.shape[0]) | |
| x = X[inds, index] | |
| srt = np.argsort(x) | |
| shap_ref = shap_values[inds, index] | |
| shap_ref = shap_ref[srt] | |
| inc = max(min(int(len(x) / 10.0), 50), 1) | |
| interactions = [] | |
| for i in range(X.shape[1]): | |
| val_other = X[inds, i][srt].astype(np.float) | |
| v = 0.0 | |
| if not (i == index or np.sum(np.abs(val_other)) < 1e-8): | |
| for j in range(0, len(x), inc): | |
| if np.std(val_other[j:j + inc]) > 0 and np.std(shap_ref[j:j + inc]) > 0: | |
| v += abs(np.corrcoef(shap_ref[j:j + inc], val_other[j:j + inc])[0, 1]) | |
| interactions.append(v) | |
| return np.argsort(-np.abs(interactions)) | |
| def shap_dependence_plot(ind, shap_values, features, feature_names=None, display_features=None, | |
| interaction_index="auto", color="#1E88E5", axis_color="#333333", | |
| dot_size=16, alpha=1, title=None, show=True): | |
| """ | |
| Create a SHAP dependence plot, colored by an interaction feature. | |
| Parameters | |
| ---------- | |
| ind : int | |
| Index of the feature to plot. | |
| shap_values : numpy.array | |
| Matrix of SHAP values (# samples x # features) | |
| features : numpy.array or pandas.DataFrame | |
| Matrix of feature values (# samples x # features) | |
| feature_names : list | |
| Names of the features (length # features) | |
| display_features : numpy.array or pandas.DataFrame | |
| Matrix of feature values for visual display (such as strings instead of coded values) | |
| interaction_index : "auto", None, or int | |
| The index of the feature used to color the plot. | |
| """ | |
| # convert from DataFrames if we got any | |
| if str(type(features)).endswith("'pandas.core.frame.DataFrame'>"): | |
| if feature_names is None: | |
| feature_names = features.columns | |
| features = features.values | |
| if str(type(display_features)).endswith("'pandas.core.frame.DataFrame'>"): | |
| if feature_names is None: | |
| feature_names = display_features.columns | |
| display_features = display_features.values | |
| elif display_features is None: | |
| display_features = features | |
| if feature_names is None: | |
| feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)] | |
| # allow vectors to be passed | |
| if len(shap_values.shape) == 1: | |
| shap_values = np.reshape(shap_values, len(shap_values), 1) | |
| if len(features.shape) == 1: | |
| features = np.reshape(features, len(features), 1) | |
| def convert_name(ind): | |
| if type(ind) == str: | |
| nzinds = np.where(feature_names == ind)[0] | |
| if len(nzinds) == 0: | |
| print("Could not find feature named: " + ind) | |
| return None | |
| else: | |
| return nzinds[0] | |
| else: | |
| return ind | |
| ind = convert_name(ind) | |
| mpl_fig = pl.gcf() | |
| ax = mpl_fig.gca() | |
| # plotting SHAP interaction values | |
| if len(shap_values.shape) == 3 and len(ind) == 2: | |
| ind1 = convert_name(ind[0]) | |
| ind2 = convert_name(ind[1]) | |
| if ind1 == ind2: | |
| proj_shap_values = shap_values[:, ind2, :] | |
| else: | |
| proj_shap_values = shap_values[:, ind2, :] * 2 # off-diag values are split in half | |
| # TODO: remove recursion; generally the functions should be shorter for more maintainable code | |
| return shap_dependence_plot( | |
| ind1, proj_shap_values, features, feature_names=feature_names, | |
| interaction_index=ind2, display_features=display_features, show=False | |
| ) | |
| assert shap_values.shape[0] == features.shape[0], \ | |
| "'shap_values' and 'features' values must have the same number of rows!" | |
| assert shap_values.shape[1] == features.shape[1], \ | |
| "'shap_values' must have the same number of columns as 'features'!" | |
| # get both the raw and display feature values | |
| xv = features[:, ind] | |
| xd = display_features[:, ind] | |
| s = shap_values[:, ind] | |
| if type(xd[0]) == str: | |
| name_map = {} | |
| for i in range(len(xv)): | |
| name_map[xd[i]] = xv[i] | |
| xnames = list(name_map.keys()) | |
| # allow a single feature name to be passed alone | |
| if type(feature_names) == str: | |
| feature_names = [feature_names] | |
| name = feature_names[ind] | |
| # guess what other feature as the stongest interaction with the plotted feature | |
| if interaction_index == "auto": | |
| interaction_index = approx_interactions(ind, shap_values, features)[0] | |
| interaction_index = convert_name(interaction_index) | |
| categorical_interaction = False | |
| # get both the raw and display color values | |
| if interaction_index is not None: | |
| cv = features[:, interaction_index] | |
| cd = display_features[:, interaction_index] | |
| clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5) | |
| chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95) | |
| if type(cd[0]) == str: | |
| cname_map = {} | |
| for i in range(len(cv)): | |
| cname_map[cd[i]] = cv[i] | |
| cnames = list(cname_map.keys()) | |
| categorical_interaction = True | |
| elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50: | |
| categorical_interaction = True | |
| # discritize colors for categorical features | |
| color_norm = None | |
| if categorical_interaction and clow != chigh: | |
| bounds = np.linspace(clow, chigh, chigh - clow + 2) | |
| color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N) | |
| # the actual scatter plot, TODO: adapt the dot_size to the number of data points? | |
| if interaction_index is not None: | |
| pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue, | |
| alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500) | |
| else: | |
| pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5", | |
| alpha=alpha, rasterized=len(xv) > 500) | |
| if interaction_index != ind and interaction_index is not None: | |
| # draw the color bar | |
| if type(cd[0]) == str: | |
| tick_positions = [cname_map[n] for n in cnames] | |
| if len(tick_positions) == 2: | |
| tick_positions[0] -= 0.25 | |
| tick_positions[1] += 0.25 | |
| cb = pl.colorbar(ticks=tick_positions) | |
| cb.set_ticklabels(cnames) | |
| else: | |
| cb = pl.colorbar() | |
| cb.set_label(feature_names[interaction_index], size=13) | |
| cb.ax.tick_params(labelsize=11) | |
| if categorical_interaction: | |
| cb.ax.tick_params(length=0) | |
| cb.set_alpha(1) | |
| cb.outline.set_visible(False) | |
| bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted()) | |
| cb.ax.set_aspect((bbox.height - 0.7) * 20) | |
| # make the plot more readable | |
| if interaction_index != ind: | |
| pl.gcf().set_size_inches(7.5, 5) | |
| else: | |
| pl.gcf().set_size_inches(6, 5) | |
| # pl.xlabel(name, color=axis_color, fontsize=13) | |
| # pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13) | |
| if title is not None: | |
| pl.title(title, color=axis_color, fontsize=13) | |
| pl.gca().xaxis.set_ticks_position('bottom') | |
| pl.gca().yaxis.set_ticks_position('left') | |
| pl.gca().spines['right'].set_visible(False) | |
| pl.gca().spines['top'].set_visible(False) | |
| pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11) | |
| for spine in pl.gca().spines.values(): | |
| spine.set_edgecolor(axis_color) | |
| if type(xd[0]) == str: | |
| pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11) | |
| # if show: | |
| # pl.show() | |
| if ind1 == ind2: | |
| pl.ylabel(labels['MAIN_EFFECT'] % feature_names[ind1]) | |
| else: | |
| pl.ylabel(labels['INTERACTION_EFFECT'] % (feature_names[ind1], feature_names[ind2])) | |
| return mpl_fig, interaction_index | |
| # # if show: | |
| # # pl.show() | |
| # return | |
| # return mpl_fig | |
| # assert shap_values.shape[0] == features.shape[0], "'shap_values' and 'features' values must have the same number of rows!" | |
| # assert shap_values.shape[1] == features.shape[1] + 1, "'shap_values' must have one more column than 'features'!" | |
| # get both the raw and display feature values | |
| xv = features[:, ind] | |
| xd = display_features[:, ind] | |
| s = shap_values[:, ind] | |
| if type(xd[0]) == str: | |
| name_map = {} | |
| for i in range(len(xv)): | |
| name_map[xd[i]] = xv[i] | |
| xnames = list(name_map.keys()) | |
| # allow a single feature name to be passed alone | |
| if type(feature_names) == str: | |
| feature_names = [feature_names] | |
| name = feature_names[ind] | |
| # guess what other feature as the stongest interaction with the plotted feature | |
| if interaction_index == "auto": | |
| interaction_index = approx_interactions(ind, shap_values, features)[0] | |
| interaction_index = convert_name(interaction_index) | |
| categorical_interaction = False | |
| # get both the raw and display color values | |
| if interaction_index is not None: | |
| cv = features[:, interaction_index] | |
| cd = display_features[:, interaction_index] | |
| clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5) | |
| chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95) | |
| if type(cd[0]) == str: | |
| cname_map = {} | |
| for i in range(len(cv)): | |
| cname_map[cd[i]] = cv[i] | |
| cnames = list(cname_map.keys()) | |
| categorical_interaction = True | |
| elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50: | |
| categorical_interaction = True | |
| # discritize colors for categorical features | |
| color_norm = None | |
| if categorical_interaction and clow != chigh: | |
| bounds = np.linspace(clow, chigh, chigh - clow + 2) | |
| color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N) | |
| # the actual scatter plot, TODO: adapt the dot_size to the number of data points? | |
| if interaction_index is not None: | |
| pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue, | |
| alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500) | |
| else: | |
| pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5", | |
| alpha=alpha, rasterized=len(xv) > 500) | |
| if interaction_index != ind and interaction_index is not None: | |
| # draw the color bar | |
| if type(cd[0]) == str: | |
| tick_positions = [cname_map[n] for n in cnames] | |
| if len(tick_positions) == 2: | |
| tick_positions[0] -= 0.25 | |
| tick_positions[1] += 0.25 | |
| cb = pl.colorbar(ticks=tick_positions) | |
| cb.set_ticklabels(cnames) | |
| else: | |
| cb = pl.colorbar() | |
| cb.set_label(feature_names[interaction_index], size=13) | |
| cb.ax.tick_params(labelsize=11) | |
| if categorical_interaction: | |
| cb.ax.tick_params(length=0) | |
| cb.set_alpha(1) | |
| cb.outline.set_visible(False) | |
| bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted()) | |
| cb.ax.set_aspect((bbox.height - 0.7) * 20) | |
| # make the plot more readable | |
| if interaction_index != ind: | |
| pl.gcf().set_size_inches(7.5, 5) | |
| else: | |
| pl.gcf().set_size_inches(6, 5) | |
| pl.xlabel(name, color=axis_color, fontsize=13) | |
| pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13) | |
| if title is not None: | |
| pl.title(title, color=axis_color, fontsize=13) | |
| pl.gca().xaxis.set_ticks_position('bottom') | |
| pl.gca().yaxis.set_ticks_position('left') | |
| pl.gca().spines['right'].set_visible(False) | |
| pl.gca().spines['top'].set_visible(False) | |
| pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11) | |
| for spine in pl.gca().spines.values(): | |
| spine.set_edgecolor(axis_color) | |
| if type(xd[0]) == str: | |
| pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11) | |
| # if show: | |
| # pl.show() | |
| return mpl_fig, interaction_index |