from io import BytesIO import numpy as np import pandas as pd import plotly.express as px import streamlit as st from sklearn.cluster import KMeans def cluster_sites( df: pd.DataFrame, lat_col: str, lon_col: str, region_col: str, max_sites: int = 25, mix_regions: bool = False, ): clusters = [] cluster_id = 0 if not mix_regions: grouped = df.groupby(region_col) else: grouped = [("All", df)] for region, group in grouped: coords = group[[lat_col, lon_col]].to_numpy() n_clusters = max(1, int(np.ceil(len(group) / max_sites))) if len(group) < max_sites: labels = np.zeros(len(group), dtype=int) else: kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) labels = kmeans.fit_predict(coords) group = group.copy() group["Cluster"] = [f"C{cluster_id + l}" for l in labels] clusters.append(group) cluster_id += len(set(labels)) return pd.concat(clusters) def to_excel(df: pd.DataFrame) -> bytes: output = BytesIO() with pd.ExcelWriter(output, engine="xlsxwriter") as writer: df.to_excel(writer, index=False, sheet_name="Clusters") return output.getvalue() st.title("Automatic Site Clustering App") # Add description st.write( """This app allows you to cluster sites based on their latitude and longitude. **Please choose a file containing the latitude and longitude region and site code columns.** """ ) uploaded_file = st.file_uploader("Upload your Excel file", type=["xlsx"]) if uploaded_file: df = pd.read_excel(uploaded_file) st.write("Sample of uploaded data:", df.head()) columns = df.columns.tolist() with st.form("clustering_form"): lat_col = st.selectbox("Select Latitude column", columns) lon_col = st.selectbox("Select Longitude column", columns) region_col = st.selectbox("Select Region column", columns) code_col = st.selectbox("Select Site Code column", columns) max_sites = st.number_input( "Max sites per cluster", min_value=5, max_value=100, value=25 ) mix_regions = st.checkbox( "Allow mixing different regions in clusters", value=False ) submitted = st.form_submit_button("Run Clustering") if submitted: clustered_df = cluster_sites( df, lat_col, lon_col, region_col, max_sites, mix_regions ) st.success("Clustering completed!") st.write(clustered_df.head()) # Plot fig = px.scatter_map( clustered_df, lat=lat_col, lon=lon_col, color="Cluster", hover_name=code_col, hover_data=[region_col], zoom=5, height=600, ) fig.update_layout(mapbox_style="open-street-map") st.plotly_chart(fig) # Download button st.download_button( label="Download clustered Excel file", data=to_excel(clustered_df), file_name="clustered_sites.xlsx", mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", on_click="ignore", type="primary", )