db_query / apps /clustering.py
DavMelchi's picture
adding clustering app description
a5af9bb
raw
history blame
3.29 kB
from io import BytesIO
import numpy as np
import pandas as pd
import plotly.express as px
import streamlit as st
from sklearn.cluster import KMeans
def cluster_sites(
df: pd.DataFrame,
lat_col: str,
lon_col: str,
region_col: str,
max_sites: int = 25,
mix_regions: bool = False,
):
clusters = []
cluster_id = 0
if not mix_regions:
grouped = df.groupby(region_col)
else:
grouped = [("All", df)]
for region, group in grouped:
coords = group[[lat_col, lon_col]].to_numpy()
n_clusters = max(1, int(np.ceil(len(group) / max_sites)))
if len(group) < max_sites:
labels = np.zeros(len(group), dtype=int)
else:
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
labels = kmeans.fit_predict(coords)
group = group.copy()
group["Cluster"] = [f"C{cluster_id + l}" for l in labels]
clusters.append(group)
cluster_id += len(set(labels))
return pd.concat(clusters)
def to_excel(df: pd.DataFrame) -> bytes:
output = BytesIO()
with pd.ExcelWriter(output, engine="xlsxwriter") as writer:
df.to_excel(writer, index=False, sheet_name="Clusters")
return output.getvalue()
st.title("Automatic Site Clustering App")
# Add description
st.write(
"""This app allows you to cluster sites based on their latitude and longitude.
**Please choose a file containing the latitude and longitude region and site code columns.**
"""
)
uploaded_file = st.file_uploader("Upload your Excel file", type=["xlsx"])
if uploaded_file:
df = pd.read_excel(uploaded_file)
st.write("Sample of uploaded data:", df.head())
columns = df.columns.tolist()
with st.form("clustering_form"):
lat_col = st.selectbox("Select Latitude column", columns)
lon_col = st.selectbox("Select Longitude column", columns)
region_col = st.selectbox("Select Region column", columns)
code_col = st.selectbox("Select Site Code column", columns)
max_sites = st.number_input(
"Max sites per cluster", min_value=5, max_value=100, value=25
)
mix_regions = st.checkbox(
"Allow mixing different regions in clusters", value=False
)
submitted = st.form_submit_button("Run Clustering")
if submitted:
clustered_df = cluster_sites(
df, lat_col, lon_col, region_col, max_sites, mix_regions
)
st.success("Clustering completed!")
st.write(clustered_df.head())
# Plot
fig = px.scatter_map(
clustered_df,
lat=lat_col,
lon=lon_col,
color="Cluster",
hover_name=code_col,
hover_data=[region_col],
zoom=5,
height=600,
)
fig.update_layout(mapbox_style="open-street-map")
st.plotly_chart(fig)
# Download button
st.download_button(
label="Download clustered Excel file",
data=to_excel(clustered_df),
file_name="clustered_sites.xlsx",
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
on_click="ignore",
type="primary",
)