Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import sahi.utils.file | |
| from PIL import Image | |
| from sahi import AutoDetectionModel | |
| from utils import sahi_yolov8m_inference | |
| from streamlit_image_comparison import image_comparison | |
| from ultralyticsplus.hf_utils import download_from_hub | |
| IMAGE_TO_URL = { | |
| 'factory_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/factory-pid.png', | |
| 'plant_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png', | |
| 'processing_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/processing-pid.png', | |
| 'prediction_visual.png' : 'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png' | |
| } | |
| st.set_page_config( | |
| page_title="P&ID Object Detection", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| st.title('P&ID Object Detection') | |
| st.subheader(' Identify valves and pumps with deep learning model ', divider='rainbow') | |
| st.caption('Developed by Deep Drawings Co.') | |
| def get_model(): | |
| yolov8_model_path = download_from_hub('DanielCerda/pid_yolov8') | |
| detection_model = AutoDetectionModel.from_pretrained( | |
| model_type='yolov8', | |
| model_path=yolov8_model_path, | |
| confidence_threshold=0.75, | |
| device="cpu", | |
| ) | |
| return detection_model | |
| def download_comparison_images(): | |
| sahi.utils.file.download_from_url( | |
| 'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png', | |
| 'plant_pid.png', | |
| ) | |
| sahi.utils.file.download_from_url( | |
| 'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png', | |
| 'prediction_visual.png', | |
| ) | |
| download_comparison_images() | |
| if "output_1" not in st.session_state: | |
| st.session_state["output_1"] = Image.open('plant_pid.png') | |
| if "output_2" not in st.session_state: | |
| st.session_state["output_2"] = Image.open('prediction_visual.png') | |
| col1, col2, col3 = st.columns(3, gap='medium') | |
| with col1: | |
| with st.expander('How to use it'): | |
| st.markdown( | |
| ''' | |
| 1) Select any example diagram ππ» | |
| 2) Set confidence threshold π | |
| 3) Press to perform inference π | |
| 4) Visualize model predictions π | |
| ''' | |
| ) | |
| st.write('##') | |
| col1, col2, col3 = st.columns(3, gap='large') | |
| with col1: | |
| st.markdown('##### Input Data') | |
| # set input images from examples | |
| def radio_func(option): | |
| option_to_id = { | |
| 'factory_pid.png' : 'A', | |
| 'plant_pid.png' : 'B', | |
| 'processing_pid.png' : 'C', | |
| } | |
| return option_to_id[option] | |
| radio = st.radio( | |
| 'Select from the following examples', | |
| options = ['factory_pid.png', 'plant_pid.png', 'processing_pid.png'], | |
| format_func = radio_func, | |
| ) | |
| with col2: | |
| st.markdown('##### Preview') | |
| image = sahi.utils.cv.read_image_as_pil(IMAGE_TO_URL[radio]) | |
| with st.container(border = True): | |
| st.image(image, use_column_width = True) | |
| with col3: | |
| st.markdown('##### Set model parameters') | |
| slice_size = st.slider( | |
| label = 'Select Slice Size', | |
| min_value=256, | |
| max_value=1024, | |
| value=768, | |
| step=256 | |
| ) | |
| overlap_ratio = st.slider( | |
| label = 'Select Overlap Ratio', | |
| min_value=0.0, | |
| max_value=0.5, | |
| value=0.1, | |
| step=0.1 | |
| ) | |
| postprocess_match_threshold = st.slider( | |
| label = 'Select Confidence Threshold', | |
| min_value = 0.0, | |
| max_value = 1.0, | |
| value = 0.75, | |
| step = 0.25 | |
| ) | |
| st.write('##') | |
| col1, col2, col3 = st.columns([3, 1, 3]) | |
| with col2: | |
| submit = st.button("π Perform Prediction") | |
| if submit: | |
| # perform prediction | |
| with st.spinner(text="Downloading model weights ... "): | |
| detection_model = get_model() | |
| image_size = 1024 | |
| with st.spinner(text="Performing prediction ... "): | |
| output = sahi_yolov8m_inference( | |
| image, | |
| detection_model, | |
| image_size=image_size, | |
| slice_height=slice_size, | |
| slice_width=slice_size, | |
| overlap_height_ratio=overlap_ratio, | |
| overlap_width_ratio=overlap_ratio, | |
| postprocess_match_threshold=postprocess_match_threshold | |
| ) | |
| st.session_state["output_1"] = image | |
| st.session_state["output_2"] = output | |
| st.write('##') | |
| col1, col2, col3 = st.columns([1, 4, 1]) | |
| with col2: | |
| st.markdown(f"#### Object Detection Result") | |
| with st.container(border = True): | |
| static_component = image_comparison( | |
| img1=st.session_state["output_1"], | |
| img2=st.session_state["output_2"], | |
| label1='Uploaded Diagram', | |
| label2='Model Inference', | |
| width=768, | |
| starting_position=50, | |
| show_labels=True, | |
| make_responsive=True, | |
| in_memory=True, | |
| ) | |