Taha Razzaq commited on
Commit
cfd23c0
·
1 Parent(s): bf71a90

Add application file

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import nibabel as nib
3
+ import numpy as np
4
+ import os
5
+ from PIL import Image
6
+ import pandas as pd
7
+
8
+ example_files = [
9
+ ["./resampled_green_25.nii.gz"],
10
+ # ["examples/sample2.nii.gz"],
11
+ # ["examples/sample3.nii.gz"]
12
+ ]
13
+
14
+ # Global variables
15
+ coronal_slices = []
16
+ last_probabilities = []
17
+ prob_df = pd.DataFrame()
18
+
19
+ # Target cell types
20
+ cell_types = [
21
+ "ABC.NN", "Astro.TE.NN", "CLA.EPd.CTX.Car3.Glut", "Endo.NN", "L2.3.IT.CTX.Glut",
22
+ "L4.5.IT.CTX.Glut", "L5.ET.CTX.Glut", "L5.IT.CTX.Glut", "L5.NP.CTX.Glut", "L6.CT.CTX.Glut",
23
+ "L6.IT.CTX.Glut", "L6b.CTX.Glut", "Lamp5.Gaba", "Lamp5.Lhx6.Gaba", "Lymphoid.NN", "Microglia.NN",
24
+ "OPC.NN", "Oligo.NN", "Peri.NN", "Pvalb.Gaba", "Pvalb.chandelier.Gaba", "SMC.NN", "Sncg.Gaba",
25
+ "Sst.Chodl.Gaba", "Sst.Gaba", "VLMC.NN", "Vip.Gaba"
26
+ ]
27
+
28
+ actual_ids = [30,52,71,91,104,109,118,126,131,137,141,164,178,182,197,208,218,226,232,242,244,248,256,262,270,282,293,297,308,323,339,344,350,355,364,372,379,389,395,401,410,415,418,424,429,434,440,444,469,479,487,509]
29
+ gallery_ids = [5,6,8,9,10,11,12,13,14,15,16,17,18,19,24,25,26,27,28,29,30,31,32,33,35,36,37,38,39,40,42,43,44,45,46,47,48,49,50,51,52,54,55,56,57,58,59,60,61,62,64,66,67]
30
+
31
+ def load_nifti(file):
32
+ global coronal_slices
33
+ img = nib.load(file.name)
34
+ vol = img.get_fdata()
35
+ coronal_slices = [vol[i, :, :] for i in range(vol.shape[0])]
36
+ mid_index = vol.shape[0] // 2
37
+ slice_img = Image.fromarray((coronal_slices[mid_index] / np.max(coronal_slices[mid_index]) * 255).astype(np.uint8))
38
+ gallery_images = load_gallery_images()
39
+ return slice_img, gr.update(visible=True, maximum=len(coronal_slices)-1, value=mid_index), gallery_images, gr.update(visible=True), gr.update(visible=False)
40
+
41
+ def update_slice(index):
42
+ if not coronal_slices:
43
+ return None, None, None
44
+ slice_img = Image.fromarray((coronal_slices[index] / np.max(coronal_slices[index]) * 255).astype(np.uint8))
45
+
46
+ # Find closest gallery index
47
+ closest_idx = min(range(len(actual_ids)), key=lambda i: abs(actual_ids[i] - index))
48
+ gallery_selection = gr.update(selected_index=closest_idx)
49
+
50
+ # Slight variation to probabilities
51
+ if last_probabilities:
52
+ noise = np.random.normal(0, 0.01, size=len(last_probabilities))
53
+ new_probs = np.clip(np.array(last_probabilities) + noise, 0, None)
54
+ new_probs /= new_probs.sum()
55
+ else:
56
+ new_probs = generate_random_probabilities()
57
+
58
+ return slice_img, plot_probabilities(new_probs), gallery_selection
59
+
60
+ def load_gallery_images():
61
+ images = []
62
+ folder = "Overlapped_updated"
63
+ if os.path.exists(folder):
64
+ for fname in sorted(os.listdir(folder)):
65
+ if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
66
+ images.append(os.path.join(folder, fname))
67
+ return images
68
+
69
+ def generate_random_probabilities():
70
+ probs = np.random.rand(len(cell_types))
71
+ low_indices = np.random.choice(len(probs), size=5, replace=False)
72
+ for idx in low_indices:
73
+ probs[idx] = np.random.rand() * 0.01
74
+ probs /= probs.sum()
75
+ return probs.tolist()
76
+
77
+ def plot_probabilities(probabilities):
78
+ if len(probabilities) < 1:
79
+ return None
80
+ prob_df = pd.DataFrame({"labels": cell_types, "values": probabilities})
81
+ prob_df.to_csv('Cell_types_predictions.csv', index=False)
82
+ return prob_df
83
+
84
+ def run_mapping():
85
+ global last_probabilities
86
+ last_probabilities = generate_random_probabilities()
87
+ return plot_probabilities(last_probabilities), gr.update(visible=True)
88
+
89
+ def download_csv():
90
+ # prob_df.to_csv('Cell_types_predictions.csv', index=False)
91
+ return 'Cell_types_predictions.csv'
92
+
93
+
94
+ with gr.Blocks() as demo:
95
+ gr.Markdown("# Map My Sections")
96
+
97
+ gr.Markdown("### Step 1: Upload your CCF registered data")
98
+ nifti_file = gr.File(label="File Upload")
99
+ gr.Examples(
100
+ examples=example_files,
101
+ inputs=nifti_file,
102
+ label="Try one of our example samples"
103
+ )
104
+
105
+ with gr.Row(visible=False) as slice_row:
106
+ with gr.Column(scale=2):
107
+ gr.Markdown("### Step 2: Visualizing your uploaded sample")
108
+ image_display = gr.Image()
109
+ slice_slider = gr.Slider(minimum=0, maximum=0, value=0, step=1, label="Browse Slices", visible=False)
110
+ with gr.Column(scale=1):
111
+ gr.Markdown("### Step 3: Visualizing Allen Brain Cell Types Atlas")
112
+ gallery = gr.Gallery(label="ABC Atlas")
113
+ gr.Markdown("**Step 4: Run cell type mapping**")
114
+ run_button = gr.Button("Run Mapping")
115
+
116
+ with gr.Column(visible=False) as plot_row:
117
+ gr.Markdown("### Step 5: Quantitative results of the mapping model.")
118
+ prob_plot = gr.BarPlot(prob_df, x="labels", y="values", title="Cell Type Probabilities", scroll_to_output=True, x_label_angle=-90)
119
+ gr.Markdown("### Step 6: Download Results.")
120
+ download_button = gr.DownloadButton(label="Download Results", value='./Cell_types_predictions.csv')
121
+
122
+ nifti_file.change(load_nifti, inputs=nifti_file, outputs=[image_display, slice_slider, gallery, slice_row, plot_row])
123
+ slice_slider.change(update_slice, inputs=slice_slider, outputs=[image_display, prob_plot, gallery])
124
+ run_button.click(run_mapping, outputs=[prob_plot, plot_row])
125
+
126
+ demo.launch()