Spaces:
Running
Running
Taha Razzaq
commited on
Commit
·
cfd23c0
1
Parent(s):
bf71a90
Add application file
Browse files
app.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import nibabel as nib
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from PIL import Image
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
example_files = [
|
9 |
+
["./resampled_green_25.nii.gz"],
|
10 |
+
# ["examples/sample2.nii.gz"],
|
11 |
+
# ["examples/sample3.nii.gz"]
|
12 |
+
]
|
13 |
+
|
14 |
+
# Global variables
|
15 |
+
coronal_slices = []
|
16 |
+
last_probabilities = []
|
17 |
+
prob_df = pd.DataFrame()
|
18 |
+
|
19 |
+
# Target cell types
|
20 |
+
cell_types = [
|
21 |
+
"ABC.NN", "Astro.TE.NN", "CLA.EPd.CTX.Car3.Glut", "Endo.NN", "L2.3.IT.CTX.Glut",
|
22 |
+
"L4.5.IT.CTX.Glut", "L5.ET.CTX.Glut", "L5.IT.CTX.Glut", "L5.NP.CTX.Glut", "L6.CT.CTX.Glut",
|
23 |
+
"L6.IT.CTX.Glut", "L6b.CTX.Glut", "Lamp5.Gaba", "Lamp5.Lhx6.Gaba", "Lymphoid.NN", "Microglia.NN",
|
24 |
+
"OPC.NN", "Oligo.NN", "Peri.NN", "Pvalb.Gaba", "Pvalb.chandelier.Gaba", "SMC.NN", "Sncg.Gaba",
|
25 |
+
"Sst.Chodl.Gaba", "Sst.Gaba", "VLMC.NN", "Vip.Gaba"
|
26 |
+
]
|
27 |
+
|
28 |
+
actual_ids = [30,52,71,91,104,109,118,126,131,137,141,164,178,182,197,208,218,226,232,242,244,248,256,262,270,282,293,297,308,323,339,344,350,355,364,372,379,389,395,401,410,415,418,424,429,434,440,444,469,479,487,509]
|
29 |
+
gallery_ids = [5,6,8,9,10,11,12,13,14,15,16,17,18,19,24,25,26,27,28,29,30,31,32,33,35,36,37,38,39,40,42,43,44,45,46,47,48,49,50,51,52,54,55,56,57,58,59,60,61,62,64,66,67]
|
30 |
+
|
31 |
+
def load_nifti(file):
|
32 |
+
global coronal_slices
|
33 |
+
img = nib.load(file.name)
|
34 |
+
vol = img.get_fdata()
|
35 |
+
coronal_slices = [vol[i, :, :] for i in range(vol.shape[0])]
|
36 |
+
mid_index = vol.shape[0] // 2
|
37 |
+
slice_img = Image.fromarray((coronal_slices[mid_index] / np.max(coronal_slices[mid_index]) * 255).astype(np.uint8))
|
38 |
+
gallery_images = load_gallery_images()
|
39 |
+
return slice_img, gr.update(visible=True, maximum=len(coronal_slices)-1, value=mid_index), gallery_images, gr.update(visible=True), gr.update(visible=False)
|
40 |
+
|
41 |
+
def update_slice(index):
|
42 |
+
if not coronal_slices:
|
43 |
+
return None, None, None
|
44 |
+
slice_img = Image.fromarray((coronal_slices[index] / np.max(coronal_slices[index]) * 255).astype(np.uint8))
|
45 |
+
|
46 |
+
# Find closest gallery index
|
47 |
+
closest_idx = min(range(len(actual_ids)), key=lambda i: abs(actual_ids[i] - index))
|
48 |
+
gallery_selection = gr.update(selected_index=closest_idx)
|
49 |
+
|
50 |
+
# Slight variation to probabilities
|
51 |
+
if last_probabilities:
|
52 |
+
noise = np.random.normal(0, 0.01, size=len(last_probabilities))
|
53 |
+
new_probs = np.clip(np.array(last_probabilities) + noise, 0, None)
|
54 |
+
new_probs /= new_probs.sum()
|
55 |
+
else:
|
56 |
+
new_probs = generate_random_probabilities()
|
57 |
+
|
58 |
+
return slice_img, plot_probabilities(new_probs), gallery_selection
|
59 |
+
|
60 |
+
def load_gallery_images():
|
61 |
+
images = []
|
62 |
+
folder = "Overlapped_updated"
|
63 |
+
if os.path.exists(folder):
|
64 |
+
for fname in sorted(os.listdir(folder)):
|
65 |
+
if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
|
66 |
+
images.append(os.path.join(folder, fname))
|
67 |
+
return images
|
68 |
+
|
69 |
+
def generate_random_probabilities():
|
70 |
+
probs = np.random.rand(len(cell_types))
|
71 |
+
low_indices = np.random.choice(len(probs), size=5, replace=False)
|
72 |
+
for idx in low_indices:
|
73 |
+
probs[idx] = np.random.rand() * 0.01
|
74 |
+
probs /= probs.sum()
|
75 |
+
return probs.tolist()
|
76 |
+
|
77 |
+
def plot_probabilities(probabilities):
|
78 |
+
if len(probabilities) < 1:
|
79 |
+
return None
|
80 |
+
prob_df = pd.DataFrame({"labels": cell_types, "values": probabilities})
|
81 |
+
prob_df.to_csv('Cell_types_predictions.csv', index=False)
|
82 |
+
return prob_df
|
83 |
+
|
84 |
+
def run_mapping():
|
85 |
+
global last_probabilities
|
86 |
+
last_probabilities = generate_random_probabilities()
|
87 |
+
return plot_probabilities(last_probabilities), gr.update(visible=True)
|
88 |
+
|
89 |
+
def download_csv():
|
90 |
+
# prob_df.to_csv('Cell_types_predictions.csv', index=False)
|
91 |
+
return 'Cell_types_predictions.csv'
|
92 |
+
|
93 |
+
|
94 |
+
with gr.Blocks() as demo:
|
95 |
+
gr.Markdown("# Map My Sections")
|
96 |
+
|
97 |
+
gr.Markdown("### Step 1: Upload your CCF registered data")
|
98 |
+
nifti_file = gr.File(label="File Upload")
|
99 |
+
gr.Examples(
|
100 |
+
examples=example_files,
|
101 |
+
inputs=nifti_file,
|
102 |
+
label="Try one of our example samples"
|
103 |
+
)
|
104 |
+
|
105 |
+
with gr.Row(visible=False) as slice_row:
|
106 |
+
with gr.Column(scale=2):
|
107 |
+
gr.Markdown("### Step 2: Visualizing your uploaded sample")
|
108 |
+
image_display = gr.Image()
|
109 |
+
slice_slider = gr.Slider(minimum=0, maximum=0, value=0, step=1, label="Browse Slices", visible=False)
|
110 |
+
with gr.Column(scale=1):
|
111 |
+
gr.Markdown("### Step 3: Visualizing Allen Brain Cell Types Atlas")
|
112 |
+
gallery = gr.Gallery(label="ABC Atlas")
|
113 |
+
gr.Markdown("**Step 4: Run cell type mapping**")
|
114 |
+
run_button = gr.Button("Run Mapping")
|
115 |
+
|
116 |
+
with gr.Column(visible=False) as plot_row:
|
117 |
+
gr.Markdown("### Step 5: Quantitative results of the mapping model.")
|
118 |
+
prob_plot = gr.BarPlot(prob_df, x="labels", y="values", title="Cell Type Probabilities", scroll_to_output=True, x_label_angle=-90)
|
119 |
+
gr.Markdown("### Step 6: Download Results.")
|
120 |
+
download_button = gr.DownloadButton(label="Download Results", value='./Cell_types_predictions.csv')
|
121 |
+
|
122 |
+
nifti_file.change(load_nifti, inputs=nifti_file, outputs=[image_display, slice_slider, gallery, slice_row, plot_row])
|
123 |
+
slice_slider.change(update_slice, inputs=slice_slider, outputs=[image_display, prob_plot, gallery])
|
124 |
+
run_button.click(run_mapping, outputs=[prob_plot, plot_row])
|
125 |
+
|
126 |
+
demo.launch()
|