import gradio as gr import nibabel as nib import numpy as np import os from PIL import Image import pandas as pd import nrrd import ants from natsort import natsorted from scipy.ndimage import zoom, rotate import matplotlib.pyplot as plt import torch import torch.nn as nn import torchvision.models as models import torchvision.transforms as transforms from sklearn.metrics.pairwise import cosine_similarity import cv2 def square_padd(original_data, square_size=(120,152, 184), order = 1): # e.g. square_size = 256 by default # takes a raw image as input # returns a square (padded) image as output # order = [int(x-1) for x in ss.rankdata(original_data.shape)] # # print(order) # data = original_data.transpose(order) data= original_data # print(original_data.shape) # print(data.shape) if data.shape[1]>data.shape[0] and data.shape[1]>data.shape[2]: # width>height scale_percent = (square_size[1]/data.shape[1])*100 # print("dim1") elif data.shape[2]>data.shape[0] and data.shape[2]>data.shape[1]: # width>height scale_percent = (square_size[2]/data.shape[2])*100 # print("dim2") else: # widthsingle_RGB.shape[0]: # width>height scale_percent = (square_size/single_RGB.shape[1])*100 else: # width2: return square_padding_RGB(single_gray[:,:,:3]) else: # print("Single gray shape:", np.shape(single_gray)) if single_gray.shape[1]>single_gray.shape[0]: # width>height scale_percent = (square_size/single_gray.shape[1])*100 else: # width2: return cv2.cvtColor(image[:,:,:3], cv2.COLOR_RGB2GRAY) else: return image def atlas_slice_prediction(user_section, axis = 'coronal'): user_section = gray_scale(square_padding(gray_scale(user_section))) user_section = gray_scale(user_section) user_section = square_padding(user_section, 224) user_section = (user_section - np.min(user_section))/((np.max(user_section) - np.min(user_section))) print("Loading model") atlas_embeddings = np.load(f"registration/atlas_embeddings_{axis}.npy") atlas_labels = np.load(f"registration/atlas_labels_{axis}.npy") idx = embeddings_classifier(user_section, atlas_embeddings,atlas_labels) return idx example_files = [ ["./resampled_green_25.nii.gz", "CCF registered Sample", "3D"], ["./Brain_1.png", "Custom Sample", "2D"], # ["examples/sample3.nii.gz"] ] # Global variables coronal_slices = [] last_probabilities = [] prob_df = pd.DataFrame() vol = None slice_idx = None # Target cell types cell_types = [ "ABC.NN", "Astro.TE.NN", "CLA.EPd.CTX.Car3.Glut", "Endo.NN", "L2.3.IT.CTX.Glut", "L4.5.IT.CTX.Glut", "L5.ET.CTX.Glut", "L5.IT.CTX.Glut", "L5.NP.CTX.Glut", "L6.CT.CTX.Glut", "L6.IT.CTX.Glut", "L6b.CTX.Glut", "Lamp5.Gaba", "Lamp5.Lhx6.Gaba", "Lymphoid.NN", "Microglia.NN", "OPC.NN", "Oligo.NN", "Peri.NN", "Pvalb.Gaba", "Pvalb.chandelier.Gaba", "SMC.NN", "Sncg.Gaba", "Sst.Chodl.Gaba", "Sst.Gaba", "VLMC.NN", "Vip.Gaba" ] actual_ids = [30,52,71,91,104,109,118,126,131,137,141,164,178,182,197,208,218,226,232,242,244,248,256,262,270,282,293,297,308,323,339,344,350,355,364,372,379,389,395,401,410,415,418,424,429,434,440,444,469,479,487,509] gallery_ids = [5,6,8,9,10,11,12,13,14,15,16,17,18,19,24,25,26,27,28,29,30,31,32,33,35,36,37,38,39,40,42,43,44,45,46,47,48,49,50,51,52,54,55,56,57,58,59,60,61,62,64,66,67] # gallery_ids.reverse() allen_atlas_ccf, header = nrrd.read('./registration/annotation_25.nrrd') allen_template_ccf, _ = nrrd.read("./registration/average_template_25.nrrd") # colored_atlas,_ = nrrd.read('./registration/colored_atlas_turbo.nrrd') gallery_selected_data = None def load_nifti_or_png(file, sample_type, data_type): global coronal_slices, vol, slice_idx, gallery_selected_data if file.name.endswith(".nii") or file.name.endswith(".nii.gz"): img = nib.load(file.name) vol = img.get_fdata() coronal_slices = [vol[i, :, :] for i in range(vol.shape[0])] if data_type == "2D": mid_index = vol.shape[0] // 2 slice_img = Image.fromarray((coronal_slices[mid_index] / np.max(coronal_slices[mid_index]) * 255).astype(np.uint8)) gallery_images = load_gallery_images() return ( slice_img, gr.update(visible=False), gallery_images, gr.update(visible=True), gr.update(visible=True), gr.update(visible=(sample_type == "Custom Sample")) ) else: # 3D with actual_ids only coronal_slices = [vol[i, :, :] for i in actual_ids] idx = len(actual_ids) // 2 # Mid of actual_ids slice_img = Image.fromarray((coronal_slices[idx] / np.max(coronal_slices[idx]) * 255).astype(np.uint8)) gallery_images = load_gallery_images() gallery_images = natsorted(gallery_images) return ( slice_img, gr.update(visible=True, minimum=0, maximum=len(coronal_slices)-1, value=idx), gallery_images, gr.update(visible=True), gr.update(visible=True), gr.update(visible=(sample_type == "Custom Sample")) ) else: img = Image.open(file.name).convert("L") vol = np.array(img) coronal_slices = [np.array(img)] gallery_images = natsorted(load_gallery_images()) idx = atlas_slice_prediction(np.array(img)) slice_idx = idx closest_actual_idx = min(actual_ids, key=lambda x: abs(x - idx)) gallery_index = actual_ids.index(closest_actual_idx) print(gallery_index, len(actual_ids) -(gallery_index)) gallery_selected_data = len(actual_ids) -(gallery_index) return ( img, gr.update(visible=False), gr.update(selected_index=len(actual_ids) -(gallery_index) if gallery_index < len(gallery_ids) else 0, visible = True), # gr.update(value=gallery_images, selected_index=len(actual_ids) -(gallery_index)), # gallery gr.update(visible=True), gr.update(visible=True), gr.update(visible=(sample_type == "Custom Sample")) ) def update_slice(index): if not coronal_slices: return None, None, None slice_img = Image.fromarray((coronal_slices[index] / np.max(coronal_slices[index]) * 255).astype(np.uint8)) gallery_selection = gr.update(selected_index=len(gallery_ids) - index if index < len(gallery_ids) else 0) if last_probabilities: noise = np.random.normal(0, 0.01, size=len(last_probabilities)) new_probs = np.clip(np.array(last_probabilities) + noise, 0, None) new_probs /= new_probs.sum() else: new_probs = [] return slice_img, plot_probabilities(new_probs), gallery_selection def load_gallery_images(): folder = "Overlapped_updated" images = [] if os.path.exists(folder): for fname in sorted(os.listdir(folder)): if fname.lower().endswith(('.png', '.jpg', '.jpeg')): images.append(os.path.join(folder, fname)) return images def generate_random_probabilities(): probs = np.random.rand(len(cell_types)) low_indices = np.random.choice(len(probs), size=5, replace=False) for idx in low_indices: probs[idx] = np.random.rand() * 0.01 probs /= probs.sum() return probs.tolist() def plot_probabilities(probabilities): if len(probabilities) < 1: return None prob_df = pd.DataFrame({"labels": cell_types, "values": probabilities}) prob_df.to_csv('Cell_types_predictions.csv', index=False) return prob_df def run_mapping(): global last_probabilities last_probabilities = generate_random_probabilities() return plot_probabilities(last_probabilities), gr.update(visible=True) def run_registration(data_type, selected_idx): global vol, slice_idx print("Running registration logic here..., Vol shape::", vol.shape) if data_type == "3D": gallery_images = run_3D_registration(vol) else: gallery_images = run_2D_registration(vol, slice_idx) return gallery_images return "Registration complete!" def download_csv(): return 'Cell_types_predictions.csv' def handle_data_type_change(dt): if dt == "2D": return gr.update(visible=False) else: return gr.update(visible=True, minimum=0, maximum=len(actual_ids)-1, value=len(actual_ids)//2) def on_select(evt: gr.SelectData): gallery_selected_data = evt.index gallery_images = natsorted(load_gallery_images()) with gr.Blocks() as demo: gr.Markdown("# Map My Sections") gr.Markdown("### Step 1: Upload your sample and choose type") with gr.Row(): nifti_file = gr.File(label="File Upload") with gr.Column(): sample_type = gr.Dropdown(choices=["CCF registered Sample", "Custom Sample"], value="CCF registered Sample", label="Sample Type") data_type = gr.Radio(choices=["2D", "3D"], value="3D", label="Data Type") gr.Examples(examples=example_files, inputs=[nifti_file, sample_type, data_type], label="Try one of our example samples") with gr.Row(visible=False) as slice_row: with gr.Column(scale=1): gr.Markdown("### Step 2: Visualizing your uploaded sample") image_display = gr.Image(height=450) slice_slider = gr.Slider(minimum=0, maximum=0, value=0, step=1, label="Slices", visible=False) with gr.Column(scale=1): gr.Markdown("### Step 3: Visualizing Allen Brain Cell Types Atlas") gallery = gr.Gallery(label="ABC Atlas", value = gallery_images,) gr.Markdown("**Step 4: Run cell type mapping**") with gr.Row(): run_button = gr.Button("Run Mapping") reg_button = gr.Button("Run Registration", visible=False) with gr.Column(visible=False) as plot_row: gr.Markdown("### Step 5: Quantitative results of the mapping model.") prob_plot = gr.BarPlot(prob_df, x="labels", y="values", title="Cell Type Probabilities", scroll_to_output=True, x_label_angle=-90, height=400) gr.Markdown("### Step 6: Download Results.") download_button = gr.DownloadButton(label="Download Results", value='Cell_types_predictions.csv') nifti_file.change( load_nifti_or_png, inputs=[nifti_file, sample_type, data_type], outputs=[image_display, slice_slider, gallery, slice_row, plot_row, reg_button] ) sample_type.change( lambda s: (gr.update(visible=True), gr.update(visible=(s == "Custom Sample"))), inputs=sample_type, outputs=[slice_row, reg_button] ) data_type.change( handle_data_type_change, inputs=data_type, outputs=slice_slider ) gallery.select(on_select, None, None) slice_slider.change(update_slice, inputs=slice_slider, outputs=[image_display, prob_plot, gallery]) run_button.click(run_mapping, outputs=[prob_plot, plot_row]) reg_button.click(run_registration,inputs = [data_type], outputs=[gallery]) demo.launch()