File size: 2,856 Bytes
d0bb937
 
 
 
 
 
fad7d2c
 
 
d0bb937
 
 
 
 
 
fad7d2c
d0bb937
 
fad7d2c
d0bb937
 
 
 
 
 
 
fad7d2c
 
d0bb937
 
 
 
 
 
 
 
fad7d2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0bb937
 
 
fad7d2c
d0bb937
 
 
 
 
fad7d2c
d0bb937
 
 
 
 
 
fad7d2c
 
 
d0bb937
 
fad7d2c
d0bb937
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
import subprocess
import os
import shutil
import uuid
import zipfile
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt

def run_segmentation(uploaded_file, modality):
    job_id = str(uuid.uuid4())
    input_filename = f"input_{job_id}.nii.gz"
    output_folder = f"segmentations_{job_id}"
    

    with open(input_filename, "wb") as f:
        f.write(uploaded_file.read())
    
    command = ["TotalSegmentator", "-i", input_filename, "-o", output_folder]
    if modality == "MR":
        command.extend(["--task", "total_mr"])
    
    try:
        subprocess.run(command, check=True)
    except subprocess.CalledProcessError as e:
        return f"Error during segmentation: {e}", None

    zip_filename = f"segmentations_{job_id}.zip"
    with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files in os.walk(output_folder):
            for file in files:
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, output_folder)
                zipf.write(file_path, arcname)
    

    seg_files = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if f.endswith('.nii.gz')]
    if seg_files:
        seg_file = seg_files[0]
        try:
            seg_img = nib.load(seg_file)
            seg_data = seg_img.get_fdata()
            slice_idx = seg_data.shape[2] // 2
            seg_slice = seg_data[:, :, slice_idx]
            plt.figure(figsize=(6, 6))
            plt.imshow(seg_slice.T, cmap="gray", origin="lower")
            plt.axis('off')
            image_filename = f"segmentation_preview_{job_id}.png"
            plt.savefig(image_filename, bbox_inches='tight')
            plt.close()
        except Exception as e:
            print(f"Error creating preview: {e}")
            image_filename = None
    else:
        image_filename = None


    os.remove(input_filename)
    shutil.rmtree(output_folder)
    
    return zip_filename, image_filename

with gr.Blocks() as demo:
    gr.Markdown("# TotalSegmentator Gradio App")
    gr.Markdown(
        "Upload a CT or MR image (in NIfTI format) and run segmentation using TotalSegmentator. "
        "For MR images, the task flag is set accordingly. A preview of one segmentation slice will be displayed."
    )
    
    with gr.Row():
        uploaded_file = gr.File(label="Upload NIfTI Image (.nii.gz)")
        modality = gr.Radio(choices=["CT", "MR"], label="Select Image Modality", value="CT")
    
    with gr.Row():
        zip_output = gr.File(label="Download Segmentation Output (zip)")
        preview_output = gr.Image(label="Segmentation Preview")
    
    run_btn = gr.Button("Run Segmentation")
    run_btn.click(fn=run_segmentation, inputs=[uploaded_file, modality], outputs=[zip_output, preview_output])

demo.launch()