JASON123454321 commited on
Commit
a0be510
·
verified ·
1 Parent(s): b78402f

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +24 -15
src/streamlit_app.py CHANGED
@@ -48,8 +48,6 @@ def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale
48
  return img, ratio, (dw, dh)
49
 
50
  def detect_modify(img0, model, device, conf=0.4, imgsz=640, conf_thres=0.25, iou_thres=0.45):
51
- st.image(img0, caption="Your image", use_column_width=True)
52
-
53
  stride = int(model.stride.max())
54
  imgsz = check_img_size(imgsz, s=stride)
55
 
@@ -81,6 +79,11 @@ def detect_modify(img0, model, device, conf=0.4, imgsz=640, conf_thres=0.25, iou
81
 
82
  img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB)
83
  st.image(img0, caption="Prediction Result", use_column_width=True)
 
 
 
 
 
84
  # 取得目前檔案所在目錄
85
  current_dir = os.path.dirname(os.path.abspath(__file__))
86
 
@@ -95,12 +98,13 @@ iou_thres = 0.45
95
  device = torch.device("cpu")
96
 
97
  # 載入模型
98
- ckpt = torch.load(weight_path, map_location=device, weights_only=False)
99
- model = ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()
 
 
 
100
 
101
- # Streamlit 介面
102
- st.title("YOLOv7 Mask Detection")
103
- st.write("Detect whether a person is wearing a face mask or not.")
104
 
105
  option = st.radio("Select Input Method", ["Upload Image", "Image URL"])
106
 
@@ -108,21 +112,26 @@ if option == "Upload Image":
108
  uploaded_file = st.file_uploader("Please upload an image.", type=["jpg", "jpeg", "png"])
109
 
110
  if uploaded_file is not None:
111
- st.image(uploaded_file, caption="Uploaded image", use_column_width=True)
112
- if st.button("Start Detection"):
113
- img = PILImage.create(uploaded_file)
 
 
 
114
  detect_modify(img, model, device, conf=conf, imgsz=imgsz, conf_thres=conf_thres, iou_thres=iou_thres)
115
 
116
-
117
  elif option == "Image URL":
118
  url = st.text_input("Please input an image URL.")
119
  if url:
120
  try:
121
- response = requests.get(url)
122
- response.raise_for_status() # 檢查 http status
123
- pil_img = PILImage.create(BytesIO(response.content))
124
- detect_modify(pil_img, model, device, conf=conf, imgsz=imgsz, conf_thres=conf_thres, iou_thres=iou_thres)
 
 
125
  except Exception as e:
126
  st.error(f"Problem reading image from URL: {url}")
127
  st.error(str(e))
128
 
 
 
48
  return img, ratio, (dw, dh)
49
 
50
  def detect_modify(img0, model, device, conf=0.4, imgsz=640, conf_thres=0.25, iou_thres=0.45):
 
 
51
  stride = int(model.stride.max())
52
  imgsz = check_img_size(imgsz, s=stride)
53
 
 
79
 
80
  img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB)
81
  st.image(img0, caption="Prediction Result", use_column_width=True)
82
+
83
+ # Streamlit 介面
84
+ st.title("YOLOv7 Mask Detection")
85
+ st.write("Detect whether a person is wearing a face mask or not.")
86
+
87
  # 取得目前檔案所在目錄
88
  current_dir = os.path.dirname(os.path.abspath(__file__))
89
 
 
98
  device = torch.device("cpu")
99
 
100
  # 載入模型
101
+ @st.cache_resource # 使用 cache 避免每次重新載入
102
+ def load_model():
103
+ ckpt = torch.load(weight_path, map_location=device, weights_only=False)
104
+ model = ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()
105
+ return model
106
 
107
+ model = load_model()
 
 
108
 
109
  option = st.radio("Select Input Method", ["Upload Image", "Image URL"])
110
 
 
112
  uploaded_file = st.file_uploader("Please upload an image.", type=["jpg", "jpeg", "png"])
113
 
114
  if uploaded_file is not None:
115
+ # 直接處理上傳的檔案,不需要額外按鈕
116
+ img = PILImage.create(uploaded_file)
117
+ st.image(img, caption="Uploaded image", use_column_width=True)
118
+
119
+ # 增加一個進度提示
120
+ with st.spinner('Running detection...'):
121
  detect_modify(img, model, device, conf=conf, imgsz=imgsz, conf_thres=conf_thres, iou_thres=iou_thres)
122
 
 
123
  elif option == "Image URL":
124
  url = st.text_input("Please input an image URL.")
125
  if url:
126
  try:
127
+ with st.spinner('Downloading and processing image...'):
128
+ response = requests.get(url)
129
+ response.raise_for_status() # 檢查 http status
130
+ pil_img = PILImage.create(BytesIO(response.content))
131
+ st.image(pil_img, caption="Downloaded image", use_column_width=True)
132
+ detect_modify(pil_img, model, device, conf=conf, imgsz=imgsz, conf_thres=conf_thres, iou_thres=iou_thres)
133
  except Exception as e:
134
  st.error(f"Problem reading image from URL: {url}")
135
  st.error(str(e))
136
 
137
+