Spaces:
Running
Running
updated app.py
Browse files- registration/app.py +528 -0
registration/app.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import nibabel as nib
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from PIL import Image
|
6 |
+
import pandas as pd
|
7 |
+
import nrrd
|
8 |
+
import ants
|
9 |
+
from natsort import natsorted
|
10 |
+
from scipy.ndimage import zoom, rotate
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torchvision.models as models
|
15 |
+
import torchvision.transforms as transforms
|
16 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
17 |
+
import cv2
|
18 |
+
|
19 |
+
def square_padd(original_data, square_size=(120,152, 184), order = 1):
|
20 |
+
# e.g. square_size = 256 by default
|
21 |
+
# takes a raw image as input
|
22 |
+
# returns a square (padded) image as output
|
23 |
+
|
24 |
+
# order = [int(x-1) for x in ss.rankdata(original_data.shape)]
|
25 |
+
# # print(order)
|
26 |
+
# data = original_data.transpose(order)
|
27 |
+
data= original_data
|
28 |
+
# print(original_data.shape)
|
29 |
+
# print(data.shape)
|
30 |
+
if data.shape[1]>data.shape[0] and data.shape[1]>data.shape[2]: # width>height
|
31 |
+
scale_percent = (square_size[1]/data.shape[1])*100
|
32 |
+
# print("dim1")
|
33 |
+
elif data.shape[2]>data.shape[0] and data.shape[2]>data.shape[1]: # width>height
|
34 |
+
scale_percent = (square_size[2]/data.shape[2])*100
|
35 |
+
# print("dim2")
|
36 |
+
else: # width<height
|
37 |
+
scale_percent = (square_size[0]/data.shape[0])*100
|
38 |
+
scale_percent = int(scale_percent)
|
39 |
+
# print(scale_percent)
|
40 |
+
width = int(data.shape[0] * scale_percent / 100); height = int(data.shape[1] * scale_percent / 100); depth = int(data.shape[2] * scale_percent / 100);
|
41 |
+
dim = (width, height, depth)
|
42 |
+
# print(dim)
|
43 |
+
zoomFactors = [square_size_axis/float(data_shape) for data_shape, square_size_axis in zip(data.shape, square_size)]
|
44 |
+
sect_mask = zoom(data,zoom = zoomFactors, order = order, )
|
45 |
+
# sect_mask = zoom(data,(scale_percent/100, scale_percent/100, scale_percent/100), order = order, )
|
46 |
+
# sect_mask = cv2.resize(data, dim, interpolation = cv2.INTER_AREA)
|
47 |
+
sect_padd = (np.ones(square_size))*data[0,0,0]
|
48 |
+
sect_padd[int((square_size[0]-np.shape(sect_mask)[0])/2):int((square_size[0]-np.shape(sect_mask)[0])/2)+np.shape(sect_mask)[0],
|
49 |
+
int((square_size[1]-np.shape(sect_mask)[1])/2):int((square_size[1]-np.shape(sect_mask)[1])/2)+np.shape(sect_mask)[1],
|
50 |
+
int((square_size[2]-np.shape(sect_mask)[2])/2):int((square_size[2]-np.shape(sect_mask)[2])/2)+np.shape(sect_mask)[2]] = sect_mask
|
51 |
+
return sect_padd
|
52 |
+
|
53 |
+
def square_padding_RGB(single_RGB,square_size=256):
|
54 |
+
# e.g. square_size = 256 by default
|
55 |
+
# takes a raw image as input
|
56 |
+
# returns a square (padded) image as output
|
57 |
+
# input: 2D image
|
58 |
+
# output: 2D resized padded image
|
59 |
+
# example: BNI images, HMS data
|
60 |
+
if single_RGB.shape[1]>single_RGB.shape[0]: # width>height
|
61 |
+
scale_percent = (square_size/single_RGB.shape[1])*100
|
62 |
+
else: # width<height
|
63 |
+
scale_percent = (square_size/single_RGB.shape[0])*100
|
64 |
+
width = int(single_RGB.shape[1] * scale_percent / 100); height = int(single_RGB.shape[0] * scale_percent / 100); dim = (width, height)
|
65 |
+
sect_mask = cv2.resize(single_RGB, dim, interpolation = cv2.INTER_AREA)
|
66 |
+
sect_padd = (np.ones((square_size,square_size,3)))*np.mean(single_RGB[:10,:10])
|
67 |
+
sect_padd[int((square_size-np.shape(sect_mask)[0])/2):int((square_size-np.shape(sect_mask)[0])/2)+np.shape(sect_mask)[0],
|
68 |
+
int((square_size-np.shape(sect_mask)[1])/2):int((square_size-np.shape(sect_mask)[1])/2)+np.shape(sect_mask)[1],:] = sect_mask
|
69 |
+
return sect_padd
|
70 |
+
|
71 |
+
def square_padding(single_gray,square_size=256):
|
72 |
+
# e.g. square_size = 256 by default
|
73 |
+
# takes a raw image as input
|
74 |
+
# returns a square (padded) image as output
|
75 |
+
# input: 2D image
|
76 |
+
# output: 2D resized padded image
|
77 |
+
# example: BNI images, HMS data
|
78 |
+
if len(np.shape(single_gray))>2:
|
79 |
+
return square_padding_RGB(single_gray[:,:,:3])
|
80 |
+
else:
|
81 |
+
# print("Single gray shape:", np.shape(single_gray))
|
82 |
+
if single_gray.shape[1]>single_gray.shape[0]: # width>height
|
83 |
+
scale_percent = (square_size/single_gray.shape[1])*100
|
84 |
+
else: # width<height
|
85 |
+
scale_percent = (square_size/single_gray.shape[0])*100
|
86 |
+
width = int(single_gray.shape[1] * scale_percent / 100); height = int(single_gray.shape[0] * scale_percent / 100); dim = (width, height)
|
87 |
+
# print("Dim::", dim)
|
88 |
+
sect_mask = cv2.resize(single_gray, dim, interpolation = cv2.INTER_AREA)
|
89 |
+
sect_padd = (np.zeros((square_size,square_size)))*single_gray[-20,-20]#find a better solution for single_gray[100,-100]
|
90 |
+
sect_padd[int((square_size-np.shape(sect_mask)[0])/2):int((square_size-np.shape(sect_mask)[0])/2)+np.shape(sect_mask)[0],
|
91 |
+
int((square_size-np.shape(sect_mask)[1])/2):int((square_size-np.shape(sect_mask)[1])/2)+np.shape(sect_mask)[1]] = sect_mask
|
92 |
+
return sect_padd
|
93 |
+
|
94 |
+
|
95 |
+
def affine_reg(fixed_image,moving_image,gauss_param=100):
|
96 |
+
# this function takes fixed and moving images as input and return affine transformation matrix
|
97 |
+
# fixed/moving images can be 2D/3D
|
98 |
+
# todo: add an option as flag to save the transformation matrix and displacement fields at the desired location to be able to apply the transforms later
|
99 |
+
mytx = ants.registration(fixed=fixed_image,
|
100 |
+
moving=moving_image,
|
101 |
+
type_of_transform='Affine',
|
102 |
+
reg_iterations = (gauss_param,gauss_param,gauss_param,gauss_param))
|
103 |
+
print('affine registration completed')
|
104 |
+
return mytx
|
105 |
+
|
106 |
+
|
107 |
+
def nonrigid_reg(fixed_image,mytx,type_of_transform='SyN',grad_step=0.25,reg_iterations=(50,50,50, ),flow_sigma=9,total_sigma=0.2):
|
108 |
+
# this function takes fixed image and affined tx matrix as input and return non-rigid transformation matrix
|
109 |
+
# fixed/moving images can be 2D/3D
|
110 |
+
# type of transform selection: https://antspy.readthedocs.io/en/latest/registration.html
|
111 |
+
# todo: scale the function to incorporate the extended parameters for type_of_transform
|
112 |
+
# todo: scale the function to incorporate the affine+non-rigid simultaneously in case of SyNRA
|
113 |
+
|
114 |
+
transform_type = {'SyN':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma},
|
115 |
+
'SyNRA':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma}}
|
116 |
+
|
117 |
+
mytx_non_rigid = ants.registration(fixed = fixed_image,
|
118 |
+
moving=mytx['warpedmovout'],
|
119 |
+
type_of_transform=type_of_transform,
|
120 |
+
grad_step=transform_type[type_of_transform]['grad_step'],
|
121 |
+
reg_iterations=transform_type[type_of_transform]['reg_iterations'],
|
122 |
+
flow_sigma=transform_type[type_of_transform]['flow_sigma'],
|
123 |
+
total_sigma=transform_type[type_of_transform]['total_sigma'])
|
124 |
+
|
125 |
+
print('non-rigid registration completed')
|
126 |
+
return mytx_non_rigid
|
127 |
+
|
128 |
+
def affine_reg(fixed_image,moving_image,gauss_param=100):
|
129 |
+
# this function takes fixed and moving images as input and return affine transformation matrix
|
130 |
+
# fixed/moving images can be 2D/3D
|
131 |
+
# todo: add an option as flag to save the transformation matrix and displacement fields at the desired location to be able to apply the transforms later
|
132 |
+
mytx = ants.registration(fixed=fixed_image,
|
133 |
+
moving=moving_image,
|
134 |
+
type_of_transform='Affine',
|
135 |
+
reg_iterations = (gauss_param,gauss_param,gauss_param,gauss_param))
|
136 |
+
print('affine registration completed')
|
137 |
+
return mytx
|
138 |
+
|
139 |
+
|
140 |
+
def nonrigid_reg(fixed_image,mytx,type_of_transform='SyN',grad_step=0.25,reg_iterations=(50,50,50, ),flow_sigma=9,total_sigma=0.2):
|
141 |
+
# this function takes fixed image and affined tx matrix as input and return non-rigid transformation matrix
|
142 |
+
# fixed/moving images can be 2D/3D
|
143 |
+
# type of transform selection: https://antspy.readthedocs.io/en/latest/registration.html
|
144 |
+
# todo: scale the function to incorporate the extended parameters for type_of_transform
|
145 |
+
# todo: scale the function to incorporate the affine+non-rigid simultaneously in case of SyNRA
|
146 |
+
|
147 |
+
transform_type = {'SyN':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma},
|
148 |
+
'SyNRA':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma}}
|
149 |
+
|
150 |
+
mytx_non_rigid = ants.registration(fixed = fixed_image,
|
151 |
+
moving=mytx['warpedmovout'],
|
152 |
+
type_of_transform=type_of_transform,
|
153 |
+
grad_step=transform_type[type_of_transform]['grad_step'],
|
154 |
+
reg_iterations=transform_type[type_of_transform]['reg_iterations'],
|
155 |
+
flow_sigma=transform_type[type_of_transform]['flow_sigma'],
|
156 |
+
total_sigma=transform_type[type_of_transform]['total_sigma'])
|
157 |
+
|
158 |
+
print('non-rigid registration completed')
|
159 |
+
return mytx_non_rigid
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
def run_3D_registration(user_section, ):
|
164 |
+
global allen_atlas_ccf, allen_template_ccf
|
165 |
+
template_atlas = allen_atlas_ccf
|
166 |
+
template_section = allen_template_ccf
|
167 |
+
template_atlas = np.uint16(template_atlas*255)
|
168 |
+
user_section = square_padd(user_section, (60, 76, 92))
|
169 |
+
|
170 |
+
template_atlas = square_padd(template_atlas, user_section.shape)
|
171 |
+
template_section = square_padd(template_section, user_section.shape)
|
172 |
+
|
173 |
+
fixed_image = ants.from_numpy(user_section)
|
174 |
+
moving_atlas_ants = ants.from_numpy(template_atlas)
|
175 |
+
moving_image = ants.from_numpy(template_section)
|
176 |
+
|
177 |
+
mytx = affine_reg(fixed_image,moving_image)
|
178 |
+
mytx_non_rigid = nonrigid_reg(fixed_image,mytx)
|
179 |
+
affined_fixed_atlas = ants.apply_transforms(fixed=fixed_image,
|
180 |
+
moving=moving_image,
|
181 |
+
transformlist=mytx['fwdtransforms'],
|
182 |
+
interpolator='nearestNeighbor')
|
183 |
+
nonrigid_fixed_atlas = ants.apply_transforms(fixed=fixed_image,
|
184 |
+
moving=affined_fixed_atlas,
|
185 |
+
transformlist=mytx_non_rigid['fwdtransforms'],
|
186 |
+
interpolator='nearestNeighbor')
|
187 |
+
gallery_images = load_gallery_images()
|
188 |
+
transformed_images = []
|
189 |
+
if not(os.path.exists("Overlaped_registered")):
|
190 |
+
os.mkdir("Overlaped_registered")
|
191 |
+
registered = nonrigid_fixed_atlas.numpy()/255
|
192 |
+
for id in list(range((registered.shape[0]//2)-15, (registered.shape[0]//2)+15, 2)):
|
193 |
+
print(id)
|
194 |
+
plt.imsave(f'Overlaped_registered/{id}.png',registered[id, :, :], cmap = 'gray' )
|
195 |
+
transformed_images.append(f'Overlaped_registered/{id}.png')
|
196 |
+
|
197 |
+
return transformed_images
|
198 |
+
|
199 |
+
|
200 |
+
def run_2D_registration(user_section, slice_idx):
|
201 |
+
global allen_atlas_ccf, allen_template_ccf, gallery_selected_data
|
202 |
+
template_atlas = allen_atlas_ccf
|
203 |
+
template_section = allen_template_ccf
|
204 |
+
|
205 |
+
template_atlas = allen_atlas_ccf[slice_idx,:,:]
|
206 |
+
template_section = allen_template_ccf[slice_idx,:,:]
|
207 |
+
# colored_atlas = colored_atlas[slice_idx,:,:]
|
208 |
+
print(np.shape(template_atlas), np.shape(template_section))
|
209 |
+
user_section = square_padding(user_section)
|
210 |
+
|
211 |
+
template_atlas = np.uint16(template_atlas*255)
|
212 |
+
template_atlas = square_padding(template_atlas)
|
213 |
+
template_section = square_padding(template_section)
|
214 |
+
|
215 |
+
fixed_image = ants.from_numpy(user_section)
|
216 |
+
moving_atlas_ants = ants.from_numpy(template_atlas)
|
217 |
+
moving_image = ants.from_numpy(template_section)
|
218 |
+
|
219 |
+
mytx = affine_reg(fixed_image,moving_image)
|
220 |
+
mytx_non_rigid = nonrigid_reg(fixed_image,mytx)
|
221 |
+
gallery_imgs = natsorted(load_gallery_images())
|
222 |
+
moving_gallery_img = ants.from_numpy(square_padding(plt.imread(gallery_imgs[gallery_selected_data])))
|
223 |
+
affined_fixed_atlas = ants.apply_transforms(fixed=fixed_image,
|
224 |
+
moving=moving_image,
|
225 |
+
transformlist=mytx['fwdtransforms'],
|
226 |
+
interpolator='nearestNeighbor')
|
227 |
+
nonrigid_fixed_atlas = ants.apply_transforms(fixed=fixed_image,
|
228 |
+
moving=affined_fixed_atlas,
|
229 |
+
transformlist=mytx_non_rigid['fwdtransforms'],
|
230 |
+
interpolator='nearestNeighbor')
|
231 |
+
gallery_images = load_gallery_images()
|
232 |
+
transformed_images = []
|
233 |
+
if not(os.path.exists("Overlaped_registered")):
|
234 |
+
os.mkdir("Overlaped_registered")
|
235 |
+
plt.imsave(f'Overlaped_registered/registered_slice.png',nonrigid_fixed_atlas.numpy()/255, cmap = 'gray')
|
236 |
+
|
237 |
+
return ['Overlaped_registered/registered_slice.png']
|
238 |
+
|
239 |
+
|
240 |
+
def embeddings_classifier(user_section, atlas_embeddings,atlas_labels):
|
241 |
+
class SliceEncoder(nn.Module):
|
242 |
+
def __init__(self):
|
243 |
+
super(SliceEncoder, self).__init__()
|
244 |
+
base = models.resnet18(pretrained=True)
|
245 |
+
self.backbone = nn.Sequential(*list(base.children())[:-1]) # Remove final FC layer
|
246 |
+
|
247 |
+
def forward(self, x):
|
248 |
+
x = self.backbone(x) # Output shape: (B, 512, 1, 1)
|
249 |
+
return x.view(x.size(0), -1) # Flatten to (B, 512)
|
250 |
+
|
251 |
+
# Transform
|
252 |
+
transform = transforms.Compose([
|
253 |
+
transforms.Resize((224, 224)),
|
254 |
+
transforms.ToTensor(),
|
255 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
256 |
+
std=[0.229, 0.224, 0.225]),
|
257 |
+
])
|
258 |
+
|
259 |
+
# Feature extraction utility
|
260 |
+
def extract_embedding(img_array, encoder, transform):
|
261 |
+
img = Image.fromarray(((img_array) * 255).astype(np.uint8)).convert('RGB')
|
262 |
+
img_tensor = transform(img).unsqueeze(0).to(device)
|
263 |
+
with torch.no_grad():
|
264 |
+
embedding = encoder(img_tensor)
|
265 |
+
return embedding.cpu().numpy().flatten()
|
266 |
+
|
267 |
+
# Prepare device and model
|
268 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
269 |
+
encoder = SliceEncoder().to(device).eval()
|
270 |
+
|
271 |
+
# Precompute atlas embeddings
|
272 |
+
|
273 |
+
|
274 |
+
query_emb = extract_embedding(user_section, encoder, transform).reshape(1, -1)
|
275 |
+
sims = cosine_similarity(query_emb, atlas_embeddings)[0]
|
276 |
+
|
277 |
+
pred_idx = np.argmax(sims)
|
278 |
+
pred_gt = atlas_labels[pred_idx]
|
279 |
+
|
280 |
+
return int(pred_gt)
|
281 |
+
|
282 |
+
|
283 |
+
def gray_scale(image):
|
284 |
+
# input: a 2D RGB image (x,y,z)
|
285 |
+
# output: a grayscale image (x,y)
|
286 |
+
# todo: fix the depth issue of pixels
|
287 |
+
if len(np.shape(image))>2:
|
288 |
+
return cv2.cvtColor(image[:,:,:3], cv2.COLOR_RGB2GRAY)
|
289 |
+
else:
|
290 |
+
return image
|
291 |
+
|
292 |
+
def atlas_slice_prediction(user_section, axis = 'coronal'):
|
293 |
+
|
294 |
+
user_section = gray_scale(square_padding(gray_scale(user_section)))
|
295 |
+
user_section = gray_scale(user_section)
|
296 |
+
user_section = square_padding(user_section, 224)
|
297 |
+
user_section = (user_section - np.min(user_section))/((np.max(user_section) - np.min(user_section)))
|
298 |
+
print("Loading model")
|
299 |
+
atlas_embeddings = np.load(f"registration/atlas_embeddings_{axis}.npy")
|
300 |
+
atlas_labels = np.load(f"registration/atlas_labels_{axis}.npy")
|
301 |
+
idx = embeddings_classifier(user_section, atlas_embeddings,atlas_labels)
|
302 |
+
|
303 |
+
return idx
|
304 |
+
|
305 |
+
|
306 |
+
|
307 |
+
|
308 |
+
|
309 |
+
|
310 |
+
example_files = [
|
311 |
+
["./resampled_green_25.nii.gz", "CCF registered Sample", "3D"],
|
312 |
+
["./Brain_1.png", "Custom Sample", "2D"],
|
313 |
+
# ["examples/sample3.nii.gz"]
|
314 |
+
]
|
315 |
+
|
316 |
+
# Global variables
|
317 |
+
coronal_slices = []
|
318 |
+
last_probabilities = []
|
319 |
+
prob_df = pd.DataFrame()
|
320 |
+
vol = None
|
321 |
+
slice_idx = None
|
322 |
+
|
323 |
+
# Target cell types
|
324 |
+
cell_types = [
|
325 |
+
"ABC.NN", "Astro.TE.NN", "CLA.EPd.CTX.Car3.Glut", "Endo.NN", "L2.3.IT.CTX.Glut",
|
326 |
+
"L4.5.IT.CTX.Glut", "L5.ET.CTX.Glut", "L5.IT.CTX.Glut", "L5.NP.CTX.Glut", "L6.CT.CTX.Glut",
|
327 |
+
"L6.IT.CTX.Glut", "L6b.CTX.Glut", "Lamp5.Gaba", "Lamp5.Lhx6.Gaba", "Lymphoid.NN", "Microglia.NN",
|
328 |
+
"OPC.NN", "Oligo.NN", "Peri.NN", "Pvalb.Gaba", "Pvalb.chandelier.Gaba", "SMC.NN", "Sncg.Gaba",
|
329 |
+
"Sst.Chodl.Gaba", "Sst.Gaba", "VLMC.NN", "Vip.Gaba"
|
330 |
+
]
|
331 |
+
|
332 |
+
actual_ids = [30,52,71,91,104,109,118,126,131,137,141,164,178,182,197,208,218,226,232,242,244,248,256,262,270,282,293,297,308,323,339,344,350,355,364,372,379,389,395,401,410,415,418,424,429,434,440,444,469,479,487,509]
|
333 |
+
gallery_ids = [5,6,8,9,10,11,12,13,14,15,16,17,18,19,24,25,26,27,28,29,30,31,32,33,35,36,37,38,39,40,42,43,44,45,46,47,48,49,50,51,52,54,55,56,57,58,59,60,61,62,64,66,67]
|
334 |
+
# gallery_ids.reverse()
|
335 |
+
|
336 |
+
allen_atlas_ccf, header = nrrd.read('./registration/annotation_25.nrrd')
|
337 |
+
allen_template_ccf, _ = nrrd.read("./registration/average_template_25.nrrd")
|
338 |
+
# colored_atlas,_ = nrrd.read('./registration/colored_atlas_turbo.nrrd')
|
339 |
+
gallery_selected_data = None
|
340 |
+
def load_nifti_or_png(file, sample_type, data_type):
|
341 |
+
global coronal_slices, vol, slice_idx, gallery_selected_data
|
342 |
+
if file.name.endswith(".nii") or file.name.endswith(".nii.gz"):
|
343 |
+
img = nib.load(file.name)
|
344 |
+
vol = img.get_fdata()
|
345 |
+
coronal_slices = [vol[i, :, :] for i in range(vol.shape[0])]
|
346 |
+
if data_type == "2D":
|
347 |
+
mid_index = vol.shape[0] // 2
|
348 |
+
slice_img = Image.fromarray((coronal_slices[mid_index] / np.max(coronal_slices[mid_index]) * 255).astype(np.uint8))
|
349 |
+
gallery_images = load_gallery_images()
|
350 |
+
return (
|
351 |
+
slice_img,
|
352 |
+
gr.update(visible=False),
|
353 |
+
gallery_images,
|
354 |
+
gr.update(visible=True),
|
355 |
+
gr.update(visible=True),
|
356 |
+
gr.update(visible=(sample_type == "Custom Sample"))
|
357 |
+
)
|
358 |
+
else: # 3D with actual_ids only
|
359 |
+
coronal_slices = [vol[i, :, :] for i in actual_ids]
|
360 |
+
idx = len(actual_ids) // 2 # Mid of actual_ids
|
361 |
+
slice_img = Image.fromarray((coronal_slices[idx] / np.max(coronal_slices[idx]) * 255).astype(np.uint8))
|
362 |
+
gallery_images = load_gallery_images()
|
363 |
+
gallery_images = natsorted(gallery_images)
|
364 |
+
return (
|
365 |
+
slice_img,
|
366 |
+
gr.update(visible=True, minimum=0, maximum=len(coronal_slices)-1, value=idx),
|
367 |
+
gallery_images,
|
368 |
+
gr.update(visible=True),
|
369 |
+
gr.update(visible=True),
|
370 |
+
gr.update(visible=(sample_type == "Custom Sample"))
|
371 |
+
)
|
372 |
+
|
373 |
+
|
374 |
+
else:
|
375 |
+
img = Image.open(file.name).convert("L")
|
376 |
+
vol = np.array(img)
|
377 |
+
coronal_slices = [np.array(img)]
|
378 |
+
gallery_images = natsorted(load_gallery_images())
|
379 |
+
idx = atlas_slice_prediction(np.array(img))
|
380 |
+
slice_idx = idx
|
381 |
+
closest_actual_idx = min(actual_ids, key=lambda x: abs(x - idx))
|
382 |
+
gallery_index = actual_ids.index(closest_actual_idx)
|
383 |
+
print(gallery_index, len(actual_ids) -(gallery_index))
|
384 |
+
gallery_selected_data = len(actual_ids) -(gallery_index)
|
385 |
+
|
386 |
+
return (
|
387 |
+
img,
|
388 |
+
gr.update(visible=False),
|
389 |
+
gr.update(selected_index=len(actual_ids) -(gallery_index) if gallery_index < len(gallery_ids) else 0, visible = True),
|
390 |
+
# gr.update(value=gallery_images, selected_index=len(actual_ids) -(gallery_index)), # gallery
|
391 |
+
gr.update(visible=True),
|
392 |
+
gr.update(visible=True),
|
393 |
+
gr.update(visible=(sample_type == "Custom Sample"))
|
394 |
+
)
|
395 |
+
|
396 |
+
|
397 |
+
|
398 |
+
|
399 |
+
def update_slice(index):
|
400 |
+
if not coronal_slices:
|
401 |
+
return None, None, None
|
402 |
+
slice_img = Image.fromarray((coronal_slices[index] / np.max(coronal_slices[index]) * 255).astype(np.uint8))
|
403 |
+
gallery_selection = gr.update(selected_index=len(gallery_ids) - index if index < len(gallery_ids) else 0)
|
404 |
+
|
405 |
+
if last_probabilities:
|
406 |
+
noise = np.random.normal(0, 0.01, size=len(last_probabilities))
|
407 |
+
new_probs = np.clip(np.array(last_probabilities) + noise, 0, None)
|
408 |
+
new_probs /= new_probs.sum()
|
409 |
+
else:
|
410 |
+
new_probs = []
|
411 |
+
return slice_img, plot_probabilities(new_probs), gallery_selection
|
412 |
+
|
413 |
+
|
414 |
+
def load_gallery_images():
|
415 |
+
folder = "Overlapped_updated"
|
416 |
+
images = []
|
417 |
+
if os.path.exists(folder):
|
418 |
+
for fname in sorted(os.listdir(folder)):
|
419 |
+
if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
|
420 |
+
images.append(os.path.join(folder, fname))
|
421 |
+
return images
|
422 |
+
|
423 |
+
def generate_random_probabilities():
|
424 |
+
probs = np.random.rand(len(cell_types))
|
425 |
+
low_indices = np.random.choice(len(probs), size=5, replace=False)
|
426 |
+
for idx in low_indices:
|
427 |
+
probs[idx] = np.random.rand() * 0.01
|
428 |
+
probs /= probs.sum()
|
429 |
+
return probs.tolist()
|
430 |
+
|
431 |
+
def plot_probabilities(probabilities):
|
432 |
+
if len(probabilities) < 1:
|
433 |
+
return None
|
434 |
+
prob_df = pd.DataFrame({"labels": cell_types, "values": probabilities})
|
435 |
+
prob_df.to_csv('Cell_types_predictions.csv', index=False)
|
436 |
+
return prob_df
|
437 |
+
|
438 |
+
def run_mapping():
|
439 |
+
global last_probabilities
|
440 |
+
last_probabilities = generate_random_probabilities()
|
441 |
+
return plot_probabilities(last_probabilities), gr.update(visible=True)
|
442 |
+
|
443 |
+
def run_registration(data_type, selected_idx):
|
444 |
+
global vol, slice_idx
|
445 |
+
print("Running registration logic here..., Vol shape::", vol.shape)
|
446 |
+
if data_type == "3D":
|
447 |
+
gallery_images = run_3D_registration(vol)
|
448 |
+
|
449 |
+
else:
|
450 |
+
gallery_images = run_2D_registration(vol, slice_idx)
|
451 |
+
return gallery_images
|
452 |
+
|
453 |
+
|
454 |
+
|
455 |
+
|
456 |
+
return "Registration complete!"
|
457 |
+
|
458 |
+
def download_csv():
|
459 |
+
return 'Cell_types_predictions.csv'
|
460 |
+
|
461 |
+
|
462 |
+
def handle_data_type_change(dt):
|
463 |
+
if dt == "2D":
|
464 |
+
return gr.update(visible=False)
|
465 |
+
else:
|
466 |
+
return gr.update(visible=True, minimum=0, maximum=len(actual_ids)-1, value=len(actual_ids)//2)
|
467 |
+
|
468 |
+
def on_select(evt: gr.SelectData):
|
469 |
+
gallery_selected_data = evt.index
|
470 |
+
|
471 |
+
gallery_images = natsorted(load_gallery_images())
|
472 |
+
with gr.Blocks() as demo:
|
473 |
+
gr.Markdown("# Map My Sections")
|
474 |
+
|
475 |
+
gr.Markdown("### Step 1: Upload your sample and choose type")
|
476 |
+
with gr.Row():
|
477 |
+
nifti_file = gr.File(label="File Upload")
|
478 |
+
with gr.Column():
|
479 |
+
sample_type = gr.Dropdown(choices=["CCF registered Sample", "Custom Sample"], value="CCF registered Sample", label="Sample Type")
|
480 |
+
data_type = gr.Radio(choices=["2D", "3D"], value="3D", label="Data Type")
|
481 |
+
|
482 |
+
gr.Examples(examples=example_files, inputs=[nifti_file, sample_type, data_type], label="Try one of our example samples")
|
483 |
+
|
484 |
+
with gr.Row(visible=False) as slice_row:
|
485 |
+
with gr.Column(scale=1):
|
486 |
+
gr.Markdown("### Step 2: Visualizing your uploaded sample")
|
487 |
+
image_display = gr.Image(height=450)
|
488 |
+
slice_slider = gr.Slider(minimum=0, maximum=0, value=0, step=1, label="Slices", visible=False)
|
489 |
+
with gr.Column(scale=1):
|
490 |
+
gr.Markdown("### Step 3: Visualizing Allen Brain Cell Types Atlas")
|
491 |
+
gallery = gr.Gallery(label="ABC Atlas", value = gallery_images,)
|
492 |
+
gr.Markdown("**Step 4: Run cell type mapping**")
|
493 |
+
with gr.Row():
|
494 |
+
run_button = gr.Button("Run Mapping")
|
495 |
+
reg_button = gr.Button("Run Registration", visible=False)
|
496 |
+
|
497 |
+
with gr.Column(visible=False) as plot_row:
|
498 |
+
gr.Markdown("### Step 5: Quantitative results of the mapping model.")
|
499 |
+
prob_plot = gr.BarPlot(prob_df, x="labels", y="values", title="Cell Type Probabilities", scroll_to_output=True, x_label_angle=-90, height=400)
|
500 |
+
gr.Markdown("### Step 6: Download Results.")
|
501 |
+
download_button = gr.DownloadButton(label="Download Results", value='Cell_types_predictions.csv')
|
502 |
+
|
503 |
+
nifti_file.change(
|
504 |
+
load_nifti_or_png,
|
505 |
+
inputs=[nifti_file, sample_type, data_type],
|
506 |
+
outputs=[image_display, slice_slider, gallery, slice_row, plot_row, reg_button]
|
507 |
+
)
|
508 |
+
|
509 |
+
sample_type.change(
|
510 |
+
lambda s: (gr.update(visible=True), gr.update(visible=(s == "Custom Sample"))),
|
511 |
+
inputs=sample_type,
|
512 |
+
outputs=[slice_row, reg_button]
|
513 |
+
)
|
514 |
+
|
515 |
+
|
516 |
+
data_type.change(
|
517 |
+
handle_data_type_change,
|
518 |
+
inputs=data_type,
|
519 |
+
outputs=slice_slider
|
520 |
+
)
|
521 |
+
|
522 |
+
gallery.select(on_select, None, None)
|
523 |
+
|
524 |
+
slice_slider.change(update_slice, inputs=slice_slider, outputs=[image_display, prob_plot, gallery])
|
525 |
+
run_button.click(run_mapping, outputs=[prob_plot, plot_row])
|
526 |
+
reg_button.click(run_registration,inputs = [data_type], outputs=[gallery])
|
527 |
+
|
528 |
+
demo.launch()
|