JASON123454321 commited on
Commit
3feb6b1
·
verified ·
1 Parent(s): a0be510

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +12 -24
src/streamlit_app.py CHANGED
@@ -48,6 +48,8 @@ def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale
48
  return img, ratio, (dw, dh)
49
 
50
  def detect_modify(img0, model, device, conf=0.4, imgsz=640, conf_thres=0.25, iou_thres=0.45):
 
 
51
  stride = int(model.stride.max())
52
  imgsz = check_img_size(imgsz, s=stride)
53
 
@@ -79,11 +81,6 @@ def detect_modify(img0, model, device, conf=0.4, imgsz=640, conf_thres=0.25, iou
79
 
80
  img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB)
81
  st.image(img0, caption="Prediction Result", use_column_width=True)
82
-
83
- # Streamlit 介面
84
- st.title("YOLOv7 Mask Detection")
85
- st.write("Detect whether a person is wearing a face mask or not.")
86
-
87
  # 取得目前檔案所在目錄
88
  current_dir = os.path.dirname(os.path.abspath(__file__))
89
 
@@ -98,38 +95,29 @@ iou_thres = 0.45
98
  device = torch.device("cpu")
99
 
100
  # 載入模型
101
- @st.cache_resource # 使用 cache 避免每次重新載入
102
- def load_model():
103
- ckpt = torch.load(weight_path, map_location=device, weights_only=False)
104
- model = ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()
105
- return model
106
 
107
- model = load_model()
 
 
108
 
109
  option = st.radio("Select Input Method", ["Upload Image", "Image URL"])
110
 
111
  if option == "Upload Image":
112
  uploaded_file = st.file_uploader("Please upload an image.", type=["jpg", "jpeg", "png"])
113
-
114
  if uploaded_file is not None:
115
- # 直接處理上傳的檔案,不需要額外按鈕
116
  img = PILImage.create(uploaded_file)
117
- st.image(img, caption="Uploaded image", use_column_width=True)
118
-
119
- # 增加一個進度提示
120
- with st.spinner('Running detection...'):
121
- detect_modify(img, model, device, conf=conf, imgsz=imgsz, conf_thres=conf_thres, iou_thres=iou_thres)
122
 
123
  elif option == "Image URL":
124
  url = st.text_input("Please input an image URL.")
125
  if url:
126
  try:
127
- with st.spinner('Downloading and processing image...'):
128
- response = requests.get(url)
129
- response.raise_for_status() # 檢查 http status
130
- pil_img = PILImage.create(BytesIO(response.content))
131
- st.image(pil_img, caption="Downloaded image", use_column_width=True)
132
- detect_modify(pil_img, model, device, conf=conf, imgsz=imgsz, conf_thres=conf_thres, iou_thres=iou_thres)
133
  except Exception as e:
134
  st.error(f"Problem reading image from URL: {url}")
135
  st.error(str(e))
 
48
  return img, ratio, (dw, dh)
49
 
50
  def detect_modify(img0, model, device, conf=0.4, imgsz=640, conf_thres=0.25, iou_thres=0.45):
51
+ st.image(img0, caption="Your image", use_column_width=True)
52
+
53
  stride = int(model.stride.max())
54
  imgsz = check_img_size(imgsz, s=stride)
55
 
 
81
 
82
  img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB)
83
  st.image(img0, caption="Prediction Result", use_column_width=True)
 
 
 
 
 
84
  # 取得目前檔案所在目錄
85
  current_dir = os.path.dirname(os.path.abspath(__file__))
86
 
 
95
  device = torch.device("cpu")
96
 
97
  # 載入模型
98
+ ckpt = torch.load(weight_path, map_location=device, weights_only=False)
99
+ model = ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()
 
 
 
100
 
101
+ # Streamlit 介面
102
+ st.title("YOLOv7 Mask Detection")
103
+ st.write("Detect whether a person is wearing a face mask or not.")
104
 
105
  option = st.radio("Select Input Method", ["Upload Image", "Image URL"])
106
 
107
  if option == "Upload Image":
108
  uploaded_file = st.file_uploader("Please upload an image.", type=["jpg", "jpeg", "png"])
 
109
  if uploaded_file is not None:
 
110
  img = PILImage.create(uploaded_file)
111
+ detect_modify(img, model, device, conf=conf, imgsz=imgsz, conf_thres=conf_thres, iou_thres=iou_thres)
 
 
 
 
112
 
113
  elif option == "Image URL":
114
  url = st.text_input("Please input an image URL.")
115
  if url:
116
  try:
117
+ response = requests.get(url)
118
+ response.raise_for_status() # 檢查 http status
119
+ pil_img = PILImage.create(BytesIO(response.content))
120
+ detect_modify(pil_img, model, device, conf=conf, imgsz=imgsz, conf_thres=conf_thres, iou_thres=iou_thres)
 
 
121
  except Exception as e:
122
  st.error(f"Problem reading image from URL: {url}")
123
  st.error(str(e))