TibbtechUser commited on
Commit
8b2712e
·
verified ·
1 Parent(s): 454f522

updated app.py

Browse files
Files changed (1) hide show
  1. registration/app.py +528 -0
registration/app.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import nibabel as nib
3
+ import numpy as np
4
+ import os
5
+ from PIL import Image
6
+ import pandas as pd
7
+ import nrrd
8
+ import ants
9
+ from natsort import natsorted
10
+ from scipy.ndimage import zoom, rotate
11
+ import matplotlib.pyplot as plt
12
+ import torch
13
+ import torch.nn as nn
14
+ import torchvision.models as models
15
+ import torchvision.transforms as transforms
16
+ from sklearn.metrics.pairwise import cosine_similarity
17
+ import cv2
18
+
19
+ def square_padd(original_data, square_size=(120,152, 184), order = 1):
20
+ # e.g. square_size = 256 by default
21
+ # takes a raw image as input
22
+ # returns a square (padded) image as output
23
+
24
+ # order = [int(x-1) for x in ss.rankdata(original_data.shape)]
25
+ # # print(order)
26
+ # data = original_data.transpose(order)
27
+ data= original_data
28
+ # print(original_data.shape)
29
+ # print(data.shape)
30
+ if data.shape[1]>data.shape[0] and data.shape[1]>data.shape[2]: # width>height
31
+ scale_percent = (square_size[1]/data.shape[1])*100
32
+ # print("dim1")
33
+ elif data.shape[2]>data.shape[0] and data.shape[2]>data.shape[1]: # width>height
34
+ scale_percent = (square_size[2]/data.shape[2])*100
35
+ # print("dim2")
36
+ else: # width<height
37
+ scale_percent = (square_size[0]/data.shape[0])*100
38
+ scale_percent = int(scale_percent)
39
+ # print(scale_percent)
40
+ width = int(data.shape[0] * scale_percent / 100); height = int(data.shape[1] * scale_percent / 100); depth = int(data.shape[2] * scale_percent / 100);
41
+ dim = (width, height, depth)
42
+ # print(dim)
43
+ zoomFactors = [square_size_axis/float(data_shape) for data_shape, square_size_axis in zip(data.shape, square_size)]
44
+ sect_mask = zoom(data,zoom = zoomFactors, order = order, )
45
+ # sect_mask = zoom(data,(scale_percent/100, scale_percent/100, scale_percent/100), order = order, )
46
+ # sect_mask = cv2.resize(data, dim, interpolation = cv2.INTER_AREA)
47
+ sect_padd = (np.ones(square_size))*data[0,0,0]
48
+ sect_padd[int((square_size[0]-np.shape(sect_mask)[0])/2):int((square_size[0]-np.shape(sect_mask)[0])/2)+np.shape(sect_mask)[0],
49
+ int((square_size[1]-np.shape(sect_mask)[1])/2):int((square_size[1]-np.shape(sect_mask)[1])/2)+np.shape(sect_mask)[1],
50
+ int((square_size[2]-np.shape(sect_mask)[2])/2):int((square_size[2]-np.shape(sect_mask)[2])/2)+np.shape(sect_mask)[2]] = sect_mask
51
+ return sect_padd
52
+
53
+ def square_padding_RGB(single_RGB,square_size=256):
54
+ # e.g. square_size = 256 by default
55
+ # takes a raw image as input
56
+ # returns a square (padded) image as output
57
+ # input: 2D image
58
+ # output: 2D resized padded image
59
+ # example: BNI images, HMS data
60
+ if single_RGB.shape[1]>single_RGB.shape[0]: # width>height
61
+ scale_percent = (square_size/single_RGB.shape[1])*100
62
+ else: # width<height
63
+ scale_percent = (square_size/single_RGB.shape[0])*100
64
+ width = int(single_RGB.shape[1] * scale_percent / 100); height = int(single_RGB.shape[0] * scale_percent / 100); dim = (width, height)
65
+ sect_mask = cv2.resize(single_RGB, dim, interpolation = cv2.INTER_AREA)
66
+ sect_padd = (np.ones((square_size,square_size,3)))*np.mean(single_RGB[:10,:10])
67
+ sect_padd[int((square_size-np.shape(sect_mask)[0])/2):int((square_size-np.shape(sect_mask)[0])/2)+np.shape(sect_mask)[0],
68
+ int((square_size-np.shape(sect_mask)[1])/2):int((square_size-np.shape(sect_mask)[1])/2)+np.shape(sect_mask)[1],:] = sect_mask
69
+ return sect_padd
70
+
71
+ def square_padding(single_gray,square_size=256):
72
+ # e.g. square_size = 256 by default
73
+ # takes a raw image as input
74
+ # returns a square (padded) image as output
75
+ # input: 2D image
76
+ # output: 2D resized padded image
77
+ # example: BNI images, HMS data
78
+ if len(np.shape(single_gray))>2:
79
+ return square_padding_RGB(single_gray[:,:,:3])
80
+ else:
81
+ # print("Single gray shape:", np.shape(single_gray))
82
+ if single_gray.shape[1]>single_gray.shape[0]: # width>height
83
+ scale_percent = (square_size/single_gray.shape[1])*100
84
+ else: # width<height
85
+ scale_percent = (square_size/single_gray.shape[0])*100
86
+ width = int(single_gray.shape[1] * scale_percent / 100); height = int(single_gray.shape[0] * scale_percent / 100); dim = (width, height)
87
+ # print("Dim::", dim)
88
+ sect_mask = cv2.resize(single_gray, dim, interpolation = cv2.INTER_AREA)
89
+ sect_padd = (np.zeros((square_size,square_size)))*single_gray[-20,-20]#find a better solution for single_gray[100,-100]
90
+ sect_padd[int((square_size-np.shape(sect_mask)[0])/2):int((square_size-np.shape(sect_mask)[0])/2)+np.shape(sect_mask)[0],
91
+ int((square_size-np.shape(sect_mask)[1])/2):int((square_size-np.shape(sect_mask)[1])/2)+np.shape(sect_mask)[1]] = sect_mask
92
+ return sect_padd
93
+
94
+
95
+ def affine_reg(fixed_image,moving_image,gauss_param=100):
96
+ # this function takes fixed and moving images as input and return affine transformation matrix
97
+ # fixed/moving images can be 2D/3D
98
+ # todo: add an option as flag to save the transformation matrix and displacement fields at the desired location to be able to apply the transforms later
99
+ mytx = ants.registration(fixed=fixed_image,
100
+ moving=moving_image,
101
+ type_of_transform='Affine',
102
+ reg_iterations = (gauss_param,gauss_param,gauss_param,gauss_param))
103
+ print('affine registration completed')
104
+ return mytx
105
+
106
+
107
+ def nonrigid_reg(fixed_image,mytx,type_of_transform='SyN',grad_step=0.25,reg_iterations=(50,50,50, ),flow_sigma=9,total_sigma=0.2):
108
+ # this function takes fixed image and affined tx matrix as input and return non-rigid transformation matrix
109
+ # fixed/moving images can be 2D/3D
110
+ # type of transform selection: https://antspy.readthedocs.io/en/latest/registration.html
111
+ # todo: scale the function to incorporate the extended parameters for type_of_transform
112
+ # todo: scale the function to incorporate the affine+non-rigid simultaneously in case of SyNRA
113
+
114
+ transform_type = {'SyN':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma},
115
+ 'SyNRA':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma}}
116
+
117
+ mytx_non_rigid = ants.registration(fixed = fixed_image,
118
+ moving=mytx['warpedmovout'],
119
+ type_of_transform=type_of_transform,
120
+ grad_step=transform_type[type_of_transform]['grad_step'],
121
+ reg_iterations=transform_type[type_of_transform]['reg_iterations'],
122
+ flow_sigma=transform_type[type_of_transform]['flow_sigma'],
123
+ total_sigma=transform_type[type_of_transform]['total_sigma'])
124
+
125
+ print('non-rigid registration completed')
126
+ return mytx_non_rigid
127
+
128
+ def affine_reg(fixed_image,moving_image,gauss_param=100):
129
+ # this function takes fixed and moving images as input and return affine transformation matrix
130
+ # fixed/moving images can be 2D/3D
131
+ # todo: add an option as flag to save the transformation matrix and displacement fields at the desired location to be able to apply the transforms later
132
+ mytx = ants.registration(fixed=fixed_image,
133
+ moving=moving_image,
134
+ type_of_transform='Affine',
135
+ reg_iterations = (gauss_param,gauss_param,gauss_param,gauss_param))
136
+ print('affine registration completed')
137
+ return mytx
138
+
139
+
140
+ def nonrigid_reg(fixed_image,mytx,type_of_transform='SyN',grad_step=0.25,reg_iterations=(50,50,50, ),flow_sigma=9,total_sigma=0.2):
141
+ # this function takes fixed image and affined tx matrix as input and return non-rigid transformation matrix
142
+ # fixed/moving images can be 2D/3D
143
+ # type of transform selection: https://antspy.readthedocs.io/en/latest/registration.html
144
+ # todo: scale the function to incorporate the extended parameters for type_of_transform
145
+ # todo: scale the function to incorporate the affine+non-rigid simultaneously in case of SyNRA
146
+
147
+ transform_type = {'SyN':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma},
148
+ 'SyNRA':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma}}
149
+
150
+ mytx_non_rigid = ants.registration(fixed = fixed_image,
151
+ moving=mytx['warpedmovout'],
152
+ type_of_transform=type_of_transform,
153
+ grad_step=transform_type[type_of_transform]['grad_step'],
154
+ reg_iterations=transform_type[type_of_transform]['reg_iterations'],
155
+ flow_sigma=transform_type[type_of_transform]['flow_sigma'],
156
+ total_sigma=transform_type[type_of_transform]['total_sigma'])
157
+
158
+ print('non-rigid registration completed')
159
+ return mytx_non_rigid
160
+
161
+
162
+
163
+ def run_3D_registration(user_section, ):
164
+ global allen_atlas_ccf, allen_template_ccf
165
+ template_atlas = allen_atlas_ccf
166
+ template_section = allen_template_ccf
167
+ template_atlas = np.uint16(template_atlas*255)
168
+ user_section = square_padd(user_section, (60, 76, 92))
169
+
170
+ template_atlas = square_padd(template_atlas, user_section.shape)
171
+ template_section = square_padd(template_section, user_section.shape)
172
+
173
+ fixed_image = ants.from_numpy(user_section)
174
+ moving_atlas_ants = ants.from_numpy(template_atlas)
175
+ moving_image = ants.from_numpy(template_section)
176
+
177
+ mytx = affine_reg(fixed_image,moving_image)
178
+ mytx_non_rigid = nonrigid_reg(fixed_image,mytx)
179
+ affined_fixed_atlas = ants.apply_transforms(fixed=fixed_image,
180
+ moving=moving_image,
181
+ transformlist=mytx['fwdtransforms'],
182
+ interpolator='nearestNeighbor')
183
+ nonrigid_fixed_atlas = ants.apply_transforms(fixed=fixed_image,
184
+ moving=affined_fixed_atlas,
185
+ transformlist=mytx_non_rigid['fwdtransforms'],
186
+ interpolator='nearestNeighbor')
187
+ gallery_images = load_gallery_images()
188
+ transformed_images = []
189
+ if not(os.path.exists("Overlaped_registered")):
190
+ os.mkdir("Overlaped_registered")
191
+ registered = nonrigid_fixed_atlas.numpy()/255
192
+ for id in list(range((registered.shape[0]//2)-15, (registered.shape[0]//2)+15, 2)):
193
+ print(id)
194
+ plt.imsave(f'Overlaped_registered/{id}.png',registered[id, :, :], cmap = 'gray' )
195
+ transformed_images.append(f'Overlaped_registered/{id}.png')
196
+
197
+ return transformed_images
198
+
199
+
200
+ def run_2D_registration(user_section, slice_idx):
201
+ global allen_atlas_ccf, allen_template_ccf, gallery_selected_data
202
+ template_atlas = allen_atlas_ccf
203
+ template_section = allen_template_ccf
204
+
205
+ template_atlas = allen_atlas_ccf[slice_idx,:,:]
206
+ template_section = allen_template_ccf[slice_idx,:,:]
207
+ # colored_atlas = colored_atlas[slice_idx,:,:]
208
+ print(np.shape(template_atlas), np.shape(template_section))
209
+ user_section = square_padding(user_section)
210
+
211
+ template_atlas = np.uint16(template_atlas*255)
212
+ template_atlas = square_padding(template_atlas)
213
+ template_section = square_padding(template_section)
214
+
215
+ fixed_image = ants.from_numpy(user_section)
216
+ moving_atlas_ants = ants.from_numpy(template_atlas)
217
+ moving_image = ants.from_numpy(template_section)
218
+
219
+ mytx = affine_reg(fixed_image,moving_image)
220
+ mytx_non_rigid = nonrigid_reg(fixed_image,mytx)
221
+ gallery_imgs = natsorted(load_gallery_images())
222
+ moving_gallery_img = ants.from_numpy(square_padding(plt.imread(gallery_imgs[gallery_selected_data])))
223
+ affined_fixed_atlas = ants.apply_transforms(fixed=fixed_image,
224
+ moving=moving_image,
225
+ transformlist=mytx['fwdtransforms'],
226
+ interpolator='nearestNeighbor')
227
+ nonrigid_fixed_atlas = ants.apply_transforms(fixed=fixed_image,
228
+ moving=affined_fixed_atlas,
229
+ transformlist=mytx_non_rigid['fwdtransforms'],
230
+ interpolator='nearestNeighbor')
231
+ gallery_images = load_gallery_images()
232
+ transformed_images = []
233
+ if not(os.path.exists("Overlaped_registered")):
234
+ os.mkdir("Overlaped_registered")
235
+ plt.imsave(f'Overlaped_registered/registered_slice.png',nonrigid_fixed_atlas.numpy()/255, cmap = 'gray')
236
+
237
+ return ['Overlaped_registered/registered_slice.png']
238
+
239
+
240
+ def embeddings_classifier(user_section, atlas_embeddings,atlas_labels):
241
+ class SliceEncoder(nn.Module):
242
+ def __init__(self):
243
+ super(SliceEncoder, self).__init__()
244
+ base = models.resnet18(pretrained=True)
245
+ self.backbone = nn.Sequential(*list(base.children())[:-1]) # Remove final FC layer
246
+
247
+ def forward(self, x):
248
+ x = self.backbone(x) # Output shape: (B, 512, 1, 1)
249
+ return x.view(x.size(0), -1) # Flatten to (B, 512)
250
+
251
+ # Transform
252
+ transform = transforms.Compose([
253
+ transforms.Resize((224, 224)),
254
+ transforms.ToTensor(),
255
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
256
+ std=[0.229, 0.224, 0.225]),
257
+ ])
258
+
259
+ # Feature extraction utility
260
+ def extract_embedding(img_array, encoder, transform):
261
+ img = Image.fromarray(((img_array) * 255).astype(np.uint8)).convert('RGB')
262
+ img_tensor = transform(img).unsqueeze(0).to(device)
263
+ with torch.no_grad():
264
+ embedding = encoder(img_tensor)
265
+ return embedding.cpu().numpy().flatten()
266
+
267
+ # Prepare device and model
268
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
269
+ encoder = SliceEncoder().to(device).eval()
270
+
271
+ # Precompute atlas embeddings
272
+
273
+
274
+ query_emb = extract_embedding(user_section, encoder, transform).reshape(1, -1)
275
+ sims = cosine_similarity(query_emb, atlas_embeddings)[0]
276
+
277
+ pred_idx = np.argmax(sims)
278
+ pred_gt = atlas_labels[pred_idx]
279
+
280
+ return int(pred_gt)
281
+
282
+
283
+ def gray_scale(image):
284
+ # input: a 2D RGB image (x,y,z)
285
+ # output: a grayscale image (x,y)
286
+ # todo: fix the depth issue of pixels
287
+ if len(np.shape(image))>2:
288
+ return cv2.cvtColor(image[:,:,:3], cv2.COLOR_RGB2GRAY)
289
+ else:
290
+ return image
291
+
292
+ def atlas_slice_prediction(user_section, axis = 'coronal'):
293
+
294
+ user_section = gray_scale(square_padding(gray_scale(user_section)))
295
+ user_section = gray_scale(user_section)
296
+ user_section = square_padding(user_section, 224)
297
+ user_section = (user_section - np.min(user_section))/((np.max(user_section) - np.min(user_section)))
298
+ print("Loading model")
299
+ atlas_embeddings = np.load(f"registration/atlas_embeddings_{axis}.npy")
300
+ atlas_labels = np.load(f"registration/atlas_labels_{axis}.npy")
301
+ idx = embeddings_classifier(user_section, atlas_embeddings,atlas_labels)
302
+
303
+ return idx
304
+
305
+
306
+
307
+
308
+
309
+
310
+ example_files = [
311
+ ["./resampled_green_25.nii.gz", "CCF registered Sample", "3D"],
312
+ ["./Brain_1.png", "Custom Sample", "2D"],
313
+ # ["examples/sample3.nii.gz"]
314
+ ]
315
+
316
+ # Global variables
317
+ coronal_slices = []
318
+ last_probabilities = []
319
+ prob_df = pd.DataFrame()
320
+ vol = None
321
+ slice_idx = None
322
+
323
+ # Target cell types
324
+ cell_types = [
325
+ "ABC.NN", "Astro.TE.NN", "CLA.EPd.CTX.Car3.Glut", "Endo.NN", "L2.3.IT.CTX.Glut",
326
+ "L4.5.IT.CTX.Glut", "L5.ET.CTX.Glut", "L5.IT.CTX.Glut", "L5.NP.CTX.Glut", "L6.CT.CTX.Glut",
327
+ "L6.IT.CTX.Glut", "L6b.CTX.Glut", "Lamp5.Gaba", "Lamp5.Lhx6.Gaba", "Lymphoid.NN", "Microglia.NN",
328
+ "OPC.NN", "Oligo.NN", "Peri.NN", "Pvalb.Gaba", "Pvalb.chandelier.Gaba", "SMC.NN", "Sncg.Gaba",
329
+ "Sst.Chodl.Gaba", "Sst.Gaba", "VLMC.NN", "Vip.Gaba"
330
+ ]
331
+
332
+ actual_ids = [30,52,71,91,104,109,118,126,131,137,141,164,178,182,197,208,218,226,232,242,244,248,256,262,270,282,293,297,308,323,339,344,350,355,364,372,379,389,395,401,410,415,418,424,429,434,440,444,469,479,487,509]
333
+ gallery_ids = [5,6,8,9,10,11,12,13,14,15,16,17,18,19,24,25,26,27,28,29,30,31,32,33,35,36,37,38,39,40,42,43,44,45,46,47,48,49,50,51,52,54,55,56,57,58,59,60,61,62,64,66,67]
334
+ # gallery_ids.reverse()
335
+
336
+ allen_atlas_ccf, header = nrrd.read('./registration/annotation_25.nrrd')
337
+ allen_template_ccf, _ = nrrd.read("./registration/average_template_25.nrrd")
338
+ # colored_atlas,_ = nrrd.read('./registration/colored_atlas_turbo.nrrd')
339
+ gallery_selected_data = None
340
+ def load_nifti_or_png(file, sample_type, data_type):
341
+ global coronal_slices, vol, slice_idx, gallery_selected_data
342
+ if file.name.endswith(".nii") or file.name.endswith(".nii.gz"):
343
+ img = nib.load(file.name)
344
+ vol = img.get_fdata()
345
+ coronal_slices = [vol[i, :, :] for i in range(vol.shape[0])]
346
+ if data_type == "2D":
347
+ mid_index = vol.shape[0] // 2
348
+ slice_img = Image.fromarray((coronal_slices[mid_index] / np.max(coronal_slices[mid_index]) * 255).astype(np.uint8))
349
+ gallery_images = load_gallery_images()
350
+ return (
351
+ slice_img,
352
+ gr.update(visible=False),
353
+ gallery_images,
354
+ gr.update(visible=True),
355
+ gr.update(visible=True),
356
+ gr.update(visible=(sample_type == "Custom Sample"))
357
+ )
358
+ else: # 3D with actual_ids only
359
+ coronal_slices = [vol[i, :, :] for i in actual_ids]
360
+ idx = len(actual_ids) // 2 # Mid of actual_ids
361
+ slice_img = Image.fromarray((coronal_slices[idx] / np.max(coronal_slices[idx]) * 255).astype(np.uint8))
362
+ gallery_images = load_gallery_images()
363
+ gallery_images = natsorted(gallery_images)
364
+ return (
365
+ slice_img,
366
+ gr.update(visible=True, minimum=0, maximum=len(coronal_slices)-1, value=idx),
367
+ gallery_images,
368
+ gr.update(visible=True),
369
+ gr.update(visible=True),
370
+ gr.update(visible=(sample_type == "Custom Sample"))
371
+ )
372
+
373
+
374
+ else:
375
+ img = Image.open(file.name).convert("L")
376
+ vol = np.array(img)
377
+ coronal_slices = [np.array(img)]
378
+ gallery_images = natsorted(load_gallery_images())
379
+ idx = atlas_slice_prediction(np.array(img))
380
+ slice_idx = idx
381
+ closest_actual_idx = min(actual_ids, key=lambda x: abs(x - idx))
382
+ gallery_index = actual_ids.index(closest_actual_idx)
383
+ print(gallery_index, len(actual_ids) -(gallery_index))
384
+ gallery_selected_data = len(actual_ids) -(gallery_index)
385
+
386
+ return (
387
+ img,
388
+ gr.update(visible=False),
389
+ gr.update(selected_index=len(actual_ids) -(gallery_index) if gallery_index < len(gallery_ids) else 0, visible = True),
390
+ # gr.update(value=gallery_images, selected_index=len(actual_ids) -(gallery_index)), # gallery
391
+ gr.update(visible=True),
392
+ gr.update(visible=True),
393
+ gr.update(visible=(sample_type == "Custom Sample"))
394
+ )
395
+
396
+
397
+
398
+
399
+ def update_slice(index):
400
+ if not coronal_slices:
401
+ return None, None, None
402
+ slice_img = Image.fromarray((coronal_slices[index] / np.max(coronal_slices[index]) * 255).astype(np.uint8))
403
+ gallery_selection = gr.update(selected_index=len(gallery_ids) - index if index < len(gallery_ids) else 0)
404
+
405
+ if last_probabilities:
406
+ noise = np.random.normal(0, 0.01, size=len(last_probabilities))
407
+ new_probs = np.clip(np.array(last_probabilities) + noise, 0, None)
408
+ new_probs /= new_probs.sum()
409
+ else:
410
+ new_probs = []
411
+ return slice_img, plot_probabilities(new_probs), gallery_selection
412
+
413
+
414
+ def load_gallery_images():
415
+ folder = "Overlapped_updated"
416
+ images = []
417
+ if os.path.exists(folder):
418
+ for fname in sorted(os.listdir(folder)):
419
+ if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
420
+ images.append(os.path.join(folder, fname))
421
+ return images
422
+
423
+ def generate_random_probabilities():
424
+ probs = np.random.rand(len(cell_types))
425
+ low_indices = np.random.choice(len(probs), size=5, replace=False)
426
+ for idx in low_indices:
427
+ probs[idx] = np.random.rand() * 0.01
428
+ probs /= probs.sum()
429
+ return probs.tolist()
430
+
431
+ def plot_probabilities(probabilities):
432
+ if len(probabilities) < 1:
433
+ return None
434
+ prob_df = pd.DataFrame({"labels": cell_types, "values": probabilities})
435
+ prob_df.to_csv('Cell_types_predictions.csv', index=False)
436
+ return prob_df
437
+
438
+ def run_mapping():
439
+ global last_probabilities
440
+ last_probabilities = generate_random_probabilities()
441
+ return plot_probabilities(last_probabilities), gr.update(visible=True)
442
+
443
+ def run_registration(data_type, selected_idx):
444
+ global vol, slice_idx
445
+ print("Running registration logic here..., Vol shape::", vol.shape)
446
+ if data_type == "3D":
447
+ gallery_images = run_3D_registration(vol)
448
+
449
+ else:
450
+ gallery_images = run_2D_registration(vol, slice_idx)
451
+ return gallery_images
452
+
453
+
454
+
455
+
456
+ return "Registration complete!"
457
+
458
+ def download_csv():
459
+ return 'Cell_types_predictions.csv'
460
+
461
+
462
+ def handle_data_type_change(dt):
463
+ if dt == "2D":
464
+ return gr.update(visible=False)
465
+ else:
466
+ return gr.update(visible=True, minimum=0, maximum=len(actual_ids)-1, value=len(actual_ids)//2)
467
+
468
+ def on_select(evt: gr.SelectData):
469
+ gallery_selected_data = evt.index
470
+
471
+ gallery_images = natsorted(load_gallery_images())
472
+ with gr.Blocks() as demo:
473
+ gr.Markdown("# Map My Sections")
474
+
475
+ gr.Markdown("### Step 1: Upload your sample and choose type")
476
+ with gr.Row():
477
+ nifti_file = gr.File(label="File Upload")
478
+ with gr.Column():
479
+ sample_type = gr.Dropdown(choices=["CCF registered Sample", "Custom Sample"], value="CCF registered Sample", label="Sample Type")
480
+ data_type = gr.Radio(choices=["2D", "3D"], value="3D", label="Data Type")
481
+
482
+ gr.Examples(examples=example_files, inputs=[nifti_file, sample_type, data_type], label="Try one of our example samples")
483
+
484
+ with gr.Row(visible=False) as slice_row:
485
+ with gr.Column(scale=1):
486
+ gr.Markdown("### Step 2: Visualizing your uploaded sample")
487
+ image_display = gr.Image(height=450)
488
+ slice_slider = gr.Slider(minimum=0, maximum=0, value=0, step=1, label="Slices", visible=False)
489
+ with gr.Column(scale=1):
490
+ gr.Markdown("### Step 3: Visualizing Allen Brain Cell Types Atlas")
491
+ gallery = gr.Gallery(label="ABC Atlas", value = gallery_images,)
492
+ gr.Markdown("**Step 4: Run cell type mapping**")
493
+ with gr.Row():
494
+ run_button = gr.Button("Run Mapping")
495
+ reg_button = gr.Button("Run Registration", visible=False)
496
+
497
+ with gr.Column(visible=False) as plot_row:
498
+ gr.Markdown("### Step 5: Quantitative results of the mapping model.")
499
+ prob_plot = gr.BarPlot(prob_df, x="labels", y="values", title="Cell Type Probabilities", scroll_to_output=True, x_label_angle=-90, height=400)
500
+ gr.Markdown("### Step 6: Download Results.")
501
+ download_button = gr.DownloadButton(label="Download Results", value='Cell_types_predictions.csv')
502
+
503
+ nifti_file.change(
504
+ load_nifti_or_png,
505
+ inputs=[nifti_file, sample_type, data_type],
506
+ outputs=[image_display, slice_slider, gallery, slice_row, plot_row, reg_button]
507
+ )
508
+
509
+ sample_type.change(
510
+ lambda s: (gr.update(visible=True), gr.update(visible=(s == "Custom Sample"))),
511
+ inputs=sample_type,
512
+ outputs=[slice_row, reg_button]
513
+ )
514
+
515
+
516
+ data_type.change(
517
+ handle_data_type_change,
518
+ inputs=data_type,
519
+ outputs=slice_slider
520
+ )
521
+
522
+ gallery.select(on_select, None, None)
523
+
524
+ slice_slider.change(update_slice, inputs=slice_slider, outputs=[image_display, prob_plot, gallery])
525
+ run_button.click(run_mapping, outputs=[prob_plot, plot_row])
526
+ reg_button.click(run_registration,inputs = [data_type], outputs=[gallery])
527
+
528
+ demo.launch()