LanglAdr commited on
Commit
4e07dc1
·
verified ·
1 Parent(s): dba8b2f

Upload app.py

Browse files

updated app.py

Files changed (1) hide show
  1. app.py +79 -79
app.py CHANGED
@@ -1,80 +1,80 @@
1
- from qdrant_client import QdrantClient
2
- from io import BytesIO
3
- import streamlit as st
4
- import base64
5
-
6
- # 1. Define the Qdrant collection name that we used to store all of our metadata and vectors
7
- collection_name = "animal_images"
8
-
9
- #2. Set up a state variable that we'll reuse throughout the rest of the app
10
- if 'selected_record' not in st.session_state:
11
- st.session_state.selected_record = None
12
-
13
- @st.cache_resource
14
- def set_selected_record(_new_record):
15
- #3. Create a function that allows us to easily set the selected record value.
16
- st.session_state.selected_record = _new_record
17
-
18
- def get_client():
19
- #4. create the Qrant client> these must be set up in the .steamlit/secrets.toml file
20
- return QdrantClient(
21
- url=st.secrets.get("qdrant_db_url"),
22
- api_key=st.secrets.get("qdrant_api_key")
23
- )
24
-
25
- def get_initial_records():
26
- #5. when the app first starts, let's show a small sample of images to the user.
27
- client = get_client()
28
-
29
- records, _ = client.scroll(
30
- collection_name = collection_name,
31
- with_vectors=False,
32
- limit=144
33
- )
34
- return records
35
-
36
- def get_similar_records():
37
- # if user has selected a record then they want to see similar images
38
- client = get_client()
39
-
40
- if st.session_state.selected_record is not None:
41
- return client.recommend(
42
- collection_name=collection_name,
43
- positive=[st.session_state.selected_record.id],
44
- limit=144
45
- )
46
- return records
47
-
48
- def get_bytes_from_base64(base64_string):
49
- return BytesIO(base64.b64decode(base64_string))
50
-
51
- records = get_similar_records(
52
- ) if st.session_state.selected_record is not None else get_initial_records()
53
-
54
- # 9 if we have a selected record then show that image at the top of the screen.
55
- if st.session_state.selected_record:
56
- image_bytes = get_bytes_from_base64(
57
- st.session_state.selected_record.payload["base64"])
58
- st.header("Images similar to:")
59
- st.image(
60
- image=image_bytes
61
- )
62
- st.divider()
63
-
64
- #10 Setup the grid that we will use to render out images
65
- column = st.columns(3)
66
-
67
- #11. Iternate over all the fetch records form the DB and render to a preview of each image using the base64 string
68
- for idx, record in enumerate(records):
69
- col_idx = idx % 3
70
- image_bytes = get_bytes_from_base64(record.payload["base64"])
71
- with column[col_idx]:
72
- st.image(
73
- image=image_bytes
74
- )
75
- st.button(
76
- label="Find similar images",
77
- key=record.id,
78
- on_click=set_selected_record,
79
- args=[record]
80
  )
 
1
+ from qdrant_client import QdrantClient
2
+ from io import BytesIO
3
+ import streamlit as st
4
+ import base64
5
+
6
+ # 1. Define the Qdrant collection name that we used to store all of our metadata and vectors
7
+ collection_name = "animal_images"
8
+
9
+ #2. Set up a state variable that we'll reuse throughout the rest of the app
10
+ if 'selected_record' not in st.session_state:
11
+ st.session_state.selected_record = None
12
+
13
+ @st.cache_resource
14
+ def set_selected_record(_new_record):
15
+ #3. Create a function that allows us to easily set the selected record value.
16
+ st.session_state.selected_record = _new_record
17
+
18
+ def get_client():
19
+ #4. create the Qrant client> these must be set up in the .steamlit/secrets.toml file
20
+ return QdrantClient(
21
+ url=st.secrets.get("qdrant_db_url"),
22
+ api_key=st.secrets.get("qdrant_api_key")
23
+ )
24
+
25
+ def get_initial_records():
26
+ #5. when the app first starts, let's show a small sample of images to the user.
27
+ client = get_client()
28
+
29
+ records, _ = client.scroll(
30
+ collection_name = collection_name,
31
+ with_vectors=False,
32
+ limit=12
33
+ )
34
+ return records
35
+
36
+ def get_similar_records():
37
+ # if user has selected a record then they want to see similar images
38
+ client = get_client()
39
+
40
+ if st.session_state.selected_record is not None:
41
+ return client.recommend(
42
+ collection_name=collection_name,
43
+ positive=[st.session_state.selected_record.id],
44
+ limit=12
45
+ )
46
+ return records
47
+
48
+ def get_bytes_from_base64(base64_string):
49
+ return BytesIO(base64.b64decode(base64_string))
50
+
51
+ records = get_similar_records(
52
+ ) if st.session_state.selected_record is not None else get_initial_records()
53
+
54
+ # 9 if we have a selected record then show that image at the top of the screen.
55
+ if st.session_state.selected_record:
56
+ image_bytes = get_bytes_from_base64(
57
+ st.session_state.selected_record.payload["base64"])
58
+ st.header("Images similar to:")
59
+ st.image(
60
+ image=image_bytes
61
+ )
62
+ st.divider()
63
+
64
+ #10 Setup the grid that we will use to render out images
65
+ column = st.columns(3)
66
+
67
+ #11. Iternate over all the fetch records form the DB and render to a preview of each image using the base64 string
68
+ for idx, record in enumerate(records):
69
+ col_idx = idx % 3
70
+ image_bytes = get_bytes_from_base64(record.payload["base64"])
71
+ with column[col_idx]:
72
+ st.image(
73
+ image=image_bytes
74
+ )
75
+ st.button(
76
+ label="Find similar images",
77
+ key=record.id,
78
+ on_click=set_selected_record,
79
+ args=[record]
80
  )