jpterry commited on
Commit
68affc8
·
1 Parent(s): 6b4e6b7

fixed model shape

Browse files
Files changed (1) hide show
  1. app.py +3 -7
app.py CHANGED
@@ -142,15 +142,9 @@ def predict_and_analyze(model_name, num_channels, dim, image):
142
  # im = f.readlines()
143
  # image = np.frombuffer(image)
144
 
145
- print(image)
146
- print(type(image))
147
- print(image.name)
148
-
149
  image = np.load(image.name, allow_pickle=True)
150
 
151
- image = image.reshape((num_channels, W, W))
152
- print(image)
153
- print(type(image))
154
 
155
  # W = int(np.sqrt(image.shape[1]))
156
 
@@ -159,6 +153,8 @@ def predict_and_analyze(model_name, num_channels, dim, image):
159
  if len(image.shape) != 4:
160
  image = image[np.newaxis, :, :, :]
161
 
 
 
162
  input_image = np.sum(image[0, :, :, :], axis=0)
163
 
164
  model_name += '_%i' % (num_channels)
 
142
  # im = f.readlines()
143
  # image = np.frombuffer(image)
144
 
 
 
 
 
145
  image = np.load(image.name, allow_pickle=True)
146
 
147
+ # image = image.reshape((num_channels, W, W))
 
 
148
 
149
  # W = int(np.sqrt(image.shape[1]))
150
 
 
153
  if len(image.shape) != 4:
154
  image = image[np.newaxis, :, :, :]
155
 
156
+ assert image.shape == (1, num_channels, W, W), "Data is the wrong shape"
157
+
158
  input_image = np.sum(image[0, :, :, :], axis=0)
159
 
160
  model_name += '_%i' % (num_channels)