from qdrant_client import QdrantClient from io import BytesIO import streamlit as st import base64 # 1. Define the Qdrant collection name that we used to store all of our metadata and vectors collection_name = "animal_images" #2. Set up a state variable that we'll reuse throughout the rest of the app if 'selected_record' not in st.session_state: st.session_state.selected_record = None @st.cache_resource def set_selected_record(_new_record): #3. Create a function that allows us to easily set the selected record value. st.session_state.selected_record = _new_record def get_client(): #4. create the Qrant client> these must be set up in the .steamlit/secrets.toml file return QdrantClient( url=st.secrets.get("qdrant_db_url"), api_key=st.secrets.get("qdrant_api_key") ) def get_initial_records(): #5. when the app first starts, let's show a small sample of images to the user. client = get_client() records, _ = client.scroll( collection_name = collection_name, with_vectors=False, limit=12 ) return records def get_similar_records(): # if user has selected a record then they want to see similar images client = get_client() if st.session_state.selected_record is not None: return client.recommend( collection_name=collection_name, positive=[st.session_state.selected_record.id], limit=12 ) return records def get_bytes_from_base64(base64_string): return BytesIO(base64.b64decode(base64_string)) records = get_similar_records( ) if st.session_state.selected_record is not None else get_initial_records() # 9 if we have a selected record then show that image at the top of the screen. if st.session_state.selected_record: image_bytes = get_bytes_from_base64( st.session_state.selected_record.payload["base64"]) st.header("Images similar to:") st.image( image=image_bytes ) st.divider() #10 Setup the grid that we will use to render out images column = st.columns(3) #11. Iternate over all the fetch records form the DB and render to a preview of each image using the base64 string for idx, record in enumerate(records): col_idx = idx % 3 image_bytes = get_bytes_from_base64(record.payload["base64"]) with column[col_idx]: st.image( image=image_bytes ) st.button( label="Find similar images", key=record.id, on_click=set_selected_record, args=[record] )