jpterry commited on
Commit
086820c
·
1 Parent(s): ce86ea7

initial commit

Browse files
Files changed (2) hide show
  1. app.py +192 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from matplotlib import cm
2
+ import matplotlib.pyplot as plt
3
+ # from mpl_toolkits.axes_grid1 import make_axes_locatable
4
+
5
+ import numpy as np
6
+
7
+ # import onnx
8
+ import onnxruntime as ort
9
+ # from onnx import helper
10
+
11
+ import pandas as pd
12
+
13
+ from scipy import special
14
+
15
+ # import torch
16
+ # import torch.utils.data
17
+
18
+ import gradio as gr
19
+ # from transformers import pipeline
20
+
21
+
22
+ model_path = 'chlab/planet_detection_models/'
23
+
24
+ # plotting a prameters
25
+ labels = 20
26
+ ticks = 14
27
+ legends = 14
28
+ text = 14
29
+ titles = 22
30
+ lw = 3
31
+ ps = 200
32
+ cmap = 'magma'
33
+
34
+ def normalize_array(x: list):
35
+
36
+ '''Makes array between 0 and 1'''
37
+
38
+ x = np.array(x)
39
+
40
+ return (x - np.min(x)) / np.max(x - np.min(x))
41
+
42
+ def load_model(model: str, activation: bool=True):
43
+
44
+ if activation:
45
+ model += '_w_activation'
46
+
47
+ ort_session = ort.InferenceSession(model_path + '%s.onnx' % (model))
48
+
49
+ return ort_session
50
+
51
+ def get_activations(intermediate_model, image: list,
52
+ layer=None, vmax=2.5, sub_mean=True):
53
+
54
+ '''Gets activations for a given input image'''
55
+
56
+
57
+ input_name = intermediate_model.get_inputs()[0].name
58
+ outputs = intermediate_model.run(None, {input_name: image})
59
+
60
+ output_1 = outputs[1]
61
+ output_2 = outputs[2]
62
+
63
+ output = outputs[0]
64
+ output = special.softmax(output)
65
+
66
+ # origin = 'lower'
67
+
68
+ # plt.rcParams['xtick.labelsize'] = ticks
69
+ # plt.rcParams['ytick.labelsize'] = ticks
70
+
71
+ # fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(28, 8))
72
+
73
+ # ax1, ax2, ax3 = axs[0], axs[1], axs[2]
74
+
75
+ in_image = np.sum(image[0, :, :, :], axis=0)
76
+ in_image = normalize_array(in_image)
77
+
78
+
79
+ # im1 = ax1.imshow(in_image, cmap=cmap, vmin=0, vmax=vmax, origin=origin)
80
+ if layer is None:
81
+ activation_1 = np.sum(output_1[0, :, :, :], axis=0)
82
+ activation_2 = np.sum(output_2[0, :, :, :], axis=0)
83
+ else:
84
+ activation_1 = output_1[0, layer, :, :]
85
+ activation_2 = output_2[0, layer, :, :]
86
+
87
+ if sub_mean:
88
+ activation_1 -= np.mean(activation_1)
89
+ activation_1 = np.abs(activation_1)
90
+
91
+ activation_2 -= np.mean(activation_2)
92
+ activation_2 = np.abs(activation_2)
93
+
94
+
95
+ # im2 = ax2.imshow(activation_1, cmap=cmap, #vmin=0, vmax=1,
96
+ # origin=origin)
97
+ # im3 = ax3.imshow(activation_2, cmap=cmap, #vmin=0, vmax=1,
98
+ # origin=origin)
99
+ # ims = [im1, im2, im3]
100
+
101
+ # for (i, ax) in enumerate(axs):
102
+ # divider = make_axes_locatable(ax)
103
+ # cax = divider.append_axes('right', size='5%', pad=0.05)
104
+ # fig.colorbar(ims[i], cax=cax, orientation='vertical')
105
+
106
+ # ax1.set_title('Input', fontsize=titles)
107
+
108
+ # plt.show()
109
+
110
+ return outputs[0], activation_1, activation_2
111
+
112
+
113
+ def predict_and_analyze(model_name, num_channels, dim, image):
114
+
115
+ '''Loads a model with activations, passes through image and shows activations
116
+
117
+ The image must be a pandas dataframe that can be made from a (C, W, H) numpy array
118
+ using
119
+
120
+ m,n,r = X.shape
121
+ arr = np.column_stack((np.repeat(np.arange(c),w),
122
+ X.reshape(c*w,-1)))
123
+ df = pd.DataFrame(arr)
124
+
125
+
126
+ image = 2d numpy array in shape (C, W*W)
127
+ i.e. take a C,W,W array and reshape into (C, W*W)
128
+
129
+ '''
130
+
131
+ num_channels = int(num_channels)
132
+ W = int(dim)
133
+
134
+ image = image.read()
135
+ image = np.frombuffer(image)
136
+ image = image.reshape((num_channels, W, W))
137
+
138
+ # W = int(np.sqrt(image.shape[1]))
139
+
140
+ # image = image.reshape((num_channels, W, W))
141
+
142
+ if len(image.shape != 4):
143
+ image = image[np.newaxis, :, :, :]
144
+
145
+ input_image = np.sum(image[0, :, :, :], axis=0)
146
+
147
+ model_name += '_%i' % (num_channels)
148
+
149
+ model = load_model(model_name, activation=True)
150
+
151
+ output, activation_1, activation_2 = get_activations(model, image, sub_mean=True)
152
+
153
+ output = 'Planet prediction with %f percent confidence' % (100*output)
154
+
155
+ return output, input_image, activation_1, activation_2
156
+
157
+
158
+ demo = gr.Interface(
159
+ fn=predict_and_analyze,
160
+ inputs=[gr.Dropdown(["regnet", "efficientnet"],
161
+ value="efficientnet",
162
+ label="Model Selection",
163
+ show_label=True),
164
+ gr.Dropdown(["45", "61", "75"],
165
+ value="61",
166
+ label="Number of Velocity Channels",
167
+ show_label=True),
168
+ gr.Dropdown(["600"],
169
+ value="600",
170
+ label="Image Dimensions",
171
+ show_label=True),
172
+ gr.File(label="Input Data", show_label=True)],
173
+ outputs=[gr.Textbox(lines=1, label="Prediction", show_label=True),
174
+ gr.Image(label="Input Image", show_label=True),
175
+ gr.Image(label="Activation 1", show_label=True),
176
+ gr.Image(label="Actication 2", show_label=True)],
177
+ title="Kinematic Planet Detector"
178
+ )
179
+ demo.launch(share=True)
180
+
181
+
182
+
183
+
184
+
185
+
186
+
187
+
188
+
189
+
190
+
191
+
192
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ matplotlib
4
+ scipy
5
+ onnx
6
+ onnxruntime
7
+ streamlit