|
from io import BytesIO |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import plotly.express as px |
|
import streamlit as st |
|
from sklearn.cluster import KMeans |
|
|
|
|
|
def cluster_sites( |
|
df: pd.DataFrame, |
|
lat_col: str, |
|
lon_col: str, |
|
region_col: str, |
|
max_sites: int = 25, |
|
mix_regions: bool = False, |
|
): |
|
clusters = [] |
|
cluster_id = 0 |
|
|
|
if not mix_regions: |
|
grouped = df.groupby(region_col) |
|
else: |
|
grouped = [("All", df)] |
|
|
|
for region, group in grouped: |
|
coords = group[[lat_col, lon_col]].to_numpy() |
|
n_clusters = max(1, int(np.ceil(len(group) / max_sites))) |
|
|
|
if len(group) < max_sites: |
|
labels = np.zeros(len(group), dtype=int) |
|
else: |
|
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) |
|
labels = kmeans.fit_predict(coords) |
|
|
|
group = group.copy() |
|
group["Cluster"] = [f"C{cluster_id + l}" for l in labels] |
|
clusters.append(group) |
|
cluster_id += len(set(labels)) |
|
|
|
return pd.concat(clusters) |
|
|
|
|
|
def to_excel(df: pd.DataFrame) -> bytes: |
|
output = BytesIO() |
|
with pd.ExcelWriter(output, engine="xlsxwriter") as writer: |
|
df.to_excel(writer, index=False, sheet_name="Clusters") |
|
return output.getvalue() |
|
|
|
|
|
st.title("Automatic Site Clustering App") |
|
|
|
|
|
st.write( |
|
"""This app allows you to cluster sites based on their latitude and longitude. |
|
**Please choose a file containing the latitude and longitude region and site code columns.** |
|
""" |
|
) |
|
|
|
uploaded_file = st.file_uploader("Upload your Excel file", type=["xlsx"]) |
|
|
|
if uploaded_file: |
|
df = pd.read_excel(uploaded_file) |
|
st.write("Sample of uploaded data:", df.head()) |
|
|
|
columns = df.columns.tolist() |
|
|
|
with st.form("clustering_form"): |
|
lat_col = st.selectbox("Select Latitude column", columns) |
|
lon_col = st.selectbox("Select Longitude column", columns) |
|
region_col = st.selectbox("Select Region column", columns) |
|
code_col = st.selectbox("Select Site Code column", columns) |
|
max_sites = st.number_input( |
|
"Max sites per cluster", min_value=5, max_value=100, value=25 |
|
) |
|
mix_regions = st.checkbox( |
|
"Allow mixing different regions in clusters", value=False |
|
) |
|
submitted = st.form_submit_button("Run Clustering") |
|
|
|
if submitted: |
|
clustered_df = cluster_sites( |
|
df, lat_col, lon_col, region_col, max_sites, mix_regions |
|
) |
|
st.success("Clustering completed!") |
|
st.write(clustered_df.head()) |
|
|
|
|
|
fig = px.scatter_map( |
|
clustered_df, |
|
lat=lat_col, |
|
lon=lon_col, |
|
color="Cluster", |
|
hover_name=code_col, |
|
hover_data=[region_col], |
|
zoom=5, |
|
height=600, |
|
) |
|
fig.update_layout(mapbox_style="open-street-map") |
|
st.plotly_chart(fig) |
|
|
|
|
|
st.download_button( |
|
label="Download clustered Excel file", |
|
data=to_excel(clustered_df), |
|
file_name="clustered_sites.xlsx", |
|
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", |
|
on_click="ignore", |
|
type="primary", |
|
) |
|
|