Spaces:
Sleeping
Sleeping
from qdrant_client import QdrantClient | |
from io import BytesIO | |
import streamlit as st | |
import base64 | |
# 1. Define the Qdrant collection name that we used to store all of our metadata and vectors | |
collection_name = "animal_images" | |
#2. Set up a state variable that we'll reuse throughout the rest of the app | |
if 'selected_record' not in st.session_state: | |
st.session_state.selected_record = None | |
def set_selected_record(_new_record): | |
#3. Create a function that allows us to easily set the selected record value. | |
st.session_state.selected_record = _new_record | |
def get_client(): | |
#4. create the Qrant client> these must be set up in the .steamlit/secrets.toml file | |
return QdrantClient( | |
url=st.secrets.get("qdrant_db_url"), | |
api_key=st.secrets.get("qdrant_api_key") | |
) | |
def get_initial_records(): | |
#5. when the app first starts, let's show a small sample of images to the user. | |
client = get_client() | |
records, _ = client.scroll( | |
collection_name = collection_name, | |
with_vectors=False, | |
limit=12 | |
) | |
return records | |
def get_similar_records(): | |
# if user has selected a record then they want to see similar images | |
client = get_client() | |
if st.session_state.selected_record is not None: | |
return client.recommend( | |
collection_name=collection_name, | |
positive=[st.session_state.selected_record.id], | |
limit=12 | |
) | |
return records | |
def get_bytes_from_base64(base64_string): | |
return BytesIO(base64.b64decode(base64_string)) | |
records = get_similar_records( | |
) if st.session_state.selected_record is not None else get_initial_records() | |
# 9 if we have a selected record then show that image at the top of the screen. | |
if st.session_state.selected_record: | |
image_bytes = get_bytes_from_base64( | |
st.session_state.selected_record.payload["base64"]) | |
st.header("Images similar to:") | |
st.image( | |
image=image_bytes | |
) | |
st.divider() | |
#10 Setup the grid that we will use to render out images | |
column = st.columns(3) | |
#11. Iternate over all the fetch records form the DB and render to a preview of each image using the base64 string | |
for idx, record in enumerate(records): | |
col_idx = idx % 3 | |
image_bytes = get_bytes_from_base64(record.payload["base64"]) | |
with column[col_idx]: | |
st.image( | |
image=image_bytes | |
) | |
st.button( | |
label="Find similar images", | |
key=record.id, | |
on_click=set_selected_record, | |
args=[record] | |
) |