yurapodk commited on
Commit
d3d8066
·
1 Parent(s): e2a98c2
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -16,18 +16,27 @@ if(input_image):
16
  image = Image.open(input_image)
17
  st.header("Original")
18
  st.image(input_image, use_column_width=True)
 
 
 
 
19
  if st.button("Get prediction"):
20
  with st.spinner("Loading..."):
21
- response = requests.post(os.path.join(api, "bounding-boxes/"), files = {"file": input_image.getvalue()})
 
22
  prediction = ast.literal_eval(response.text)
23
- response = requests.post(os.path.join(api, "image-w-boxes/"), files = {"file": input_image.getvalue()})
 
 
24
  image_with_boxes = response.content
25
  arr = []
26
  for b in prediction["bboxes"]:
27
  stamp = image.crop((b["xmin"], b["ymin"], b["xmin"] + b["width"], b["ymin"] + b["height"]))
28
  output = io.BytesIO()
29
  stamp.save(output, format="BMP")
30
- response = ast.literal_eval(requests.post(os.path.join(api, "embeddings-from-cropped/"), files = {"file": output.getvalue()}).text)
 
 
31
  arr.extend(response["embedding"])
32
 
33
  col1, col2, col3 = st.columns(3)
 
16
  image = Image.open(input_image)
17
  st.header("Original")
18
  st.image(input_image, use_column_width=True)
19
+
20
+ detection_model = st.selectbox("Select detection model", ("YOLO"))
21
+ embedding_model = st.selectbox("Select detection model", ("OML"))
22
+
23
  if st.button("Get prediction"):
24
  with st.spinner("Loading..."):
25
+ response = requests.post(os.path.join(api, f"bounding-boxes-{detection_model}/"),
26
+ files = {"file": input_image.getvalue(), "model_id": detection_model})
27
  prediction = ast.literal_eval(response.text)
28
+ response = requests.post(os.path.join(api, f"image-w-boxes-{detection_model}/"),
29
+ files = {"file": input_image.getvalue(), "model_id": detection_model})
30
+
31
  image_with_boxes = response.content
32
  arr = []
33
  for b in prediction["bboxes"]:
34
  stamp = image.crop((b["xmin"], b["ymin"], b["xmin"] + b["width"], b["ymin"] + b["height"]))
35
  output = io.BytesIO()
36
  stamp.save(output, format="BMP")
37
+ response = ast.literal_eval(requests.post(os.path.join(api, f"embeddings-from-cropped-{embedding_model}/"),
38
+ files = {"file": output.getvalue(), "model_id": embedding_model}).text)
39
+
40
  arr.extend(response["embedding"])
41
 
42
  col1, col2, col3 = st.columns(3)