Spaces:
Running
Running
import gradio as gr | |
import nibabel as nib | |
import numpy as np | |
import os | |
from PIL import Image | |
import pandas as pd | |
import nrrd | |
import ants | |
from natsort import natsorted | |
from scipy.ndimage import zoom, rotate | |
import matplotlib.pyplot as plt | |
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
import torchvision.transforms as transforms | |
from sklearn.metrics.pairwise import cosine_similarity | |
from skimage.transform import resize | |
import cv2 | |
def square_padd(original_data, square_size=(120,152, 184), order = 1): | |
# e.g. square_size = 256 by default | |
# takes a raw image as input | |
# returns a square (padded) image as output | |
# order = [int(x-1) for x in ss.rankdata(original_data.shape)] | |
# # print(order) | |
# data = original_data.transpose(order) | |
data= original_data | |
# print(original_data.shape) | |
# print(data.shape) | |
if data.shape[1]>data.shape[0] and data.shape[1]>data.shape[2]: # width>height | |
scale_percent = (square_size[1]/data.shape[1])*100 | |
# print("dim1") | |
elif data.shape[2]>data.shape[0] and data.shape[2]>data.shape[1]: # width>height | |
scale_percent = (square_size[2]/data.shape[2])*100 | |
# print("dim2") | |
else: # width<height | |
scale_percent = (square_size[0]/data.shape[0])*100 | |
scale_percent = int(scale_percent) | |
# print(scale_percent) | |
width = int(data.shape[0] * scale_percent / 100); height = int(data.shape[1] * scale_percent / 100); depth = int(data.shape[2] * scale_percent / 100); | |
dim = (width, height, depth) | |
# print(dim) | |
zoomFactors = [square_size_axis/float(data_shape) for data_shape, square_size_axis in zip(data.shape, square_size)] | |
sect_mask = zoom(data,zoom = zoomFactors, order = order, ) | |
# sect_mask = zoom(data,(scale_percent/100, scale_percent/100, scale_percent/100), order = order, ) | |
# sect_mask = cv2.resize(data, dim, interpolation = cv2.INTER_AREA) | |
sect_padd = (np.ones(square_size))*data[0,0,0] | |
sect_padd[int((square_size[0]-np.shape(sect_mask)[0])/2):int((square_size[0]-np.shape(sect_mask)[0])/2)+np.shape(sect_mask)[0], | |
int((square_size[1]-np.shape(sect_mask)[1])/2):int((square_size[1]-np.shape(sect_mask)[1])/2)+np.shape(sect_mask)[1], | |
int((square_size[2]-np.shape(sect_mask)[2])/2):int((square_size[2]-np.shape(sect_mask)[2])/2)+np.shape(sect_mask)[2]] = sect_mask | |
return sect_padd | |
def square_padding_RGB(single_RGB,square_size=256): | |
# e.g. square_size = 256 by default | |
# takes a raw image as input | |
# returns a square (padded) image as output | |
# input: 2D image | |
# output: 2D resized padded image | |
# example: BNI images, HMS data | |
if single_RGB.shape[1]>single_RGB.shape[0]: # width>height | |
scale_percent = (square_size/single_RGB.shape[1])*100 | |
else: # width<height | |
scale_percent = (square_size/single_RGB.shape[0])*100 | |
width = int(single_RGB.shape[1] * scale_percent / 100); height = int(single_RGB.shape[0] * scale_percent / 100); dim = (width, height) | |
sect_mask = cv2.resize(single_RGB, dim, interpolation = cv2.INTER_AREA) | |
sect_padd = (np.ones((square_size,square_size,3)))*np.mean(single_RGB[:10,:10]) | |
sect_padd[int((square_size-np.shape(sect_mask)[0])/2):int((square_size-np.shape(sect_mask)[0])/2)+np.shape(sect_mask)[0], | |
int((square_size-np.shape(sect_mask)[1])/2):int((square_size-np.shape(sect_mask)[1])/2)+np.shape(sect_mask)[1],:] = sect_mask | |
return sect_padd | |
def square_padding(single_gray,square_size=256): | |
# e.g. square_size = 256 by default | |
# takes a raw image as input | |
# returns a square (padded) image as output | |
# input: 2D image | |
# output: 2D resized padded image | |
# example: BNI images, HMS data | |
if len(np.shape(single_gray))>2: | |
return square_padding_RGB(single_gray[:,:,:3]) | |
else: | |
# print("Single gray shape:", np.shape(single_gray)) | |
if single_gray.shape[1]>single_gray.shape[0]: # width>height | |
scale_percent = (square_size/single_gray.shape[1])*100 | |
else: # width<height | |
scale_percent = (square_size/single_gray.shape[0])*100 | |
width = int(single_gray.shape[1] * scale_percent / 100); height = int(single_gray.shape[0] * scale_percent / 100); dim = (width, height) | |
# print("Dim::", dim) | |
sect_mask = cv2.resize(single_gray, dim, interpolation = cv2.INTER_AREA) | |
sect_padd = (np.zeros((square_size,square_size)))*single_gray[-20,-20]#find a better solution for single_gray[100,-100] | |
sect_padd[int((square_size-np.shape(sect_mask)[0])/2):int((square_size-np.shape(sect_mask)[0])/2)+np.shape(sect_mask)[0], | |
int((square_size-np.shape(sect_mask)[1])/2):int((square_size-np.shape(sect_mask)[1])/2)+np.shape(sect_mask)[1]] = sect_mask | |
return sect_padd | |
def affine_reg(fixed_image,moving_image,gauss_param=100): | |
# this function takes fixed and moving images as input and return affine transformation matrix | |
# fixed/moving images can be 2D/3D | |
# todo: add an option as flag to save the transformation matrix and displacement fields at the desired location to be able to apply the transforms later | |
mytx = ants.registration(fixed=fixed_image, | |
moving=moving_image, | |
type_of_transform='Affine', | |
reg_iterations = (gauss_param,gauss_param,gauss_param,gauss_param)) | |
print('affine registration completed') | |
return mytx | |
def nonrigid_reg(fixed_image,mytx,type_of_transform='SyN',grad_step=0.25,reg_iterations=(50,50,50, ),flow_sigma=9,total_sigma=0.2): | |
# this function takes fixed image and affined tx matrix as input and return non-rigid transformation matrix | |
# fixed/moving images can be 2D/3D | |
# type of transform selection: https://antspy.readthedocs.io/en/latest/registration.html | |
# todo: scale the function to incorporate the extended parameters for type_of_transform | |
# todo: scale the function to incorporate the affine+non-rigid simultaneously in case of SyNRA | |
transform_type = {'SyN':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma}, | |
'SyNRA':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma}} | |
mytx_non_rigid = ants.registration(fixed = fixed_image, | |
moving=mytx['warpedmovout'], | |
type_of_transform=type_of_transform, | |
grad_step=transform_type[type_of_transform]['grad_step'], | |
reg_iterations=transform_type[type_of_transform]['reg_iterations'], | |
flow_sigma=transform_type[type_of_transform]['flow_sigma'], | |
total_sigma=transform_type[type_of_transform]['total_sigma']) | |
print('non-rigid registration completed') | |
return mytx_non_rigid | |
def affine_reg(fixed_image,moving_image,gauss_param=100): | |
# this function takes fixed and moving images as input and return affine transformation matrix | |
# fixed/moving images can be 2D/3D | |
# todo: add an option as flag to save the transformation matrix and displacement fields at the desired location to be able to apply the transforms later | |
mytx = ants.registration(fixed=fixed_image, | |
moving=moving_image, | |
type_of_transform='Affine', | |
reg_iterations = (gauss_param,gauss_param,gauss_param,gauss_param)) | |
print('affine registration completed') | |
return mytx | |
def nonrigid_reg(fixed_image,mytx,type_of_transform='SyN',grad_step=0.25,reg_iterations=(50,50,50, ),flow_sigma=9,total_sigma=0.2): | |
# this function takes fixed image and affined tx matrix as input and return non-rigid transformation matrix | |
# fixed/moving images can be 2D/3D | |
# type of transform selection: https://antspy.readthedocs.io/en/latest/registration.html | |
# todo: scale the function to incorporate the extended parameters for type_of_transform | |
# todo: scale the function to incorporate the affine+non-rigid simultaneously in case of SyNRA | |
transform_type = {'SyN':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma}, | |
'SyNRA':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma}} | |
mytx_non_rigid = ants.registration(fixed = fixed_image, | |
moving=mytx['warpedmovout'], | |
type_of_transform=type_of_transform, | |
grad_step=transform_type[type_of_transform]['grad_step'], | |
reg_iterations=transform_type[type_of_transform]['reg_iterations'], | |
flow_sigma=transform_type[type_of_transform]['flow_sigma'], | |
total_sigma=transform_type[type_of_transform]['total_sigma']) | |
print('non-rigid registration completed') | |
return mytx_non_rigid | |
def run_3D_registration(user_section, ): | |
global allen_atlas_ccf, allen_template_ccf | |
template_atlas = allen_atlas_ccf | |
template_section = allen_template_ccf | |
template_atlas = np.uint16(template_atlas*255) | |
user_section = square_padd(user_section, (60, 76, 92)) | |
template_atlas = square_padd(template_atlas, user_section.shape) | |
template_section = square_padd(template_section, user_section.shape) | |
fixed_image = ants.from_numpy(user_section) | |
moving_atlas_ants = ants.from_numpy(template_atlas) | |
moving_image = ants.from_numpy(template_section) | |
mytx = affine_reg(fixed_image,moving_image) | |
mytx_non_rigid = nonrigid_reg(fixed_image,mytx) | |
affined_fixed_atlas = ants.apply_transforms(fixed=fixed_image, | |
moving=moving_image, | |
transformlist=mytx['fwdtransforms'], | |
interpolator='nearestNeighbor') | |
nonrigid_fixed_atlas = ants.apply_transforms(fixed=fixed_image, | |
moving=affined_fixed_atlas, | |
transformlist=mytx_non_rigid['fwdtransforms'], | |
interpolator='nearestNeighbor') | |
gallery_images = natsorted(load_gallery_images()) | |
transformed_images = [] | |
if not(os.path.exists("Overlaped_registered")): | |
os.mkdir("Overlaped_registered") | |
# registered = nonrigid_fixed_atlas.numpy()/255 | |
# for id in list(range((registered.shape[0]//2)-15, (registered.shape[0]//2)+15, 2)): | |
# print(id) | |
# plt.imsave(f'Overlaped_registered/{id}.png',registered[id, :, :], cmap = 'gray' ) | |
# transformed_images.append(f'Overlaped_registered/{id}.png') | |
for i in range(len(gallery_images)-10): | |
im = plt.imread(gallery_images[i]) | |
fname = os.path.split(gallery_images[i])[-1] | |
moving_image_slice = ants.from_numpy(square_padding(gray_scale(im))) | |
affined_fixed_atlas = ants.apply_transforms(fixed=fixed_image, | |
moving=moving_image, | |
transformlist=mytx['fwdtransforms'], | |
interpolator='nearestNeighbor') | |
nonrigid_fixed_atlas = ants.apply_transforms(fixed=fixed_image, | |
moving=affined_fixed_atlas, | |
transformlist=mytx_non_rigid['fwdtransforms'], | |
interpolator='nearestNeighbor') | |
# print(im.shape, nonrigid_fixed_atlas.numpy().shape) | |
reconverted_img = reconvert_to_rgb(im[:,:,:3], nonrigid_fixed_atlas.numpy()[i,:,:]) | |
plt.imsave(f'Overlaped_registered/{fname}',(reconverted_img * 255).astype(np.uint8)) | |
transformed_images.append(f'Overlaped_registered/{fname}') | |
transformed_images = natsorted(load_gallery_images()) | |
return transformed_images | |
def run_2D_registration(user_section, slice_idx): | |
global allen_atlas_ccf, allen_template_ccf, gallery_selected_data | |
template_atlas = allen_atlas_ccf | |
template_section = allen_template_ccf | |
template_atlas = allen_atlas_ccf[slice_idx,:,:] | |
template_section = allen_template_ccf[slice_idx,:,:] | |
# colored_atlas = colored_atlas[slice_idx,:,:] | |
print(np.shape(template_atlas), np.shape(template_section)) | |
user_section = square_padding(user_section) | |
template_atlas = np.uint16(template_atlas*255) | |
template_atlas = square_padding(template_atlas) | |
template_section = square_padding(template_section) | |
fixed_image = ants.from_numpy(user_section) | |
moving_atlas_ants = ants.from_numpy(template_atlas) | |
moving_image = ants.from_numpy(template_section) | |
mytx = affine_reg(fixed_image,moving_image) | |
mytx_non_rigid = nonrigid_reg(fixed_image,mytx) | |
gallery_imgs = natsorted(load_gallery_images()) | |
im = plt.imread(gallery_imgs[gallery_selected_data]) | |
print(im.shape) | |
moving_gallery_img = ants.from_numpy(square_padding(gray_scale(im))) | |
affined_fixed_atlas = ants.apply_transforms(fixed=fixed_image, | |
moving=moving_gallery_img, | |
transformlist=mytx['fwdtransforms'], | |
interpolator='nearestNeighbor') | |
nonrigid_fixed_atlas = ants.apply_transforms(fixed=fixed_image, | |
moving=affined_fixed_atlas, | |
transformlist=mytx_non_rigid['fwdtransforms'], | |
interpolator='nearestNeighbor') | |
gallery_images = load_gallery_images() | |
transformed_images = [] | |
if not(os.path.exists("Overlaped_registered")): | |
os.mkdir("Overlaped_registered") | |
print("Reconverting Image") | |
reconverted_img = reconvert_to_rgb(im[:,:,:3], nonrigid_fixed_atlas.numpy()) | |
plt.imsave(f'Overlaped_registered/registered_slice_reconverted_1.png',(reconverted_img * 255).astype(np.uint8)) | |
return ['Overlaped_registered/registered_slice_reconverted_1.png'] | |
def reconvert_to_rgb(img_rgb, img_gray_processed): | |
# 3. Resize original RGB to match processed grayscale shape | |
original_shape = img_gray_processed.shape | |
img_rgb_resized = resize(img_rgb, (original_shape[0], original_shape[1]), preserve_range=True) | |
# 4. Convert resized RGB to grayscale | |
gray_resized = np.mean(img_rgb_resized, axis=2) + 1e-8 # avoid divide-by-zero | |
# 5. Compute ratio of new_gray / old_gray, apply to RGB channels | |
ratio = img_gray_processed / gray_resized | |
img_recolored = img_rgb_resized * ratio[..., np.newaxis] | |
# 6. Clip values to [0, 1] if image is in float format (common for imread) | |
# img_recolored = np.clip(img_recolored, 0, 1) | |
return img_recolored | |
def embeddings_classifier(user_section, atlas_embeddings,atlas_labels): | |
class SliceEncoder(nn.Module): | |
def __init__(self): | |
super(SliceEncoder, self).__init__() | |
base = models.resnet18(pretrained=True) | |
self.backbone = nn.Sequential(*list(base.children())[:-1]) # Remove final FC layer | |
def forward(self, x): | |
x = self.backbone(x) # Output shape: (B, 512, 1, 1) | |
return x.view(x.size(0), -1) # Flatten to (B, 512) | |
# Transform | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
# Feature extraction utility | |
def extract_embedding(img_array, encoder, transform): | |
img = Image.fromarray(((img_array) * 255).astype(np.uint8)).convert('RGB') | |
img_tensor = transform(img).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
embedding = encoder(img_tensor) | |
return embedding.cpu().numpy().flatten() | |
# Prepare device and model | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
encoder = SliceEncoder().to(device).eval() | |
# Precompute atlas embeddings | |
query_emb = extract_embedding(user_section, encoder, transform).reshape(1, -1) | |
sims = cosine_similarity(query_emb, atlas_embeddings)[0] | |
pred_idx = np.argmax(sims) | |
pred_gt = atlas_labels[pred_idx] | |
return int(pred_gt) | |
def gray_scale(image): | |
# input: a 2D RGB image (x,y,z) | |
# output: a grayscale image (x,y) | |
# todo: fix the depth issue of pixels | |
if len(np.shape(image))>2: | |
return cv2.cvtColor(image[:,:,:3], cv2.COLOR_RGB2GRAY) | |
else: | |
return image | |
def atlas_slice_prediction(user_section, axis = 'coronal'): | |
user_section = gray_scale(square_padding(gray_scale(user_section))) | |
user_section = gray_scale(user_section) | |
user_section = square_padding(user_section, 224) | |
user_section = (user_section - np.min(user_section))/((np.max(user_section) - np.min(user_section))) | |
print("Loading model") | |
atlas_embeddings = np.load(f"registration/atlas_embeddings_{axis}.npy") | |
atlas_labels = np.load(f"registration/atlas_labels_{axis}.npy") | |
idx = embeddings_classifier(user_section, atlas_embeddings,atlas_labels) | |
return idx | |
example_files = [ | |
["./resampled_green_25.nii.gz", "CCF registered Sample", "3D"], | |
["./Brain_1.png", "Custom Sample", "2D"], | |
# ["examples/sample3.nii.gz"] | |
] | |
# Global variables | |
coronal_slices = [] | |
last_probabilities = [] | |
prob_df = pd.DataFrame() | |
vol = None | |
slice_idx = None | |
# Target cell types | |
cell_types = [ | |
"ABC.NN", "Astro.TE.NN", "CLA.EPd.CTX.Car3.Glut", "Endo.NN", "L2.3.IT.CTX.Glut", | |
"L4.5.IT.CTX.Glut", "L5.ET.CTX.Glut", "L5.IT.CTX.Glut", "L5.NP.CTX.Glut", "L6.CT.CTX.Glut", | |
"L6.IT.CTX.Glut", "L6b.CTX.Glut", "Lamp5.Gaba", "Lamp5.Lhx6.Gaba", "Lymphoid.NN", "Microglia.NN", | |
"OPC.NN", "Oligo.NN", "Peri.NN", "Pvalb.Gaba", "Pvalb.chandelier.Gaba", "SMC.NN", "Sncg.Gaba", | |
"Sst.Chodl.Gaba", "Sst.Gaba", "VLMC.NN", "Vip.Gaba" | |
] | |
actual_ids = [30,52,71,91,104,109,118,126,131,137,141,164,178,182,197,208,218,226,232,242,244,248,256,262,270,282,293,297,308,323,339,344,350,355,364,372,379,389,395,401,410,415,418,424,429,434,440,444,469,479,487,509] | |
gallery_ids = [5,6,8,9,10,11,12,13,14,15,16,17,18,19,24,25,26,27,28,29,30,31,32,33,35,36,37,38,39,40,42,43,44,45,46,47,48,49,50,51,52,54,55,56,57,58,59,60,61,62,64,66,67] | |
# gallery_ids.reverse() | |
allen_atlas_ccf, header = nrrd.read('./registration/annotation_25.nrrd') | |
allen_template_ccf, _ = nrrd.read("./registration/average_template_25.nrrd") | |
# colored_atlas,_ = nrrd.read('./registration/colored_atlas_turbo.nrrd') | |
gallery_selected_data = None | |
def load_nifti_or_png(file, sample_type, data_type): | |
global coronal_slices, vol, slice_idx, gallery_selected_data | |
if file.name.endswith(".nii") or file.name.endswith(".nii.gz"): | |
img = nib.load(file.name) | |
vol = img.get_fdata() | |
coronal_slices = [vol[i, :, :] for i in range(vol.shape[0])] | |
if data_type == "2D": | |
mid_index = vol.shape[0] // 2 | |
slice_img = Image.fromarray((coronal_slices[mid_index] / np.max(coronal_slices[mid_index]) * 255).astype(np.uint8)) | |
gallery_images = load_gallery_images() | |
return ( | |
slice_img, | |
gr.update(visible=False), | |
gallery_images, | |
gr.update(visible=True), | |
gr.update(visible=True), | |
gr.update(visible=(sample_type == "Custom Sample")) | |
) | |
else: # 3D with actual_ids only | |
coronal_slices = [vol[i, :, :] for i in actual_ids] | |
idx = len(actual_ids) // 2 # Mid of actual_ids | |
slice_img = Image.fromarray((coronal_slices[idx] / np.max(coronal_slices[idx]) * 255).astype(np.uint8)) | |
gallery_images = load_gallery_images() | |
gallery_images = natsorted(gallery_images) | |
return ( | |
slice_img, | |
gr.update(visible=True, minimum=0, maximum=len(coronal_slices)-1, value=idx), | |
gallery_images, | |
gr.update(visible=True), | |
gr.update(visible=True), | |
gr.update(visible=(sample_type == "Custom Sample")) | |
) | |
else: | |
img = Image.open(file.name).convert("L") | |
vol = np.array(img) | |
coronal_slices = [np.array(img)] | |
gallery_images = natsorted(load_gallery_images()) | |
idx = atlas_slice_prediction(np.array(img)) | |
slice_idx = idx | |
closest_actual_idx = min(actual_ids, key=lambda x: abs(x - idx)) | |
gallery_index = actual_ids.index(closest_actual_idx) | |
print(gallery_index, len(actual_ids) -(gallery_index)) | |
gallery_selected_data = len(actual_ids) -(gallery_index) | |
return ( | |
img, | |
gr.update(visible=False), | |
gr.update(selected_index=len(actual_ids) -(gallery_index) if gallery_index < len(gallery_ids) else 0, visible = True), | |
# gr.update(value=gallery_images, selected_index=len(actual_ids) -(gallery_index)), # gallery | |
gr.update(visible=True), | |
gr.update(visible=True), | |
gr.update(visible=(sample_type == "Custom Sample")) | |
) | |
def update_slice(index): | |
if not coronal_slices: | |
return None, None, None | |
slice_img = Image.fromarray((coronal_slices[index] / np.max(coronal_slices[index]) * 255).astype(np.uint8)) | |
gallery_selection = gr.update(selected_index=len(gallery_ids) - index if index < len(gallery_ids) else 0) | |
if last_probabilities: | |
noise = np.random.normal(0, 0.01, size=len(last_probabilities)) | |
new_probs = np.clip(np.array(last_probabilities) + noise, 0, None) | |
new_probs /= new_probs.sum() | |
else: | |
new_probs = [] | |
return slice_img, plot_probabilities(new_probs), gallery_selection | |
def load_gallery_images(): | |
folder = "Overlapped_updated" | |
images = [] | |
if os.path.exists(folder): | |
for fname in sorted(os.listdir(folder)): | |
if fname.lower().endswith(('.png', '.jpg', '.jpeg')): | |
images.append(os.path.join(folder, fname)) | |
return images | |
def generate_random_probabilities(): | |
probs = np.random.rand(len(cell_types)) | |
low_indices = np.random.choice(len(probs), size=5, replace=False) | |
for idx in low_indices: | |
probs[idx] = np.random.rand() * 0.01 | |
probs /= probs.sum() | |
return probs.tolist() | |
def plot_probabilities(probabilities): | |
if len(probabilities) < 1: | |
return None | |
prob_df = pd.DataFrame({"labels": cell_types, "values": probabilities}) | |
os.makedirs("outputs", exist_ok=True) | |
prob_df.to_csv('outputs/Cell_types_predictions.csv', index=False) | |
return prob_df | |
def run_mapping(): | |
global last_probabilities | |
last_probabilities = generate_random_probabilities() | |
return plot_probabilities(last_probabilities), gr.update(visible=True), gr.update(value = 'outputs/Cell_types_predictions.csv', visible = True), gr.update(visible=True) | |
def run_registration(data_type, selected_idx): | |
global vol, slice_idx | |
print("Running registration logic here..., Vol shape::", vol.shape) | |
if data_type == "3D": | |
gallery_images = run_3D_registration(vol) | |
else: | |
gallery_images = run_2D_registration(vol, slice_idx) | |
return gallery_images | |
return "Registration complete!" | |
def download_csv(): | |
return 'outputs/Cell_types_predictions.csv' | |
def handle_data_type_change(dt): | |
if dt == "2D": | |
return gr.update(visible=False) | |
else: | |
return gr.update(visible=True, minimum=0, maximum=len(actual_ids)-1, value=len(actual_ids)//2) | |
def on_select(evt: gr.SelectData): | |
print("Selected index:", evt) | |
print("Selected value:", evt.value) | |
print("Selected coordinates:", evt.selected) | |
gallery_selected_data = evt.index | |
gallery_images = natsorted(load_gallery_images()) | |
with gr.Blocks() as demo: | |
gr.Markdown("# Map My Sections\n### This GUI is part of the submission to the Allen Institute's Map My Sections tool by Tibbling Technologies.") | |
with gr.Row(): | |
gr.Markdown("### Step 1: Upload your sample, currently only .nii.gz (3D) and .png (2D) supported") | |
gr.Markdown("### Step 2: Select your sample and data type.") | |
with gr.Row(): | |
nifti_file = gr.File(label="File Upload") | |
with gr.Column(): | |
sample_type = gr.Dropdown(choices=["CCF registered Sample", "Custom Sample"], value="CCF registered Sample", label="Sample Type") | |
data_type = gr.Radio(choices=["2D", "3D"], value="3D", label="Data Type") | |
gr.Examples(examples=example_files, inputs=[nifti_file, sample_type, data_type], label="Try one of our example samples") | |
with gr.Row(visible=False) as slice_row: | |
with gr.Column(scale=1): | |
gr.Markdown("### Step 3: Visualizing your uploaded sample") | |
image_display = gr.Image(height=450) | |
slice_slider = gr.Slider(minimum=0, maximum=0, value=0, step=1, label="Slices", visible=False) | |
with gr.Column(scale=1): | |
gr.Markdown("### Step 4: Visualizing Allen Brain Cell Types Atlas") | |
gallery = gr.Gallery(label="ABC Atlas", value = gallery_images,columns = 5, height = 450) | |
gr.Markdown("### Step 5: Run cell type mapping and/or registeration. ") | |
with gr.Row(): | |
run_button = gr.Button("Map My Sections") | |
reg_button = gr.Button("Run Registration (Optional)", visible=False) | |
with gr.Column(visible=False) as plot_row: | |
gr.Markdown("### Step 6: Quantitative results of the mapping model.") | |
prob_plot = gr.BarPlot(prob_df, x="labels", y="values", title="Cell Type Probabilities", scroll_to_output=True, x_label_angle=-90, height=400) | |
download_step = gr.Markdown("### Step 7: Download Results.", visible = False) | |
download_button = gr.DownloadButton(label="Download Results", visible = False) | |
nifti_file.change( | |
load_nifti_or_png, | |
inputs=[nifti_file, sample_type, data_type], | |
outputs=[image_display, slice_slider, gallery, slice_row, plot_row, reg_button] | |
) | |
sample_type.change( | |
lambda s: (gr.update(visible=True), gr.update(visible=(s == "Custom Sample"))), | |
inputs=sample_type, | |
outputs=[slice_row, reg_button] | |
) | |
data_type.change( | |
handle_data_type_change, | |
inputs=data_type, | |
outputs=slice_slider | |
) | |
gallery.select(on_select, inputs=None, outputs=None) | |
slice_slider.change(update_slice, inputs=slice_slider, outputs=[image_display, prob_plot, gallery]) | |
run_button.click(run_mapping, outputs=[prob_plot, plot_row, download_button, download_step]) | |
reg_button.click(run_registration,inputs = [data_type], outputs=[gallery]) | |
demo.launch() |