ayyuce's picture
Update app.py
fad7d2c verified
raw
history blame
2.86 kB
import gradio as gr
import subprocess
import os
import shutil
import uuid
import zipfile
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
def run_segmentation(uploaded_file, modality):
job_id = str(uuid.uuid4())
input_filename = f"input_{job_id}.nii.gz"
output_folder = f"segmentations_{job_id}"
with open(input_filename, "wb") as f:
f.write(uploaded_file.read())
command = ["TotalSegmentator", "-i", input_filename, "-o", output_folder]
if modality == "MR":
command.extend(["--task", "total_mr"])
try:
subprocess.run(command, check=True)
except subprocess.CalledProcessError as e:
return f"Error during segmentation: {e}", None
zip_filename = f"segmentations_{job_id}.zip"
with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(output_folder):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, output_folder)
zipf.write(file_path, arcname)
seg_files = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if f.endswith('.nii.gz')]
if seg_files:
seg_file = seg_files[0]
try:
seg_img = nib.load(seg_file)
seg_data = seg_img.get_fdata()
slice_idx = seg_data.shape[2] // 2
seg_slice = seg_data[:, :, slice_idx]
plt.figure(figsize=(6, 6))
plt.imshow(seg_slice.T, cmap="gray", origin="lower")
plt.axis('off')
image_filename = f"segmentation_preview_{job_id}.png"
plt.savefig(image_filename, bbox_inches='tight')
plt.close()
except Exception as e:
print(f"Error creating preview: {e}")
image_filename = None
else:
image_filename = None
os.remove(input_filename)
shutil.rmtree(output_folder)
return zip_filename, image_filename
with gr.Blocks() as demo:
gr.Markdown("# TotalSegmentator Gradio App")
gr.Markdown(
"Upload a CT or MR image (in NIfTI format) and run segmentation using TotalSegmentator. "
"For MR images, the task flag is set accordingly. A preview of one segmentation slice will be displayed."
)
with gr.Row():
uploaded_file = gr.File(label="Upload NIfTI Image (.nii.gz)")
modality = gr.Radio(choices=["CT", "MR"], label="Select Image Modality", value="CT")
with gr.Row():
zip_output = gr.File(label="Download Segmentation Output (zip)")
preview_output = gr.Image(label="Segmentation Preview")
run_btn = gr.Button("Run Segmentation")
run_btn.click(fn=run_segmentation, inputs=[uploaded_file, modality], outputs=[zip_output, preview_output])
demo.launch()