Gradio app fixed
Browse files- Segformer_best_state_dict.ckpt +0 -3
- app.py +11 -26
- requirements.txt +1 -2
Segformer_best_state_dict.ckpt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:800bb5ba3fff6c5539542cc6d9548da73dbc1a35c0dc686f0bade3b3c6c5746c
|
| 3 |
-
size 256373829
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -18,7 +18,7 @@ class Configs:
|
|
| 18 |
IMAGE_SIZE: tuple[int, int] = (288, 288) # W, H
|
| 19 |
MEAN: tuple = (0.485, 0.456, 0.406)
|
| 20 |
STD: tuple = (0.229, 0.224, 0.225)
|
| 21 |
-
MODEL_PATH: str =
|
| 22 |
|
| 23 |
|
| 24 |
def get_model(*, model_path, num_classes):
|
|
@@ -47,11 +47,8 @@ if __name__ == "__main__":
|
|
| 47 |
class2hexcolor = {"Stomach": "#007fff", "Small bowel": "#009A17", "Large bowel": "#FF0000"}
|
| 48 |
|
| 49 |
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
| 50 |
-
CKPT_PATH = os.path.join(os.getcwd(), "Segformer_best_state_dict.ckpt")
|
| 51 |
|
| 52 |
model = get_model(model_path=Configs.MODEL_PATH, num_classes=Configs.NUM_CLASSES)
|
| 53 |
-
model.load_state_dict(torch.load(CKPT_PATH, map_location=DEVICE))
|
| 54 |
-
|
| 55 |
model.to(DEVICE)
|
| 56 |
model.eval()
|
| 57 |
_ = model(torch.randn(1, 3, *Configs.IMAGE_SIZE[::-1], device=DEVICE))
|
|
@@ -64,29 +61,17 @@ if __name__ == "__main__":
|
|
| 64 |
]
|
| 65 |
)
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
outputs=gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor),
|
| 73 |
-
examples=examples,
|
| 74 |
-
cache_examples=False,
|
| 75 |
-
allow_flagging="never",
|
| 76 |
-
title="Medical Image Segmentation with UW-Madison GI Tract Dataset",
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
# with gr.Blocks(title="Medical Image Segmentation") as demo:
|
| 80 |
-
# gr.Markdown("""<h1><center>Medical Image Segmentation with UW-Madison GI Tract Dataset</center></h1>""")
|
| 81 |
-
# with gr.Row():
|
| 82 |
-
# img_input = gr.Image(type="pil", height=300, width=300, label="Input image")
|
| 83 |
-
# img_output = gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor)
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
|
| 92 |
demo.launch()
|
|
|
|
| 18 |
IMAGE_SIZE: tuple[int, int] = (288, 288) # W, H
|
| 19 |
MEAN: tuple = (0.485, 0.456, 0.406)
|
| 20 |
STD: tuple = (0.229, 0.224, 0.225)
|
| 21 |
+
MODEL_PATH: str = os.path.join(os.getcwd(), "segformer_trained_weights")
|
| 22 |
|
| 23 |
|
| 24 |
def get_model(*, model_path, num_classes):
|
|
|
|
| 47 |
class2hexcolor = {"Stomach": "#007fff", "Small bowel": "#009A17", "Large bowel": "#FF0000"}
|
| 48 |
|
| 49 |
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
|
| 50 |
|
| 51 |
model = get_model(model_path=Configs.MODEL_PATH, num_classes=Configs.NUM_CLASSES)
|
|
|
|
|
|
|
| 52 |
model.to(DEVICE)
|
| 53 |
model.eval()
|
| 54 |
_ = model(torch.randn(1, 3, *Configs.IMAGE_SIZE[::-1], device=DEVICE))
|
|
|
|
| 61 |
]
|
| 62 |
)
|
| 63 |
|
| 64 |
+
with gr.Blocks(title="Medical Image Segmentation") as demo:
|
| 65 |
+
gr.Markdown("""<h1><center>Medical Image Segmentation with UW-Madison GI Tract Dataset</center></h1>""")
|
| 66 |
+
with gr.Row():
|
| 67 |
+
img_input = gr.Image(type="pil", height=300, width=300, label="Input image")
|
| 68 |
+
img_output = gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
section_btn = gr.Button("Generate Predictions")
|
| 71 |
+
section_btn.click(partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE), img_input, img_output)
|
| 72 |
|
| 73 |
+
images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
|
| 74 |
+
examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
|
| 75 |
+
gr.Examples(examples=examples, inputs=img_input, outputs=img_output)
|
| 76 |
|
| 77 |
demo.launch()
|
requirements.txt
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
--find-links https://download.pytorch.org/whl/torch_stable.html
|
| 2 |
torch==2.0.0+cpu
|
| 3 |
torchvision==0.15.0
|
| 4 |
-
transformers==4.30.2
|
| 5 |
-
gradio
|
|
|
|
| 1 |
--find-links https://download.pytorch.org/whl/torch_stable.html
|
| 2 |
torch==2.0.0+cpu
|
| 3 |
torchvision==0.15.0
|
| 4 |
+
transformers==4.30.2
|
|
|