Spaces:
Runtime error
Runtime error
File size: 1,987 Bytes
50e082d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import gradio as gr
from fastmri.data.subsample import create_mask_for_mask_type
from fastmri.data.transforms import apply_mask, to_tensor, center_crop
from pytorch_msssim import ssim
# st.title('FastMRI Kspace Reconstruction Masks')
# st.write('This app allows you to visualize the masks and their effects on the kspace data.')
def main_func(
mask_name: str,
mask_center_fractions: int,
accelerations: int,
seed: int,
input_image: str,
):
file_dict = {
"knee 1": "knee_singlecoil_train/file1000002.h5",
"knee 2": "knee_singlecoil_train/file1000003.h5",
"brain 1": "brain_axial_train/file1000002.h5",
"prostate 1": "prostate_t1_tse_train/file1000002.h5",
"prostate 2": "prostate_t2_tse_train/file1000002.h5",
}
input_file = file_dict[input_image]
mask_func = create_mask_for_mask_type(
mask_name, center_fractions=[mask_center_fractions], accelerations=[accelerations]
)
mask =
masked_kspace, mask = mask(input_image, return_mask=True)
return masked_kspace, mask
demo = gr.Interface(
fn=main_func,
inputs=[
gr.inputs.Radio(['random', 'equispaced'], label="Mask Type"),
gr.inputs.Slider(minimum=0.04, maximum=0.4, default=0.08, label="Center Fraction"),
gr.inputs.Number(default=4, label="Acceleration"),
gr.inputs.Number(default=0, label="Seed"),
gr.inputs.Radio(["knee 1", "knee 2", "brain 1", "prostate 1", "prostate 2"], label="Input Image")
],
outputs=[
gr.outputs.Image(type="mask", label="Mask"),
gr.outputs.Image(type="kspace", label="Masked Kspace"),
gr.outputs.Image(type="kspace", label="Reconstructed Image"),
gr.outputs.Image(type="kspace", label="Original Image"),
gr.outputs.Dataframe()
],
title="FastMRI Kspace Reconstruction Masks",
description="This app allows you to visualize the masks and their effects on the kspace data."
)
demo.launch() |