File size: 5,305 Bytes
cfd23c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d20a3a3
cfd23c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea7c5b4
cfd23c0
ea7c5b4
cfd23c0
 
 
ea7c5b4
cfd23c0
 
 
 
 
ea7c5b4
cfd23c0
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import nibabel as nib
import numpy as np
import os
from PIL import Image
import pandas as pd

example_files = [
    ["./resampled_green_25.nii.gz"],
#     ["examples/sample2.nii.gz"],
#     ["examples/sample3.nii.gz"]
]

# Global variables
coronal_slices = []
last_probabilities = []
prob_df = pd.DataFrame()

# Target cell types
cell_types = [
    "ABC.NN", "Astro.TE.NN", "CLA.EPd.CTX.Car3.Glut", "Endo.NN", "L2.3.IT.CTX.Glut",
    "L4.5.IT.CTX.Glut", "L5.ET.CTX.Glut", "L5.IT.CTX.Glut", "L5.NP.CTX.Glut", "L6.CT.CTX.Glut",
    "L6.IT.CTX.Glut", "L6b.CTX.Glut", "Lamp5.Gaba", "Lamp5.Lhx6.Gaba", "Lymphoid.NN", "Microglia.NN",
    "OPC.NN", "Oligo.NN", "Peri.NN", "Pvalb.Gaba", "Pvalb.chandelier.Gaba", "SMC.NN", "Sncg.Gaba",
    "Sst.Chodl.Gaba", "Sst.Gaba", "VLMC.NN", "Vip.Gaba"
]

actual_ids = [30,52,71,91,104,109,118,126,131,137,141,164,178,182,197,208,218,226,232,242,244,248,256,262,270,282,293,297,308,323,339,344,350,355,364,372,379,389,395,401,410,415,418,424,429,434,440,444,469,479,487,509]
gallery_ids = [5,6,8,9,10,11,12,13,14,15,16,17,18,19,24,25,26,27,28,29,30,31,32,33,35,36,37,38,39,40,42,43,44,45,46,47,48,49,50,51,52,54,55,56,57,58,59,60,61,62,64,66,67]
gallery_ids.reverse()
def load_nifti(file):
    global coronal_slices
    img = nib.load(file.name)
    vol = img.get_fdata()
    coronal_slices = [vol[i, :, :] for i in range(vol.shape[0])]
    mid_index = vol.shape[0] // 2
    slice_img = Image.fromarray((coronal_slices[mid_index] / np.max(coronal_slices[mid_index]) * 255).astype(np.uint8))
    gallery_images = load_gallery_images()
    return slice_img, gr.update(visible=True, maximum=len(coronal_slices)-1, value=mid_index), gallery_images, gr.update(visible=True), gr.update(visible=False)

def update_slice(index):
    if not coronal_slices:
        return None, None, None
    slice_img = Image.fromarray((coronal_slices[index] / np.max(coronal_slices[index]) * 255).astype(np.uint8))

    # Find closest gallery index
    closest_idx = min(range(len(actual_ids)), key=lambda i: abs(actual_ids[i] - index))
    gallery_selection = gr.update(selected_index=closest_idx)

    # Slight variation to probabilities
    if last_probabilities:
        noise = np.random.normal(0, 0.01, size=len(last_probabilities))
        new_probs = np.clip(np.array(last_probabilities) + noise, 0, None)
        new_probs /= new_probs.sum()
    else:
        new_probs = generate_random_probabilities()

    return slice_img, plot_probabilities(new_probs), gallery_selection

def load_gallery_images():
    images = []
    folder = "Overlapped_updated"
    if os.path.exists(folder):
        for fname in sorted(os.listdir(folder)):
            if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                images.append(os.path.join(folder, fname))
    return images

def generate_random_probabilities():
    probs = np.random.rand(len(cell_types))
    low_indices = np.random.choice(len(probs), size=5, replace=False)
    for idx in low_indices:
        probs[idx] = np.random.rand() * 0.01
    probs /= probs.sum()
    return probs.tolist()

def plot_probabilities(probabilities):
    if len(probabilities) < 1:
        return None
    prob_df = pd.DataFrame({"labels": cell_types, "values": probabilities})
    prob_df.to_csv('Cell_types_predictions.csv', index=False)
    return prob_df

def run_mapping():
    global last_probabilities
    last_probabilities = generate_random_probabilities()
    return plot_probabilities(last_probabilities), gr.update(visible=True)

def download_csv():
    # prob_df.to_csv('Cell_types_predictions.csv', index=False)
    return 'Cell_types_predictions.csv'


with gr.Blocks() as demo:
    gr.Markdown("# Map My Sections")

    gr.Markdown("### Step 1: Upload your CCF registered data")
    nifti_file = gr.File(label="File Upload")
    gr.Examples(
        examples=example_files,
        inputs=nifti_file,
        label="Try one of our example samples"
    )

    with gr.Row(visible=False) as slice_row:
        with gr.Column(scale=1):
            gr.Markdown("### Step 2: Visualizing your uploaded sample")
            image_display = gr.Image(height = 400)
            slice_slider = gr.Slider(minimum=0, maximum=0, value=0, step=1, label="Browse Slices", visible=False)
        with gr.Column(scale=1):
            gr.Markdown("### Step 3: Visualizing Allen Brain Cell Types Atlas")
            gallery = gr.Gallery(label="ABC Atlas", height = 400)
            gr.Markdown("**Step 4: Run cell type mapping**")
            run_button = gr.Button("Run Mapping")

    with gr.Column(visible=False) as plot_row:
        gr.Markdown("### Step 5: Quantitative results of the mapping model.")
        prob_plot = gr.BarPlot(prob_df, x="labels", y="values", title="Cell Type Probabilities", scroll_to_output=True, x_label_angle=-90, height = 600)
        gr.Markdown("### Step 6: Download Results.")
        download_button = gr.DownloadButton(label="Download Results", value='./Cell_types_predictions.csv')

    nifti_file.change(load_nifti, inputs=nifti_file, outputs=[image_display, slice_slider, gallery, slice_row, plot_row])
    slice_slider.change(update_slice, inputs=slice_slider, outputs=[image_display, prob_plot, gallery])
    run_button.click(run_mapping, outputs=[prob_plot, plot_row])

    demo.launch()