TibbtechUser commited on
Commit
454f522
·
verified ·
1 Parent(s): 202ae2b

Updated app.py

Browse files
.gitattributes CHANGED
@@ -40,3 +40,6 @@ Overlapped_updated/38.png filter=lfs diff=lfs merge=lfs -text
40
  Overlapped_updated/39.png filter=lfs diff=lfs merge=lfs -text
41
  Overlapped_updated/40.png filter=lfs diff=lfs merge=lfs -text
42
  Overlapped_updated/43.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
40
  Overlapped_updated/39.png filter=lfs diff=lfs merge=lfs -text
41
  Overlapped_updated/40.png filter=lfs diff=lfs merge=lfs -text
42
  Overlapped_updated/43.png filter=lfs diff=lfs merge=lfs -text
43
+ registration/annotation_25.nrrd filter=lfs diff=lfs merge=lfs -text
44
+ registration/average_template_25.nrrd filter=lfs diff=lfs merge=lfs -text
45
+ registration/CCFv3OntologyStructure_u16.xlsx filter=lfs diff=lfs merge=lfs -text
Brain_1.png ADDED
app.py CHANGED
@@ -4,10 +4,312 @@ import numpy as np
4
  import os
5
  from PIL import Image
6
  import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  example_files = [
9
- ["./resampled_green_25.nii.gz"],
10
- # ["examples/sample2.nii.gz"],
11
  # ["examples/sample3.nii.gz"]
12
  ]
13
 
@@ -15,6 +317,8 @@ example_files = [
15
  coronal_slices = []
16
  last_probabilities = []
17
  prob_df = pd.DataFrame()
 
 
18
 
19
  # Target cell types
20
  cell_types = [
@@ -27,27 +331,77 @@ cell_types = [
27
 
28
  actual_ids = [30,52,71,91,104,109,118,126,131,137,141,164,178,182,197,208,218,226,232,242,244,248,256,262,270,282,293,297,308,323,339,344,350,355,364,372,379,389,395,401,410,415,418,424,429,434,440,444,469,479,487,509]
29
  gallery_ids = [5,6,8,9,10,11,12,13,14,15,16,17,18,19,24,25,26,27,28,29,30,31,32,33,35,36,37,38,39,40,42,43,44,45,46,47,48,49,50,51,52,54,55,56,57,58,59,60,61,62,64,66,67]
30
- gallery_ids.reverse()
31
- def load_nifti(file):
32
- global coronal_slices
33
- img = nib.load(file.name)
34
- vol = img.get_fdata()
35
- coronal_slices = [vol[i, :, :] for i in range(vol.shape[0])]
36
- mid_index = vol.shape[0] // 2
37
- slice_img = Image.fromarray((coronal_slices[mid_index] / np.max(coronal_slices[mid_index]) * 255).astype(np.uint8))
38
- gallery_images = load_gallery_images()
39
- return slice_img, gr.update(visible=True, maximum=len(coronal_slices)-1, value=mid_index), gallery_images, gr.update(visible=True), gr.update(visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def update_slice(index):
42
  if not coronal_slices:
43
  return None, None, None
44
  slice_img = Image.fromarray((coronal_slices[index] / np.max(coronal_slices[index]) * 255).astype(np.uint8))
45
-
46
- # Find closest gallery index
47
- closest_idx = min(range(len(actual_ids)), key=lambda i: abs(actual_ids[i] - index))
48
- gallery_selection = gr.update(selected_index=closest_idx)
49
-
50
- # Slight variation to probabilities
51
  if last_probabilities:
52
  noise = np.random.normal(0, 0.01, size=len(last_probabilities))
53
  new_probs = np.clip(np.array(last_probabilities) + noise, 0, None)
@@ -56,9 +410,10 @@ def update_slice(index):
56
  new_probs = []
57
  return slice_img, plot_probabilities(new_probs), gallery_selection
58
 
 
59
  def load_gallery_images():
60
- images = []
61
  folder = "Overlapped_updated"
 
62
  if os.path.exists(folder):
63
  for fname in sorted(os.listdir(folder)):
64
  if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
@@ -78,8 +433,6 @@ def plot_probabilities(probabilities):
78
  return None
79
  prob_df = pd.DataFrame({"labels": cell_types, "values": probabilities})
80
  prob_df.to_csv('Cell_types_predictions.csv', index=False)
81
- print("CSV saved!!")
82
-
83
  return prob_df
84
 
85
  def run_mapping():
@@ -87,41 +440,89 @@ def run_mapping():
87
  last_probabilities = generate_random_probabilities()
88
  return plot_probabilities(last_probabilities), gr.update(visible=True)
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def download_csv():
91
- # prob_df.to_csv('Cell_types_predictions.csv', index=False)
92
  return 'Cell_types_predictions.csv'
93
 
94
 
 
 
 
 
 
 
 
 
 
 
95
  with gr.Blocks() as demo:
96
  gr.Markdown("# Map My Sections")
97
 
98
- gr.Markdown("### Step 1: Upload your CCF registered data")
99
- nifti_file = gr.File(label="File Upload")
100
- gr.Examples(
101
- examples=example_files,
102
- inputs=nifti_file,
103
- label="Try one of our example samples"
104
- )
 
105
 
106
  with gr.Row(visible=False) as slice_row:
107
  with gr.Column(scale=1):
108
  gr.Markdown("### Step 2: Visualizing your uploaded sample")
109
- image_display = gr.Image(height = 400)
110
- slice_slider = gr.Slider(minimum=0, maximum=0, value=0, step=1, label="Browse Slices", visible=False)
111
  with gr.Column(scale=1):
112
  gr.Markdown("### Step 3: Visualizing Allen Brain Cell Types Atlas")
113
- gallery = gr.Gallery(label="ABC Atlas", height = 400)
114
  gr.Markdown("**Step 4: Run cell type mapping**")
115
- run_button = gr.Button("Run Mapping")
 
 
116
 
117
  with gr.Column(visible=False) as plot_row:
118
  gr.Markdown("### Step 5: Quantitative results of the mapping model.")
119
- prob_plot = gr.BarPlot(prob_df, x="labels", y="values", title="Cell Type Probabilities", scroll_to_output=True, x_label_angle=-90, height = 400)
120
  gr.Markdown("### Step 6: Download Results.")
121
  download_button = gr.DownloadButton(label="Download Results", value='Cell_types_predictions.csv')
122
 
123
- nifti_file.change(load_nifti, inputs=nifti_file, outputs=[image_display, slice_slider, gallery, slice_row, plot_row])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  slice_slider.change(update_slice, inputs=slice_slider, outputs=[image_display, prob_plot, gallery])
125
  run_button.click(run_mapping, outputs=[prob_plot, plot_row])
 
126
 
127
- demo.launch()
 
4
  import os
5
  from PIL import Image
6
  import pandas as pd
7
+ import nrrd
8
+ import ants
9
+ from natsort import natsorted
10
+ from scipy.ndimage import zoom, rotate
11
+ import matplotlib.pyplot as plt
12
+ import torch
13
+ import torch.nn as nn
14
+ import torchvision.models as models
15
+ import torchvision.transforms as transforms
16
+ from sklearn.metrics.pairwise import cosine_similarity
17
+ import cv2
18
+
19
+ def square_padd(original_data, square_size=(120,152, 184), order = 1):
20
+ # e.g. square_size = 256 by default
21
+ # takes a raw image as input
22
+ # returns a square (padded) image as output
23
+
24
+ # order = [int(x-1) for x in ss.rankdata(original_data.shape)]
25
+ # # print(order)
26
+ # data = original_data.transpose(order)
27
+ data= original_data
28
+ # print(original_data.shape)
29
+ # print(data.shape)
30
+ if data.shape[1]>data.shape[0] and data.shape[1]>data.shape[2]: # width>height
31
+ scale_percent = (square_size[1]/data.shape[1])*100
32
+ # print("dim1")
33
+ elif data.shape[2]>data.shape[0] and data.shape[2]>data.shape[1]: # width>height
34
+ scale_percent = (square_size[2]/data.shape[2])*100
35
+ # print("dim2")
36
+ else: # width<height
37
+ scale_percent = (square_size[0]/data.shape[0])*100
38
+ scale_percent = int(scale_percent)
39
+ # print(scale_percent)
40
+ width = int(data.shape[0] * scale_percent / 100); height = int(data.shape[1] * scale_percent / 100); depth = int(data.shape[2] * scale_percent / 100);
41
+ dim = (width, height, depth)
42
+ # print(dim)
43
+ zoomFactors = [square_size_axis/float(data_shape) for data_shape, square_size_axis in zip(data.shape, square_size)]
44
+ sect_mask = zoom(data,zoom = zoomFactors, order = order, )
45
+ # sect_mask = zoom(data,(scale_percent/100, scale_percent/100, scale_percent/100), order = order, )
46
+ # sect_mask = cv2.resize(data, dim, interpolation = cv2.INTER_AREA)
47
+ sect_padd = (np.ones(square_size))*data[0,0,0]
48
+ sect_padd[int((square_size[0]-np.shape(sect_mask)[0])/2):int((square_size[0]-np.shape(sect_mask)[0])/2)+np.shape(sect_mask)[0],
49
+ int((square_size[1]-np.shape(sect_mask)[1])/2):int((square_size[1]-np.shape(sect_mask)[1])/2)+np.shape(sect_mask)[1],
50
+ int((square_size[2]-np.shape(sect_mask)[2])/2):int((square_size[2]-np.shape(sect_mask)[2])/2)+np.shape(sect_mask)[2]] = sect_mask
51
+ return sect_padd
52
+
53
+ def square_padding_RGB(single_RGB,square_size=256):
54
+ # e.g. square_size = 256 by default
55
+ # takes a raw image as input
56
+ # returns a square (padded) image as output
57
+ # input: 2D image
58
+ # output: 2D resized padded image
59
+ # example: BNI images, HMS data
60
+ if single_RGB.shape[1]>single_RGB.shape[0]: # width>height
61
+ scale_percent = (square_size/single_RGB.shape[1])*100
62
+ else: # width<height
63
+ scale_percent = (square_size/single_RGB.shape[0])*100
64
+ width = int(single_RGB.shape[1] * scale_percent / 100); height = int(single_RGB.shape[0] * scale_percent / 100); dim = (width, height)
65
+ sect_mask = cv2.resize(single_RGB, dim, interpolation = cv2.INTER_AREA)
66
+ sect_padd = (np.ones((square_size,square_size,3)))*np.mean(single_RGB[:10,:10])
67
+ sect_padd[int((square_size-np.shape(sect_mask)[0])/2):int((square_size-np.shape(sect_mask)[0])/2)+np.shape(sect_mask)[0],
68
+ int((square_size-np.shape(sect_mask)[1])/2):int((square_size-np.shape(sect_mask)[1])/2)+np.shape(sect_mask)[1],:] = sect_mask
69
+ return sect_padd
70
+
71
+ def square_padding(single_gray,square_size=256):
72
+ # e.g. square_size = 256 by default
73
+ # takes a raw image as input
74
+ # returns a square (padded) image as output
75
+ # input: 2D image
76
+ # output: 2D resized padded image
77
+ # example: BNI images, HMS data
78
+ if len(np.shape(single_gray))>2:
79
+ return square_padding_RGB(single_gray[:,:,:3])
80
+ else:
81
+ # print("Single gray shape:", np.shape(single_gray))
82
+ if single_gray.shape[1]>single_gray.shape[0]: # width>height
83
+ scale_percent = (square_size/single_gray.shape[1])*100
84
+ else: # width<height
85
+ scale_percent = (square_size/single_gray.shape[0])*100
86
+ width = int(single_gray.shape[1] * scale_percent / 100); height = int(single_gray.shape[0] * scale_percent / 100); dim = (width, height)
87
+ # print("Dim::", dim)
88
+ sect_mask = cv2.resize(single_gray, dim, interpolation = cv2.INTER_AREA)
89
+ sect_padd = (np.zeros((square_size,square_size)))*single_gray[-20,-20]#find a better solution for single_gray[100,-100]
90
+ sect_padd[int((square_size-np.shape(sect_mask)[0])/2):int((square_size-np.shape(sect_mask)[0])/2)+np.shape(sect_mask)[0],
91
+ int((square_size-np.shape(sect_mask)[1])/2):int((square_size-np.shape(sect_mask)[1])/2)+np.shape(sect_mask)[1]] = sect_mask
92
+ return sect_padd
93
+
94
+
95
+ def affine_reg(fixed_image,moving_image,gauss_param=100):
96
+ # this function takes fixed and moving images as input and return affine transformation matrix
97
+ # fixed/moving images can be 2D/3D
98
+ # todo: add an option as flag to save the transformation matrix and displacement fields at the desired location to be able to apply the transforms later
99
+ mytx = ants.registration(fixed=fixed_image,
100
+ moving=moving_image,
101
+ type_of_transform='Affine',
102
+ reg_iterations = (gauss_param,gauss_param,gauss_param,gauss_param))
103
+ print('affine registration completed')
104
+ return mytx
105
+
106
+
107
+ def nonrigid_reg(fixed_image,mytx,type_of_transform='SyN',grad_step=0.25,reg_iterations=(50,50,50, ),flow_sigma=9,total_sigma=0.2):
108
+ # this function takes fixed image and affined tx matrix as input and return non-rigid transformation matrix
109
+ # fixed/moving images can be 2D/3D
110
+ # type of transform selection: https://antspy.readthedocs.io/en/latest/registration.html
111
+ # todo: scale the function to incorporate the extended parameters for type_of_transform
112
+ # todo: scale the function to incorporate the affine+non-rigid simultaneously in case of SyNRA
113
+
114
+ transform_type = {'SyN':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma},
115
+ 'SyNRA':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma}}
116
+
117
+ mytx_non_rigid = ants.registration(fixed = fixed_image,
118
+ moving=mytx['warpedmovout'],
119
+ type_of_transform=type_of_transform,
120
+ grad_step=transform_type[type_of_transform]['grad_step'],
121
+ reg_iterations=transform_type[type_of_transform]['reg_iterations'],
122
+ flow_sigma=transform_type[type_of_transform]['flow_sigma'],
123
+ total_sigma=transform_type[type_of_transform]['total_sigma'])
124
+
125
+ print('non-rigid registration completed')
126
+ return mytx_non_rigid
127
+
128
+ def affine_reg(fixed_image,moving_image,gauss_param=100):
129
+ # this function takes fixed and moving images as input and return affine transformation matrix
130
+ # fixed/moving images can be 2D/3D
131
+ # todo: add an option as flag to save the transformation matrix and displacement fields at the desired location to be able to apply the transforms later
132
+ mytx = ants.registration(fixed=fixed_image,
133
+ moving=moving_image,
134
+ type_of_transform='Affine',
135
+ reg_iterations = (gauss_param,gauss_param,gauss_param,gauss_param))
136
+ print('affine registration completed')
137
+ return mytx
138
+
139
+
140
+ def nonrigid_reg(fixed_image,mytx,type_of_transform='SyN',grad_step=0.25,reg_iterations=(50,50,50, ),flow_sigma=9,total_sigma=0.2):
141
+ # this function takes fixed image and affined tx matrix as input and return non-rigid transformation matrix
142
+ # fixed/moving images can be 2D/3D
143
+ # type of transform selection: https://antspy.readthedocs.io/en/latest/registration.html
144
+ # todo: scale the function to incorporate the extended parameters for type_of_transform
145
+ # todo: scale the function to incorporate the affine+non-rigid simultaneously in case of SyNRA
146
+
147
+ transform_type = {'SyN':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma},
148
+ 'SyNRA':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma}}
149
+
150
+ mytx_non_rigid = ants.registration(fixed = fixed_image,
151
+ moving=mytx['warpedmovout'],
152
+ type_of_transform=type_of_transform,
153
+ grad_step=transform_type[type_of_transform]['grad_step'],
154
+ reg_iterations=transform_type[type_of_transform]['reg_iterations'],
155
+ flow_sigma=transform_type[type_of_transform]['flow_sigma'],
156
+ total_sigma=transform_type[type_of_transform]['total_sigma'])
157
+
158
+ print('non-rigid registration completed')
159
+ return mytx_non_rigid
160
+
161
+
162
+
163
+ def run_3D_registration(user_section, ):
164
+ global allen_atlas_ccf, allen_template_ccf
165
+ template_atlas = allen_atlas_ccf
166
+ template_section = allen_template_ccf
167
+ template_atlas = np.uint16(template_atlas*255)
168
+ user_section = square_padd(user_section, (60, 76, 92))
169
+
170
+ template_atlas = square_padd(template_atlas, user_section.shape)
171
+ template_section = square_padd(template_section, user_section.shape)
172
+
173
+ fixed_image = ants.from_numpy(user_section)
174
+ moving_atlas_ants = ants.from_numpy(template_atlas)
175
+ moving_image = ants.from_numpy(template_section)
176
+
177
+ mytx = affine_reg(fixed_image,moving_image)
178
+ mytx_non_rigid = nonrigid_reg(fixed_image,mytx)
179
+ affined_fixed_atlas = ants.apply_transforms(fixed=fixed_image,
180
+ moving=moving_image,
181
+ transformlist=mytx['fwdtransforms'],
182
+ interpolator='nearestNeighbor')
183
+ nonrigid_fixed_atlas = ants.apply_transforms(fixed=fixed_image,
184
+ moving=affined_fixed_atlas,
185
+ transformlist=mytx_non_rigid['fwdtransforms'],
186
+ interpolator='nearestNeighbor')
187
+ gallery_images = load_gallery_images()
188
+ transformed_images = []
189
+ if not(os.path.exists("Overlaped_registered")):
190
+ os.mkdir("Overlaped_registered")
191
+ registered = nonrigid_fixed_atlas.numpy()/255
192
+ for id in list(range((registered.shape[0]//2)-15, (registered.shape[0]//2)+15, 2)):
193
+ print(id)
194
+ plt.imsave(f'Overlaped_registered/{id}.png',registered[id, :, :], cmap = 'gray' )
195
+ transformed_images.append(f'Overlaped_registered/{id}.png')
196
+
197
+ return transformed_images
198
+
199
+
200
+ def run_2D_registration(user_section, slice_idx):
201
+ global allen_atlas_ccf, allen_template_ccf, gallery_selected_data
202
+ template_atlas = allen_atlas_ccf
203
+ template_section = allen_template_ccf
204
+
205
+ template_atlas = allen_atlas_ccf[slice_idx,:,:]
206
+ template_section = allen_template_ccf[slice_idx,:,:]
207
+ # colored_atlas = colored_atlas[slice_idx,:,:]
208
+ print(np.shape(template_atlas), np.shape(template_section))
209
+ user_section = square_padding(user_section)
210
+
211
+ template_atlas = np.uint16(template_atlas*255)
212
+ template_atlas = square_padding(template_atlas)
213
+ template_section = square_padding(template_section)
214
+
215
+ fixed_image = ants.from_numpy(user_section)
216
+ moving_atlas_ants = ants.from_numpy(template_atlas)
217
+ moving_image = ants.from_numpy(template_section)
218
+
219
+ mytx = affine_reg(fixed_image,moving_image)
220
+ mytx_non_rigid = nonrigid_reg(fixed_image,mytx)
221
+ gallery_imgs = natsorted(load_gallery_images())
222
+ moving_gallery_img = ants.from_numpy(square_padding(plt.imread(gallery_imgs[gallery_selected_data])))
223
+ affined_fixed_atlas = ants.apply_transforms(fixed=fixed_image,
224
+ moving=moving_image,
225
+ transformlist=mytx['fwdtransforms'],
226
+ interpolator='nearestNeighbor')
227
+ nonrigid_fixed_atlas = ants.apply_transforms(fixed=fixed_image,
228
+ moving=affined_fixed_atlas,
229
+ transformlist=mytx_non_rigid['fwdtransforms'],
230
+ interpolator='nearestNeighbor')
231
+ gallery_images = load_gallery_images()
232
+ transformed_images = []
233
+ if not(os.path.exists("Overlaped_registered")):
234
+ os.mkdir("Overlaped_registered")
235
+ plt.imsave(f'Overlaped_registered/registered_slice.png',nonrigid_fixed_atlas.numpy()/255, cmap = 'gray')
236
+
237
+ return ['Overlaped_registered/registered_slice.png']
238
+
239
+
240
+ def embeddings_classifier(user_section, atlas_embeddings,atlas_labels):
241
+ class SliceEncoder(nn.Module):
242
+ def __init__(self):
243
+ super(SliceEncoder, self).__init__()
244
+ base = models.resnet18(pretrained=True)
245
+ self.backbone = nn.Sequential(*list(base.children())[:-1]) # Remove final FC layer
246
+
247
+ def forward(self, x):
248
+ x = self.backbone(x) # Output shape: (B, 512, 1, 1)
249
+ return x.view(x.size(0), -1) # Flatten to (B, 512)
250
+
251
+ # Transform
252
+ transform = transforms.Compose([
253
+ transforms.Resize((224, 224)),
254
+ transforms.ToTensor(),
255
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
256
+ std=[0.229, 0.224, 0.225]),
257
+ ])
258
+
259
+ # Feature extraction utility
260
+ def extract_embedding(img_array, encoder, transform):
261
+ img = Image.fromarray(((img_array) * 255).astype(np.uint8)).convert('RGB')
262
+ img_tensor = transform(img).unsqueeze(0).to(device)
263
+ with torch.no_grad():
264
+ embedding = encoder(img_tensor)
265
+ return embedding.cpu().numpy().flatten()
266
+
267
+ # Prepare device and model
268
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
269
+ encoder = SliceEncoder().to(device).eval()
270
+
271
+ # Precompute atlas embeddings
272
+
273
+
274
+ query_emb = extract_embedding(user_section, encoder, transform).reshape(1, -1)
275
+ sims = cosine_similarity(query_emb, atlas_embeddings)[0]
276
+
277
+ pred_idx = np.argmax(sims)
278
+ pred_gt = atlas_labels[pred_idx]
279
+
280
+ return int(pred_gt)
281
+
282
+
283
+ def gray_scale(image):
284
+ # input: a 2D RGB image (x,y,z)
285
+ # output: a grayscale image (x,y)
286
+ # todo: fix the depth issue of pixels
287
+ if len(np.shape(image))>2:
288
+ return cv2.cvtColor(image[:,:,:3], cv2.COLOR_RGB2GRAY)
289
+ else:
290
+ return image
291
+
292
+ def atlas_slice_prediction(user_section, axis = 'coronal'):
293
+
294
+ user_section = gray_scale(square_padding(gray_scale(user_section)))
295
+ user_section = gray_scale(user_section)
296
+ user_section = square_padding(user_section, 224)
297
+ user_section = (user_section - np.min(user_section))/((np.max(user_section) - np.min(user_section)))
298
+ print("Loading model")
299
+ atlas_embeddings = np.load(f"registration/template_id_checkpoints/atlas_embeddings_{axis}.npy")
300
+ atlas_labels = np.load(f"registration/template_id_checkpoints/atlas_labels_{axis}.npy")
301
+ idx = embeddings_classifier(user_section, atlas_embeddings,atlas_labels)
302
+
303
+ return idx
304
+
305
+
306
+
307
+
308
+
309
 
310
  example_files = [
311
+ ["./resampled_green_25.nii.gz", "CCF registered Sample", "3D"],
312
+ ["./Brain_1.png", "Custom Sample", "2D"],
313
  # ["examples/sample3.nii.gz"]
314
  ]
315
 
 
317
  coronal_slices = []
318
  last_probabilities = []
319
  prob_df = pd.DataFrame()
320
+ vol = None
321
+ slice_idx = None
322
 
323
  # Target cell types
324
  cell_types = [
 
331
 
332
  actual_ids = [30,52,71,91,104,109,118,126,131,137,141,164,178,182,197,208,218,226,232,242,244,248,256,262,270,282,293,297,308,323,339,344,350,355,364,372,379,389,395,401,410,415,418,424,429,434,440,444,469,479,487,509]
333
  gallery_ids = [5,6,8,9,10,11,12,13,14,15,16,17,18,19,24,25,26,27,28,29,30,31,32,33,35,36,37,38,39,40,42,43,44,45,46,47,48,49,50,51,52,54,55,56,57,58,59,60,61,62,64,66,67]
334
+ # gallery_ids.reverse()
335
+
336
+ allen_atlas_ccf, header = nrrd.read('./registration/annotation_25.nrrd')
337
+ allen_template_ccf, _ = nrrd.read("./registration/average_template_25.nrrd")
338
+ # colored_atlas,_ = nrrd.read('./registration/colored_atlas_turbo.nrrd')
339
+ gallery_selected_data = None
340
+ def load_nifti_or_png(file, sample_type, data_type):
341
+ global coronal_slices, vol, slice_idx, gallery_selected_data
342
+ if file.name.endswith(".nii") or file.name.endswith(".nii.gz"):
343
+ img = nib.load(file.name)
344
+ vol = img.get_fdata()
345
+ coronal_slices = [vol[i, :, :] for i in range(vol.shape[0])]
346
+ if data_type == "2D":
347
+ mid_index = vol.shape[0] // 2
348
+ slice_img = Image.fromarray((coronal_slices[mid_index] / np.max(coronal_slices[mid_index]) * 255).astype(np.uint8))
349
+ gallery_images = load_gallery_images()
350
+ return (
351
+ slice_img,
352
+ gr.update(visible=False),
353
+ gallery_images,
354
+ gr.update(visible=True),
355
+ gr.update(visible=True),
356
+ gr.update(visible=(sample_type == "Custom Sample"))
357
+ )
358
+ else: # 3D with actual_ids only
359
+ coronal_slices = [vol[i, :, :] for i in actual_ids]
360
+ idx = len(actual_ids) // 2 # Mid of actual_ids
361
+ slice_img = Image.fromarray((coronal_slices[idx] / np.max(coronal_slices[idx]) * 255).astype(np.uint8))
362
+ gallery_images = load_gallery_images()
363
+ gallery_images = natsorted(gallery_images)
364
+ return (
365
+ slice_img,
366
+ gr.update(visible=True, minimum=0, maximum=len(coronal_slices)-1, value=idx),
367
+ gallery_images,
368
+ gr.update(visible=True),
369
+ gr.update(visible=True),
370
+ gr.update(visible=(sample_type == "Custom Sample"))
371
+ )
372
+
373
+
374
+ else:
375
+ img = Image.open(file.name).convert("L")
376
+ vol = np.array(img)
377
+ coronal_slices = [np.array(img)]
378
+ gallery_images = natsorted(load_gallery_images())
379
+ idx = atlas_slice_prediction(np.array(img))
380
+ slice_idx = idx
381
+ closest_actual_idx = min(actual_ids, key=lambda x: abs(x - idx))
382
+ gallery_index = actual_ids.index(closest_actual_idx)
383
+ print(gallery_index, len(actual_ids) -(gallery_index))
384
+ gallery_selected_data = len(actual_ids) -(gallery_index)
385
+
386
+ return (
387
+ img,
388
+ gr.update(visible=False),
389
+ gr.update(selected_index=len(actual_ids) -(gallery_index) if gallery_index < len(gallery_ids) else 0, visible = True),
390
+ # gr.update(value=gallery_images, selected_index=len(actual_ids) -(gallery_index)), # gallery
391
+ gr.update(visible=True),
392
+ gr.update(visible=True),
393
+ gr.update(visible=(sample_type == "Custom Sample"))
394
+ )
395
+
396
+
397
+
398
 
399
  def update_slice(index):
400
  if not coronal_slices:
401
  return None, None, None
402
  slice_img = Image.fromarray((coronal_slices[index] / np.max(coronal_slices[index]) * 255).astype(np.uint8))
403
+ gallery_selection = gr.update(selected_index=len(gallery_ids) - index if index < len(gallery_ids) else 0)
404
+
 
 
 
 
405
  if last_probabilities:
406
  noise = np.random.normal(0, 0.01, size=len(last_probabilities))
407
  new_probs = np.clip(np.array(last_probabilities) + noise, 0, None)
 
410
  new_probs = []
411
  return slice_img, plot_probabilities(new_probs), gallery_selection
412
 
413
+
414
  def load_gallery_images():
 
415
  folder = "Overlapped_updated"
416
+ images = []
417
  if os.path.exists(folder):
418
  for fname in sorted(os.listdir(folder)):
419
  if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
 
433
  return None
434
  prob_df = pd.DataFrame({"labels": cell_types, "values": probabilities})
435
  prob_df.to_csv('Cell_types_predictions.csv', index=False)
 
 
436
  return prob_df
437
 
438
  def run_mapping():
 
440
  last_probabilities = generate_random_probabilities()
441
  return plot_probabilities(last_probabilities), gr.update(visible=True)
442
 
443
+ def run_registration(data_type, selected_idx):
444
+ global vol, slice_idx
445
+ print("Running registration logic here..., Vol shape::", vol.shape)
446
+ if data_type == "3D":
447
+ gallery_images = run_3D_registration(vol)
448
+
449
+ else:
450
+ gallery_images = run_2D_registration(vol, slice_idx)
451
+ return gallery_images
452
+
453
+
454
+
455
+
456
+ return "Registration complete!"
457
+
458
  def download_csv():
 
459
  return 'Cell_types_predictions.csv'
460
 
461
 
462
+ def handle_data_type_change(dt):
463
+ if dt == "2D":
464
+ return gr.update(visible=False)
465
+ else:
466
+ return gr.update(visible=True, minimum=0, maximum=len(actual_ids)-1, value=len(actual_ids)//2)
467
+
468
+ def on_select(evt: gr.SelectData):
469
+ gallery_selected_data = evt.index
470
+
471
+ gallery_images = natsorted(load_gallery_images())
472
  with gr.Blocks() as demo:
473
  gr.Markdown("# Map My Sections")
474
 
475
+ gr.Markdown("### Step 1: Upload your sample and choose type")
476
+ with gr.Row():
477
+ nifti_file = gr.File(label="File Upload")
478
+ with gr.Row():
479
+ sample_type = gr.Dropdown(choices=["CCF registered Sample", "Custom Sample"], value="CCF registered Sample", label="Sample Type")
480
+ data_type = gr.Radio(choices=["2D", "3D"], value="3D", label="Data Type")
481
+
482
+ gr.Examples(examples=example_files, inputs=[nifti_file, sample_type, data_type], label="Try one of our example samples")
483
 
484
  with gr.Row(visible=False) as slice_row:
485
  with gr.Column(scale=1):
486
  gr.Markdown("### Step 2: Visualizing your uploaded sample")
487
+ image_display = gr.Image(height=450)
488
+ slice_slider = gr.Slider(minimum=0, maximum=0, value=0, step=1, label="Slices", visible=False)
489
  with gr.Column(scale=1):
490
  gr.Markdown("### Step 3: Visualizing Allen Brain Cell Types Atlas")
491
+ gallery = gr.Gallery(label="ABC Atlas", value = gallery_images,)
492
  gr.Markdown("**Step 4: Run cell type mapping**")
493
+ with gr.Row():
494
+ run_button = gr.Button("Run Mapping")
495
+ reg_button = gr.Button("Run Registration", visible=False)
496
 
497
  with gr.Column(visible=False) as plot_row:
498
  gr.Markdown("### Step 5: Quantitative results of the mapping model.")
499
+ prob_plot = gr.BarPlot(prob_df, x="labels", y="values", title="Cell Type Probabilities", scroll_to_output=True, x_label_angle=-90, height=400)
500
  gr.Markdown("### Step 6: Download Results.")
501
  download_button = gr.DownloadButton(label="Download Results", value='Cell_types_predictions.csv')
502
 
503
+ nifti_file.change(
504
+ load_nifti_or_png,
505
+ inputs=[nifti_file, sample_type, data_type],
506
+ outputs=[image_display, slice_slider, gallery, slice_row, plot_row, reg_button]
507
+ )
508
+
509
+ sample_type.change(
510
+ lambda s: (gr.update(visible=True), gr.update(visible=(s == "Custom Sample"))),
511
+ inputs=sample_type,
512
+ outputs=[slice_row, reg_button]
513
+ )
514
+
515
+
516
+ data_type.change(
517
+ handle_data_type_change,
518
+ inputs=data_type,
519
+ outputs=slice_slider
520
+ )
521
+
522
+ gallery.select(on_select, None, None)
523
+
524
  slice_slider.change(update_slice, inputs=slice_slider, outputs=[image_display, prob_plot, gallery])
525
  run_button.click(run_mapping, outputs=[prob_plot, plot_row])
526
+ reg_button.click(run_registration,inputs = [data_type], outputs=[gallery])
527
 
528
+ demo.launch()
registration/CCFv3OntologyStructure_u16.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:377bfc25589dcd7194eead21f9cc5a788b6aafe5ba026351b4d73d01c30a0a1d
3
+ size 138855
registration/annotation_25.nrrd ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c620cbcc562183e4dcd40250d440130501781f74b41de35b1c1bdabace290c42
3
+ size 4035363
registration/atlas_embeddings_coronal.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59f66f53b266a0a4581bd6517a8304b74e7dc426f0365d5de70fee5b0766a9e0
3
+ size 4325504
registration/atlas_labels_coronal.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7e28eaf41f3134430a72108b750d628fb42fcd24e615a4d3e4763d6c1f5c2b9
3
+ size 17024
registration/average_template_25.nrrd ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4a2b483e842b4c8c1b5452d940ea59e14bc1ebaa38fe6a9c3bacac6db2a8f4b
3
+ size 32998960
requirements.txt CHANGED
@@ -2,4 +2,13 @@ gradio
2
  nibabel
3
  numpy
4
  pillow
5
- pandas
 
 
 
 
 
 
 
 
 
 
2
  nibabel
3
  numpy
4
  pillow
5
+ pandas
6
+ pynrrd
7
+ antspyx
8
+ natsort
9
+ scipy
10
+ matplotlib
11
+ torch
12
+ torchvision
13
+ scikit-learn
14
+ opencv-python