DavMelchi commited on
Commit
b5e340d
·
1 Parent(s): 962e6f5

improve clustering app

Browse files
apps/clustering.py CHANGED
@@ -4,10 +4,11 @@ import numpy as np
4
  import pandas as pd
5
  import plotly.express as px
6
  import streamlit as st
 
7
  from sklearn.cluster import KMeans
8
 
9
 
10
- def cluster_sites(
11
  df: pd.DataFrame,
12
  lat_col: str,
13
  lon_col: str,
@@ -23,20 +24,103 @@ def cluster_sites(
23
  else:
24
  grouped = [("All", df)]
25
 
 
 
 
 
26
  for region, group in grouped:
27
- coords = group[[lat_col, lon_col]].to_numpy()
28
- n_clusters = max(1, int(np.ceil(len(group) / max_sites)))
29
 
30
- if len(group) < max_sites:
31
- labels = np.zeros(len(group), dtype=int)
32
- else:
33
- kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
34
- labels = kmeans.fit_predict(coords)
35
 
36
  group = group.copy()
37
- group["Cluster"] = [f"C{cluster_id + l}" for l in labels]
38
- clusters.append(group)
39
- cluster_id += len(set(labels))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  return pd.concat(clusters)
42
 
@@ -57,7 +141,18 @@ st.write(
57
  """
58
  )
59
 
60
- uploaded_file = st.file_uploader("Upload your Excel file", type=["xlsx"])
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  if uploaded_file:
63
  df = pd.read_excel(uploaded_file)
@@ -73,30 +168,58 @@ if uploaded_file:
73
  max_sites = st.number_input(
74
  "Max sites per cluster", min_value=5, max_value=100, value=25
75
  )
 
 
 
 
76
  mix_regions = st.checkbox(
77
  "Allow mixing different regions in clusters", value=False
78
  )
79
  submitted = st.form_submit_button("Run Clustering")
80
 
81
  if submitted:
82
- clustered_df = cluster_sites(
83
- df, lat_col, lon_col, region_col, max_sites, mix_regions
84
- )
 
 
 
 
 
85
  st.success("Clustering completed!")
86
  st.write(clustered_df.head())
87
 
88
  # Plot
 
89
  fig = px.scatter_map(
90
  clustered_df,
91
  lat=lat_col,
92
  lon=lon_col,
93
  color="Cluster",
 
94
  hover_name=code_col,
95
  hover_data=[region_col],
96
  zoom=5,
97
  height=600,
98
  )
99
  fig.update_layout(mapbox_style="open-street-map")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  st.plotly_chart(fig)
101
 
102
  # Download button
 
4
  import pandas as pd
5
  import plotly.express as px
6
  import streamlit as st
7
+ from hilbertcurve.hilbertcurve import HilbertCurve
8
  from sklearn.cluster import KMeans
9
 
10
 
11
+ def cluster_sites_hilbert_curve_same_size(
12
  df: pd.DataFrame,
13
  lat_col: str,
14
  lon_col: str,
 
24
  else:
25
  grouped = [("All", df)]
26
 
27
+ # Create Hilbert Curve (higher p = more precision)
28
+ p = 16 # Adjust based on your coordinate precision needs
29
+ hilbert_curve = HilbertCurve(p, 2) # 2D curve
30
+
31
  for region, group in grouped:
32
+ if len(group) == 0:
33
+ continue
34
 
35
+ # Normalize coordinates to [0, 2^p-1] range
36
+ lat_min, lat_max = group[lat_col].min(), group[lat_col].max()
37
+ lon_min, lon_max = group[lon_col].min(), group[lon_col].max()
 
 
38
 
39
  group = group.copy()
40
+ group["x"] = ((group[lat_col] - lat_min) / (lat_max - lat_min + 1e-10)) * (
41
+ 2**p - 1
42
+ )
43
+ group["y"] = ((group[lon_col] - lon_min) / (lon_max - lon_min + 1e-10)) * (
44
+ 2**p - 1
45
+ )
46
+
47
+ # Calculate Hilbert distance
48
+ group["hilbert"] = group.apply(
49
+ lambda row: hilbert_curve.distance_from_point(
50
+ [int(row["x"]), int(row["y"])]
51
+ ),
52
+ axis=1,
53
+ )
54
+
55
+ # Sort by Hilbert value
56
+ group = group.sort_values("hilbert")
57
+
58
+ # Create fixed-size clusters
59
+ for i in range(0, len(group), max_sites):
60
+ cluster = group.iloc[i : i + max_sites].copy()
61
+ cluster["Cluster"] = f"C{cluster_id}"
62
+ clusters.append(cluster)
63
+ cluster_id += 1
64
+
65
+ result = pd.concat(clusters)
66
+ return result.drop(columns=["x", "y", "hilbert"], errors="ignore")
67
+
68
+
69
+ def cluster_sites_kmeans_lower_to_fixed_size(
70
+ df: pd.DataFrame,
71
+ lat_col: str,
72
+ lon_col: str,
73
+ region_col: str,
74
+ max_sites: int = 25,
75
+ mix_regions: bool = False,
76
+ ):
77
+ clusters = []
78
+ cluster_id = 0
79
+
80
+ if not mix_regions:
81
+ grouped = df.groupby(region_col)
82
+ else:
83
+ grouped = [("All", df)]
84
+
85
+ for region, group in grouped:
86
+ coords = group[[lat_col, lon_col]].to_numpy()
87
+ remaining_sites = group.copy()
88
+
89
+ while len(remaining_sites) > 0:
90
+ # Calculate number of clusters needed for remaining sites
91
+ n_remaining = len(remaining_sites)
92
+ n_clusters = max(1, int(np.ceil(n_remaining / max_sites)))
93
+
94
+ if n_remaining <= max_sites:
95
+ # If remaining sites can fit in one cluster
96
+ cluster_group = remaining_sites.copy()
97
+ cluster_group["Cluster"] = f"C{cluster_id}"
98
+ clusters.append(cluster_group)
99
+ cluster_id += 1
100
+ break
101
+ else:
102
+ # Apply KMeans to remaining sites
103
+ kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
104
+ labels = kmeans.fit_predict(
105
+ remaining_sites[[lat_col, lon_col]].to_numpy()
106
+ )
107
+
108
+ # Split into clusters and check sizes
109
+ temp_df = remaining_sites.copy()
110
+ temp_df["Cluster"] = labels
111
+ temp_df["Temp_Cluster"] = labels
112
+
113
+ for cluster_num in range(n_clusters):
114
+ cluster_group = temp_df[temp_df["Temp_Cluster"] == cluster_num]
115
+ if len(cluster_group) <= max_sites:
116
+ # If cluster is small enough, keep it
117
+ cluster_group = cluster_group.drop(columns=["Temp_Cluster"])
118
+ cluster_group["Cluster"] = f"C{cluster_id}"
119
+ clusters.append(cluster_group)
120
+ cluster_id += 1
121
+ # Remove these sites from remaining_sites
122
+ remaining_sites = remaining_sites.drop(cluster_group.index)
123
+ # Else these sites will remain for next iteration
124
 
125
  return pd.concat(clusters)
126
 
 
141
  """
142
  )
143
 
144
+ # Download Sample file
145
+ clustering_sample_file_path = "samples/Site_Clustering.xlsx"
146
+
147
+ # Create a download button
148
+ st.download_button(
149
+ label="Download Clustering Sample File",
150
+ data=open(clustering_sample_file_path, "rb").read(),
151
+ file_name="Site_Clustering.xlsx",
152
+ mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
153
+ )
154
+
155
+ uploaded_file = st.file_uploader("Upload your Excel file ", type=["xlsx"])
156
 
157
  if uploaded_file:
158
  df = pd.read_excel(uploaded_file)
 
168
  max_sites = st.number_input(
169
  "Max sites per cluster", min_value=5, max_value=100, value=25
170
  )
171
+ cluster_method = st.selectbox(
172
+ "Select clustering method",
173
+ ["Hilbert Curve Same Size", "KMeans Lower To Fixed Size"],
174
+ )
175
  mix_regions = st.checkbox(
176
  "Allow mixing different regions in clusters", value=False
177
  )
178
  submitted = st.form_submit_button("Run Clustering")
179
 
180
  if submitted:
181
+ if cluster_method == "Hilbert Curve Same Size":
182
+ clustered_df = cluster_sites_hilbert_curve_same_size(
183
+ df, lat_col, lon_col, region_col, max_sites, mix_regions
184
+ )
185
+ elif cluster_method == "KMeans Lower To Fixed Size":
186
+ clustered_df = cluster_sites_kmeans_lower_to_fixed_size(
187
+ df, lat_col, lon_col, region_col, max_sites, mix_regions
188
+ )
189
  st.success("Clustering completed!")
190
  st.write(clustered_df.head())
191
 
192
  # Plot
193
+ clustered_df["size"] = 10
194
  fig = px.scatter_map(
195
  clustered_df,
196
  lat=lat_col,
197
  lon=lon_col,
198
  color="Cluster",
199
+ size="size",
200
  hover_name=code_col,
201
  hover_data=[region_col],
202
  zoom=5,
203
  height=600,
204
  )
205
  fig.update_layout(mapbox_style="open-street-map")
206
+ fig.update_traces(marker=dict(size=15))
207
+ st.plotly_chart(fig)
208
+
209
+ # Show cluster size per cluster plot
210
+ cluster_size = clustered_df["Cluster"].value_counts().sort_index()
211
+ fig = px.bar(cluster_size, x=cluster_size.index, y=cluster_size.values)
212
+ fig.update_layout(title="Cluster Size")
213
+ st.plotly_chart(fig)
214
+
215
+ # Show cluster size per region plot
216
+ cluster_size_per_region = (
217
+ clustered_df.groupby([region_col, "Cluster"])
218
+ .size()
219
+ .reset_index(name="count")
220
+ )
221
+ fig = px.bar(cluster_size_per_region, x="Cluster", y="count", color=region_col)
222
+ fig.update_layout(title="Cluster Size per Region")
223
  st.plotly_chart(fig)
224
 
225
  # Download button
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
 
samples/Site_Clustering.xlsx ADDED
Binary file (39.9 kB). View file