File size: 2,648 Bytes
4e07dc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550bae6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from qdrant_client import QdrantClient
from io import BytesIO
import streamlit as st
import base64

# 1. Define the Qdrant collection name that we used to store all of our metadata and vectors
collection_name = "animal_images"

#2. Set up a state variable that we'll reuse throughout the rest of the app
if 'selected_record' not in st.session_state:
    st.session_state.selected_record = None

@st.cache_resource
def set_selected_record(_new_record):
    #3. Create a function that allows us to easily set the selected record value.
    st.session_state.selected_record = _new_record

def get_client():
    #4. create the Qrant client> these must be set up in the .steamlit/secrets.toml file
    return QdrantClient(
        url=st.secrets.get("qdrant_db_url"),
        api_key=st.secrets.get("qdrant_api_key")
    )    

def get_initial_records():
    #5. when the app first starts, let's show a small sample of images to the user.
    client = get_client()

    records, _ = client.scroll(
        collection_name = collection_name,
        with_vectors=False,
        limit=12
    )
    return records

def get_similar_records():
    # if user has selected a record then they want to see similar images
    client = get_client()

    if st.session_state.selected_record is not None:
        return client.recommend(
            collection_name=collection_name,
            positive=[st.session_state.selected_record.id],
            limit=12
        )
    return records

def get_bytes_from_base64(base64_string):
    return BytesIO(base64.b64decode(base64_string))

records = get_similar_records(
) if st.session_state.selected_record is not None else get_initial_records()

# 9 if we have a selected record then show that image at the top of the screen.
if st.session_state.selected_record:
    image_bytes = get_bytes_from_base64(
        st.session_state.selected_record.payload["base64"])
    st.header("Images similar to:")
    st.image(
        image=image_bytes
    )
    st.divider()

#10 Setup the grid that we will use to render out images
column = st.columns(3)

#11. Iternate over all the fetch records form the DB and render to a preview of each image using the base64 string
for idx, record in enumerate(records):
    col_idx = idx % 3
    image_bytes = get_bytes_from_base64(record.payload["base64"])
    with column[col_idx]:
        st.image(
            image=image_bytes
        )
        st.button(
            label="Find similar images",
            key=record.id,
            on_click=set_selected_record,
            args=[record]
        )