awacke1 commited on
Commit
3f3bffe
·
1 Parent(s): 6e73e6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -0
app.py CHANGED
@@ -22,3 +22,154 @@ from datasets import load_dataset
22
 
23
  geo = load_dataset('jamescalam/world-cities-geo', split='train')
24
  st.write(geo)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  geo = load_dataset('jamescalam/world-cities-geo', split='train')
24
  st.write(geo)
25
+
26
+
27
+
28
+
29
+
30
+
31
+ import plotly.express as px
32
+
33
+ palette = ['#1c17ff', '#faff00', '#8cf1ff', '#000000', '#030080', '#738fab']
34
+
35
+ fig = px.scatter_3d(
36
+ x=geo['x'], y=geo['y'], z=geo['z'],
37
+ color=geo['continent'],
38
+ custom_data=[geo['country'], geo['city']],
39
+ color_discrete_sequence=palette
40
+ )
41
+ fig.update_traces(
42
+ hovertemplate="\n".join([
43
+ "city: %{customdata[1]}",
44
+ "country: %{customdata[0]}"
45
+ ])
46
+ )
47
+
48
+
49
+ fig.write_html("umap-earth-3d.html", include_plotlyjs="cdn", full_html=False)
50
+
51
+
52
+ import numpy as np
53
+
54
+ geo_arr = np.asarray([geo['x'], geo['y'], geo['z']]).T
55
+ geo_arr = geo_arr / geo_arr.max()
56
+
57
+ st.markdown(geo_arr[:5])
58
+
59
+
60
+ import umap
61
+
62
+
63
+ colors = geo['continent']
64
+ c_map = {
65
+ 'Africa': '#8cf1ff',
66
+ 'Asia': '#1c17ff',
67
+ 'Europe': '#faff00',
68
+ 'North America': '#738fab',
69
+ 'Oceania': '#030080',
70
+ 'South America': '#000000'
71
+ }
72
+ for i in range(len(colors)):
73
+ colors[i] = c_map[colors[i]]
74
+ colors[:5]
75
+
76
+
77
+
78
+ import matplotlib.pyplot as plt
79
+ import seaborn as sns
80
+ from tqdm.auto import tqdm
81
+
82
+ fig, ax = plt.subplots(3, 3, figsize=(14, 14))
83
+ nns = [2, 3, 4, 5, 10, 15, 30, 50, 100]
84
+ i, j = 0, 0
85
+ for n_neighbors in tqdm(nns):
86
+ fit = umap.UMAP(n_neighbors=n_neighbors)
87
+ u = fit.fit_transform(geo_arr)
88
+ sns.scatterplot(x=u[:,0], y=u[:,1], c=colors, ax=ax[j, i])
89
+ ax[j, i].set_title(f'n={n_neighbors}')
90
+ if i < 2: i += 1
91
+ else: i = 0; j += 1
92
+
93
+
94
+ target = geo['continent']
95
+ t_map = {
96
+ 'Africa': 0,
97
+ 'Asia': 1,
98
+ 'Europe': 2,
99
+ 'North America': 3,
100
+ 'Oceania': 4,
101
+ 'South America': 5
102
+ }
103
+ for i in range(len(target)):
104
+ target[i] = t_map[target[i]]
105
+
106
+ fig, ax = plt.subplots(3, 3, figsize=(14, 14))
107
+ nns = [2, 3, 4, 5, 10, 15, 30, 50, 100]
108
+ i, j = 0, 0
109
+ for n_neighbors in tqdm(nns):
110
+ fit = umap.UMAP(n_neighbors=n_neighbors)
111
+ u = fit.fit_transform(geo_arr, y=target)
112
+ sns.scatterplot(x=u[:,0], y=u[:,1], c=colors, ax=ax[j, i])
113
+ ax[j, i].set_title(f'n={n_neighbors}')
114
+ if i < 2: i += 1
115
+ else: i = 0; j += 1
116
+
117
+ import umap
118
+
119
+ fit = umap.UMAP(n_neighbors=50, min_dist=0.5)
120
+ u = fit.fit_transform(geo_arr)
121
+
122
+ fig = px.scatter(
123
+ x=u[:,0], y=u[:,1],
124
+ color=geo['continent'],
125
+ custom_data=[geo['country'], geo['city']],
126
+ color_discrete_sequence=palette
127
+ )
128
+ fig.update_traces(
129
+ hovertemplate="\n".join([
130
+ "city: %{customdata[1]}",
131
+ "country: %{customdata[0]}"
132
+ ])
133
+ )
134
+
135
+ fig.write_html("umap-earth-2d.html", include_plotlyjs="cdn", full_html=False)
136
+
137
+ import pandas as pd
138
+
139
+ umapped = pd.DataFrame({
140
+ 'x': u[:,0],
141
+ 'y': u[:,1],
142
+ 'continent': geo['continent'],
143
+ 'country': geo['country'],
144
+ 'city': geo['city']
145
+ })
146
+
147
+ umapped.to_csv('umapped.csv', sep='|', index=False)
148
+
149
+ from sklearn.decomposition import PCA
150
+
151
+ pca = PCA(n_components=2) # this means we will create 2-d space
152
+ p = pca.fit_transform(geo_arr)
153
+ fig = px.scatter(
154
+ x=p[:,0], y=p[:,1],
155
+ color=geo['continent'],
156
+ custom_data=[geo['country'], geo['city']],
157
+ color_discrete_sequence=palette
158
+ )
159
+ fig.update_traces(
160
+ hovertemplate="\n".join([
161
+ "city: %{customdata[1]}",
162
+ "country: %{customdata[0]}"
163
+ ])
164
+ )
165
+
166
+ fig.write_html("pca-earth-2d.html", include_plotlyjs="cdn", full_html=False)
167
+
168
+
169
+
170
+
171
+
172
+
173
+
174
+
175
+