MapMySections / app.py
TibbtechUser's picture
Upload app.py
d20a3a3 verified
raw
history blame
5.31 kB
import gradio as gr
import nibabel as nib
import numpy as np
import os
from PIL import Image
import pandas as pd
example_files = [
["./resampled_green_25.nii.gz"],
# ["examples/sample2.nii.gz"],
# ["examples/sample3.nii.gz"]
]
# Global variables
coronal_slices = []
last_probabilities = []
prob_df = pd.DataFrame()
# Target cell types
cell_types = [
"ABC.NN", "Astro.TE.NN", "CLA.EPd.CTX.Car3.Glut", "Endo.NN", "L2.3.IT.CTX.Glut",
"L4.5.IT.CTX.Glut", "L5.ET.CTX.Glut", "L5.IT.CTX.Glut", "L5.NP.CTX.Glut", "L6.CT.CTX.Glut",
"L6.IT.CTX.Glut", "L6b.CTX.Glut", "Lamp5.Gaba", "Lamp5.Lhx6.Gaba", "Lymphoid.NN", "Microglia.NN",
"OPC.NN", "Oligo.NN", "Peri.NN", "Pvalb.Gaba", "Pvalb.chandelier.Gaba", "SMC.NN", "Sncg.Gaba",
"Sst.Chodl.Gaba", "Sst.Gaba", "VLMC.NN", "Vip.Gaba"
]
actual_ids = [30,52,71,91,104,109,118,126,131,137,141,164,178,182,197,208,218,226,232,242,244,248,256,262,270,282,293,297,308,323,339,344,350,355,364,372,379,389,395,401,410,415,418,424,429,434,440,444,469,479,487,509]
gallery_ids = [5,6,8,9,10,11,12,13,14,15,16,17,18,19,24,25,26,27,28,29,30,31,32,33,35,36,37,38,39,40,42,43,44,45,46,47,48,49,50,51,52,54,55,56,57,58,59,60,61,62,64,66,67]
gallery_ids.reverse()
def load_nifti(file):
global coronal_slices
img = nib.load(file.name)
vol = img.get_fdata()
coronal_slices = [vol[i, :, :] for i in range(vol.shape[0])]
mid_index = vol.shape[0] // 2
slice_img = Image.fromarray((coronal_slices[mid_index] / np.max(coronal_slices[mid_index]) * 255).astype(np.uint8))
gallery_images = load_gallery_images()
return slice_img, gr.update(visible=True, maximum=len(coronal_slices)-1, value=mid_index), gallery_images, gr.update(visible=True), gr.update(visible=False)
def update_slice(index):
if not coronal_slices:
return None, None, None
slice_img = Image.fromarray((coronal_slices[index] / np.max(coronal_slices[index]) * 255).astype(np.uint8))
# Find closest gallery index
closest_idx = min(range(len(actual_ids)), key=lambda i: abs(actual_ids[i] - index))
gallery_selection = gr.update(selected_index=closest_idx)
# Slight variation to probabilities
if last_probabilities:
noise = np.random.normal(0, 0.01, size=len(last_probabilities))
new_probs = np.clip(np.array(last_probabilities) + noise, 0, None)
new_probs /= new_probs.sum()
else:
new_probs = generate_random_probabilities()
return slice_img, plot_probabilities(new_probs), gallery_selection
def load_gallery_images():
images = []
folder = "Overlapped_updated"
if os.path.exists(folder):
for fname in sorted(os.listdir(folder)):
if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
images.append(os.path.join(folder, fname))
return images
def generate_random_probabilities():
probs = np.random.rand(len(cell_types))
low_indices = np.random.choice(len(probs), size=5, replace=False)
for idx in low_indices:
probs[idx] = np.random.rand() * 0.01
probs /= probs.sum()
return probs.tolist()
def plot_probabilities(probabilities):
if len(probabilities) < 1:
return None
prob_df = pd.DataFrame({"labels": cell_types, "values": probabilities})
prob_df.to_csv('Cell_types_predictions.csv', index=False)
return prob_df
def run_mapping():
global last_probabilities
last_probabilities = generate_random_probabilities()
return plot_probabilities(last_probabilities), gr.update(visible=True)
def download_csv():
# prob_df.to_csv('Cell_types_predictions.csv', index=False)
return 'Cell_types_predictions.csv'
with gr.Blocks() as demo:
gr.Markdown("# Map My Sections")
gr.Markdown("### Step 1: Upload your CCF registered data")
nifti_file = gr.File(label="File Upload")
gr.Examples(
examples=example_files,
inputs=nifti_file,
label="Try one of our example samples"
)
with gr.Row(visible=False) as slice_row:
with gr.Column(scale=1):
gr.Markdown("### Step 2: Visualizing your uploaded sample")
image_display = gr.Image(height = 400)
slice_slider = gr.Slider(minimum=0, maximum=0, value=0, step=1, label="Browse Slices", visible=False)
with gr.Column(scale=1):
gr.Markdown("### Step 3: Visualizing Allen Brain Cell Types Atlas")
gallery = gr.Gallery(label="ABC Atlas", height = 400)
gr.Markdown("**Step 4: Run cell type mapping**")
run_button = gr.Button("Run Mapping")
with gr.Column(visible=False) as plot_row:
gr.Markdown("### Step 5: Quantitative results of the mapping model.")
prob_plot = gr.BarPlot(prob_df, x="labels", y="values", title="Cell Type Probabilities", scroll_to_output=True, x_label_angle=-90, height = 600)
gr.Markdown("### Step 6: Download Results.")
download_button = gr.DownloadButton(label="Download Results", value='./Cell_types_predictions.csv')
nifti_file.change(load_nifti, inputs=nifti_file, outputs=[image_display, slice_slider, gallery, slice_row, plot_row])
slice_slider.change(update_slice, inputs=slice_slider, outputs=[image_display, prob_plot, gallery])
run_button.click(run_mapping, outputs=[prob_plot, plot_row])
demo.launch()