jpterry commited on
Commit
3cf0e0f
·
1 Parent(s): 688e887

casting to float32

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -79,10 +79,10 @@ def get_activations(intermediate_model, image: list,
79
 
80
  def predict_and_analyze(model_name, num_channels, dim, image):
81
 
82
- '''Loads a model with activations, passes through image and shows activations
 
83
 
84
  The image must be a numpy array of shape (C, W, W) or (1, C, W, W)
85
-
86
  '''
87
 
88
  num_channels = int(num_channels)
@@ -90,6 +90,7 @@ def predict_and_analyze(model_name, num_channels, dim, image):
90
 
91
  print("Loading data")
92
  image = np.load(image.name, allow_pickle=True)
 
93
 
94
  if len(image.shape) != 4:
95
  image = image[np.newaxis, :, :, :]
@@ -97,7 +98,7 @@ def predict_and_analyze(model_name, num_channels, dim, image):
97
  assert image.shape == (1, num_channels, W, W), "Data is the wrong shape"
98
 
99
  model_name += '_%i' % (num_channels)
100
-
101
  print("Loading model")
102
  model = load_model(model_name, activation=True)
103
  print("Model loaded")
@@ -153,7 +154,7 @@ if __name__ == "__main__":
153
 
154
  demo = gr.Interface(
155
  fn=predict_and_analyze,
156
- inputs=[gr.Dropdown(["regnet", "efficientnet"],
157
  value="efficientnet",
158
  label="Model Selection",
159
  show_label=True),
 
79
 
80
  def predict_and_analyze(model_name, num_channels, dim, image):
81
 
82
+ '''
83
+ Loads a model with activations, passes through image and shows activations
84
 
85
  The image must be a numpy array of shape (C, W, W) or (1, C, W, W)
 
86
  '''
87
 
88
  num_channels = int(num_channels)
 
90
 
91
  print("Loading data")
92
  image = np.load(image.name, allow_pickle=True)
93
+ image = image.astype(np.float32)
94
 
95
  if len(image.shape) != 4:
96
  image = image[np.newaxis, :, :, :]
 
98
  assert image.shape == (1, num_channels, W, W), "Data is the wrong shape"
99
 
100
  model_name += '_%i' % (num_channels)
101
+
102
  print("Loading model")
103
  model = load_model(model_name, activation=True)
104
  print("Model loaded")
 
154
 
155
  demo = gr.Interface(
156
  fn=predict_and_analyze,
157
+ inputs=[gr.Dropdown(["efficientnet", "regnet"],
158
  value="efficientnet",
159
  label="Model Selection",
160
  show_label=True),