streamlitTrial / app.py
LanglAdr's picture
Upload app.py
4e07dc1 verified
from qdrant_client import QdrantClient
from io import BytesIO
import streamlit as st
import base64
# 1. Define the Qdrant collection name that we used to store all of our metadata and vectors
collection_name = "animal_images"
#2. Set up a state variable that we'll reuse throughout the rest of the app
if 'selected_record' not in st.session_state:
st.session_state.selected_record = None
@st.cache_resource
def set_selected_record(_new_record):
#3. Create a function that allows us to easily set the selected record value.
st.session_state.selected_record = _new_record
def get_client():
#4. create the Qrant client> these must be set up in the .steamlit/secrets.toml file
return QdrantClient(
url=st.secrets.get("qdrant_db_url"),
api_key=st.secrets.get("qdrant_api_key")
)
def get_initial_records():
#5. when the app first starts, let's show a small sample of images to the user.
client = get_client()
records, _ = client.scroll(
collection_name = collection_name,
with_vectors=False,
limit=12
)
return records
def get_similar_records():
# if user has selected a record then they want to see similar images
client = get_client()
if st.session_state.selected_record is not None:
return client.recommend(
collection_name=collection_name,
positive=[st.session_state.selected_record.id],
limit=12
)
return records
def get_bytes_from_base64(base64_string):
return BytesIO(base64.b64decode(base64_string))
records = get_similar_records(
) if st.session_state.selected_record is not None else get_initial_records()
# 9 if we have a selected record then show that image at the top of the screen.
if st.session_state.selected_record:
image_bytes = get_bytes_from_base64(
st.session_state.selected_record.payload["base64"])
st.header("Images similar to:")
st.image(
image=image_bytes
)
st.divider()
#10 Setup the grid that we will use to render out images
column = st.columns(3)
#11. Iternate over all the fetch records form the DB and render to a preview of each image using the base64 string
for idx, record in enumerate(records):
col_idx = idx % 3
image_bytes = get_bytes_from_base64(record.payload["base64"])
with column[col_idx]:
st.image(
image=image_bytes
)
st.button(
label="Find similar images",
key=record.id,
on_click=set_selected_record,
args=[record]
)