sahi-yolox / app.py
fcakyon
randomly direct sahi related urls during spinner
14787cc
raw
history blame
9.69 kB
import streamlit as st
import sahi.utils.mmdet
import sahi.model
import sahi.predict
from PIL import Image
import numpy
import random
MMDET_YOLACT_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/yolact/yolact_r50_1x8_coco/yolact_r50_1x8_coco_20200908-f38d58df.pth"
MMDET_YOLOX_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20210806_234250-4ff3b67e.pth"
MMDET_FASTERRCNN_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_2x_coco/faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth"
# Images
sahi.utils.file.download_from_url(
"https://user-images.githubusercontent.com/34196005/142730935-2ace3999-a47b-49bb-83e0-2bdd509f1c90.jpg",
"apple_tree.jpg",
)
sahi.utils.file.download_from_url(
"https://user-images.githubusercontent.com/34196005/142730936-1b397756-52e5-43be-a949-42ec0134d5d8.jpg",
"highway.jpg",
)
sahi.utils.file.download_from_url(
"https://user-images.githubusercontent.com/34196005/142742871-bf485f84-0355-43a3-be86-96b44e63c3a2.jpg",
"highway2.jpg",
)
sahi.utils.file.download_from_url(
"https://user-images.githubusercontent.com/34196005/142742872-1fefcc4d-d7e6-4c43-bbb7-6b5982f7e4ba.jpg",
"highway3.jpg",
)
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_mmdet_model(model_name: str):
if model_name == "yolact":
model_path = "yolact.pt"
sahi.utils.file.download_from_url(
MMDET_YOLACT_MODEL_URL,
model_path,
)
config_path = sahi.utils.mmdet.download_mmdet_config(
model_name="yolact", config_file_name="yolact_r50_1x8_coco.py"
)
elif model_name == "yolox":
model_path = "yolox.pt"
sahi.utils.file.download_from_url(
MMDET_YOLOX_MODEL_URL,
model_path,
)
config_path = sahi.utils.mmdet.download_mmdet_config(
model_name="yolox", config_file_name="yolox_tiny_8x8_300e_coco.py"
)
elif model_name == "faster_rcnn":
model_path = "faster_rcnn.pt"
sahi.utils.file.download_from_url(
MMDET_FASTERRCNN_MODEL_URL,
model_path,
)
config_path = sahi.utils.mmdet.download_mmdet_config(
model_name="faster_rcnn", config_file_name="faster_rcnn_r50_fpn_2x_coco.py"
)
detection_model = sahi.model.MmdetDetectionModel(
model_path=model_path,
config_path=config_path,
confidence_threshold=0.4,
device="cpu",
)
return detection_model
def sahi_mmdet_inference(
image,
detection_model,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2,
image_size=640,
postprocess_type="UNIONMERGE",
postprocess_match_metric="IOS",
postprocess_match_threshold=0.5,
postprocess_class_agnostic=False,
):
# standard inference
prediction_result_1 = sahi.predict.get_prediction(
image=image, detection_model=detection_model, image_size=image_size
)
visual_result_1 = sahi.utils.cv.visualize_object_predictions(
image=numpy.array(image),
object_prediction_list=prediction_result_1.object_prediction_list,
)
output_1 = Image.fromarray(visual_result_1["image"])
# sliced inference
prediction_result_2 = sahi.predict.get_sliced_prediction(
image=image,
detection_model=detection_model,
image_size=image_size,
slice_height=slice_height,
slice_width=slice_width,
overlap_height_ratio=overlap_height_ratio,
overlap_width_ratio=overlap_width_ratio,
postprocess_type=postprocess_type,
postprocess_match_metric=postprocess_match_metric,
postprocess_match_threshold=postprocess_match_threshold,
postprocess_class_agnostic=postprocess_class_agnostic,
)
visual_result_2 = sahi.utils.cv.visualize_object_predictions(
image=numpy.array(image),
object_prediction_list=prediction_result_2.object_prediction_list,
)
output_2 = Image.fromarray(visual_result_2["image"])
return output_1, output_2
st.set_page_config(
page_title="SAHI + MMDetection Demo",
page_icon="",
layout="centered",
initial_sidebar_state="auto",
)
st.markdown(
"<h2 style='text-align: center'> SAHI + MMDetection Demo </h1>",
unsafe_allow_html=True,
)
st.markdown(
"<p style='text-align: center'>SAHI is a lightweight vision library for performing large scale object detection/ instance segmentation.. <a href='https://github.com/obss/sahi'>SAHI Github</a> | <a href='https://medium.com/codable/sahi-a-vision-library-for-performing-sliced-inference-on-large-images-small-objects-c8b086af3b80'>SAHI Blog</a> | <a href='https://huggingface.co/spaces/fcakyon/sahi-yolov5'>SAHI+YOLOv5 Demo</a> </p>",
unsafe_allow_html=True,
)
st.markdown(
"<h3 style='text-align: center'> Parameters: </h1>",
unsafe_allow_html=True,
)
col1, col2, col3 = st.columns([6, 1, 6])
with col1:
image_file = st.file_uploader(
"Upload an image to test:", type=["jpg", "jpeg", "png"]
)
def slider_func(option):
option_to_id = {
"apple_tree.jpg": str(1),
"highway.jpg": str(2),
"highway2.jpg": str(3),
"highway3.jpg": str(4),
}
return option_to_id[option]
slider = st.select_slider(
"Or select from example images:",
options=["apple_tree.jpg", "highway.jpg", "highway2.jpg", "highway3.jpg"],
format_func=slider_func,
)
image = Image.open(slider)
st.image(image, caption=slider, width=300)
with col3:
model_name = st.selectbox(
"Select MMDetection model:", ("faster_rcnn", "yolact", "yolox"), index=2
)
slice_size = st.number_input("slice_size", 256, value=512, step=256)
overlap_ratio = st.number_input("overlap_ratio", 0.0, 0.6, value=0.2, step=0.2)
postprocess_type = st.selectbox(
"postprocess_type", options=["NMS", "UNIONMERGE"], index=1
)
postprocess_match_metric = st.selectbox(
"postprocess_match_metric", options=["IOU", "IOS"], index=1
)
postprocess_match_threshold = st.number_input(
"postprocess_match_threshold", value=0.5, step=0.1
)
postprocess_class_agnostic = st.checkbox("postprocess_class_agnostic", value=True)
col1, col2, col3 = st.columns([6, 1, 6])
with col2:
submit = st.button("Submit")
if image_file is not None:
image = Image.open(image_file)
else:
image = Image.open(slider)
class SpinnerTexts:
def __init__(self):
self.ind_history_list = []
self.text_list = [
"Meanwhile check out [MMDetection Colab notebook of SAHI](https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_mmdetection.ipynb)!",
"Meanwhile check out [aerial object detection with SAHI](https://blog.ml6.eu/how-to-detect-small-objects-in-very-large-images-70234bab0f98?gi=b434299595d4)!",
"Meanwhile check out [COCO Utilities of SAHI](https://github.com/obss/sahi/blob/main/docs/COCO.md)!",
"Meanwhile check out [FiftyOne utilities of SAHI](https://github.com/obss/sahi#fiftyone-utilities)!",
"Meanwhile check out [easy installation of SAHI](https://github.com/obss/sahi#getting-started)!",
"Meanwhile check out [give a Github star to SAHI](https://github.com/obss/sahi/stargazers)!",
"Meanwhile check out [easy installation of SAHI](https://github.com/obss/sahi#getting-started)!",
"Meanwhile check out [Medium blogpost of SAHI](https://medium.com/codable/sahi-a-vision-library-for-performing-sliced-inference-on-large-images-small-objects-c8b086af3b80)!",
"Meanwhile check out [YOLOv5 HF Spaces demo of SAHI](https://huggingface.co/spaces/fcakyon/sahi-yolov5)!",
]
def _store(self, ind):
if len(self.ind_history_list) == 6:
self.ind_history_list.pop(0)
self.ind_history_list.append(ind)
def get(self):
ind = 0
while ind in self.ind_history_list:
ind = random.randint(0, len(self.text_list) - 1)
self._store(ind)
return self.text_list[ind]
if "last_spinner_texts" not in st.session_state:
st.session_state["last_spinner_texts"] = SpinnerTexts()
if submit:
# perform prediction
st.markdown(
"<h3 style='text-align: center'> Results: </h1>",
unsafe_allow_html=True,
)
with st.spinner(
text="Downloading model weight.. "
+ st.session_state["last_spinner_texts"].get()
):
detection_model = get_mmdet_model(model_name)
if model_name == "yolox":
image_size = 416
else:
image_size = 640
with st.spinner(
text="Performing prediction.. " + st.session_state["last_spinner_texts"].get()
):
output_1, output_2 = sahi_mmdet_inference(
image,
detection_model,
image_size=image_size,
slice_height=slice_size,
slice_width=slice_size,
overlap_height_ratio=overlap_ratio,
overlap_width_ratio=overlap_ratio,
postprocess_type=postprocess_type,
postprocess_match_metric=postprocess_match_metric,
postprocess_match_threshold=postprocess_match_threshold,
postprocess_class_agnostic=postprocess_class_agnostic,
)
st.markdown(f"##### Standard {model_name} Prediction:")
st.image(output_1, width=700)
st.markdown(f"##### Sliced {model_name} Prediction:")
st.image(output_2, width=700)