from matplotlib import cm import matplotlib.pyplot as plt # from mpl_toolkits.axes_grid1 import make_axes_locatable import numpy as np # import onnx import onnxruntime as ort # from onnx import helper import pandas as pd from scipy import special # import torch # import torch.utils.data import gradio as gr # from transformers import pipeline model_path = 'chlab/planet_detection_models/' # plotting a prameters labels = 20 ticks = 14 legends = 14 text = 14 titles = 22 lw = 3 ps = 200 cmap = 'magma' def normalize_array(x: list): '''Makes array between 0 and 1''' x = np.array(x) return (x - np.min(x)) / np.max(x - np.min(x)) def load_model(model: str, activation: bool=True): if activation: model += '_w_activation' ort_session = ort.InferenceSession(model_path + '%s.onnx' % (model)) return ort_session def get_activations(intermediate_model, image: list, layer=None, vmax=2.5, sub_mean=True): '''Gets activations for a given input image''' input_name = intermediate_model.get_inputs()[0].name outputs = intermediate_model.run(None, {input_name: image}) output_1 = outputs[1] output_2 = outputs[2] output = outputs[0] output = special.softmax(output) # origin = 'lower' # plt.rcParams['xtick.labelsize'] = ticks # plt.rcParams['ytick.labelsize'] = ticks # fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(28, 8)) # ax1, ax2, ax3 = axs[0], axs[1], axs[2] in_image = np.sum(image[0, :, :, :], axis=0) in_image = normalize_array(in_image) # im1 = ax1.imshow(in_image, cmap=cmap, vmin=0, vmax=vmax, origin=origin) if layer is None: activation_1 = np.sum(output_1[0, :, :, :], axis=0) activation_2 = np.sum(output_2[0, :, :, :], axis=0) else: activation_1 = output_1[0, layer, :, :] activation_2 = output_2[0, layer, :, :] if sub_mean: activation_1 -= np.mean(activation_1) activation_1 = np.abs(activation_1) activation_2 -= np.mean(activation_2) activation_2 = np.abs(activation_2) # im2 = ax2.imshow(activation_1, cmap=cmap, #vmin=0, vmax=1, # origin=origin) # im3 = ax3.imshow(activation_2, cmap=cmap, #vmin=0, vmax=1, # origin=origin) # ims = [im1, im2, im3] # for (i, ax) in enumerate(axs): # divider = make_axes_locatable(ax) # cax = divider.append_axes('right', size='5%', pad=0.05) # fig.colorbar(ims[i], cax=cax, orientation='vertical') # ax1.set_title('Input', fontsize=titles) # plt.show() return outputs[0], activation_1, activation_2 def predict_and_analyze(model_name, num_channels, dim, image): '''Loads a model with activations, passes through image and shows activations The image must be a pandas dataframe that can be made from a (C, W, H) numpy array using m,n,r = X.shape arr = np.column_stack((np.repeat(np.arange(c),w), X.reshape(c*w,-1))) df = pd.DataFrame(arr) image = 2d numpy array in shape (C, W*W) i.e. take a C,W,W array and reshape into (C, W*W) ''' num_channels = int(num_channels) W = int(dim) # image = image.read() # with open(image, 'rb') as f: # im = f.readlines() # image = np.frombuffer(image) image = np.load(image, allow_pickle=True) image = image.reshape((num_channels, W, W)) # W = int(np.sqrt(image.shape[1])) # image = image.reshape((num_channels, W, W)) if len(image.shape != 4): image = image[np.newaxis, :, :, :] input_image = np.sum(image[0, :, :, :], axis=0) model_name += '_%i' % (num_channels) model = load_model(model_name, activation=True) output, activation_1, activation_2 = get_activations(model, image, sub_mean=True) output = 'Planet prediction with %f percent confidence' % (100*output) return output, input_image, activation_1, activation_2 demo = gr.Interface( fn=predict_and_analyze, inputs=[gr.Dropdown(["regnet", "efficientnet"], value="efficientnet", label="Model Selection", show_label=True), gr.Dropdown(["45", "61", "75"], value="61", label="Number of Velocity Channels", show_label=True), gr.Dropdown(["600"], value="600", label="Image Dimensions", show_label=True), gr.File(label="Input Data", show_label=True)], outputs=[gr.Textbox(lines=1, label="Prediction", show_label=True), gr.Image(label="Input Image", show_label=True), gr.Image(label="Activation 1", show_label=True), gr.Image(label="Actication 2", show_label=True)], title="Kinematic Planet Detector" ) demo.launch()