dilithjay commited on
Commit
e87025c
·
1 Parent(s): 6777b94

Initial Commit

Browse files
Files changed (5) hide show
  1. .gitignore +4 -0
  2. README.md +2 -3
  3. app.py +397 -0
  4. requirements.txt +4 -0
  5. utils.py +22 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ data/
2
+ flagged/
3
+ **/__pycache__/
4
+ venv/
README.md CHANGED
@@ -4,10 +4,9 @@ emoji: 🚀
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.12.0
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.21.0
8
+ python_version: 3.8
9
  app_file: app.py
10
  pinned: false
11
  license: mit
12
  ---
 
 
app.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import tempfile
4
+ import gradio as gr
5
+ import plotly.graph_objects as go
6
+
7
+ import pandas as pd
8
+ from time import time
9
+ from utils import (
10
+ create_file_structure,
11
+ init_info_csv,
12
+ add_to_info_csv,
13
+ )
14
+
15
+ from satseg.dataset import create_datasets, create_inference_dataset
16
+ from satseg.model import train_model, save_model, run_inference, load_model
17
+ from satseg.seg_result import combine_seg_maps, get_combined_map_contours
18
+ from satseg.geo_tools import (
19
+ shapefile_to_latlong,
20
+ shapefile_to_grid_indices,
21
+ points_to_shapefile,
22
+ contours_to_shapefile,
23
+ get_tif_n_channels,
24
+ )
25
+
26
+ DATA_DIR = "data"
27
+ MODEL_DIR = os.path.join(DATA_DIR, "models")
28
+ TIF_DIR = os.path.join(DATA_DIR, "tifs")
29
+ MASK_DIR = os.path.join(DATA_DIR, "masks")
30
+ INFO_DIR = os.path.join(DATA_DIR, "info")
31
+
32
+ MODEL_INFO_PATH = os.path.join(INFO_DIR, "model_data.csv")
33
+ DATASET_TIF_INFO_PATH = os.path.join(INFO_DIR, "dataset_tif_data.csv")
34
+ DATASET_MASK_INFO_PATH = os.path.join(INFO_DIR, "dataset_mask_data.csv")
35
+
36
+ create_file_structure(
37
+ [DATA_DIR, TIF_DIR, MASK_DIR, INFO_DIR],
38
+ [MODEL_INFO_PATH, DATASET_TIF_INFO_PATH, DATASET_MASK_INFO_PATH],
39
+ )
40
+ init_info_csv(
41
+ MODEL_INFO_PATH,
42
+ [
43
+ "Name",
44
+ "Architecture",
45
+ "# of channels",
46
+ "Train TIF",
47
+ "Train Mask",
48
+ "Expression",
49
+ "Path",
50
+ ],
51
+ )
52
+ init_info_csv(DATASET_TIF_INFO_PATH, ["Name", "# of channels", "Path"])
53
+ init_info_csv(DATASET_MASK_INFO_PATH, ["Name", "Class", "Path"])
54
+
55
+
56
+ def gr_train_model(
57
+ tif_names, mask_names, model_name, expression, progress=gr.Progress()
58
+ ):
59
+ tif_paths = list(map(lambda x: os.path.join(TIF_DIR, x), tif_names))
60
+ mask_paths = list(map(lambda x: os.path.join(MASK_DIR, x), mask_names))
61
+ expression = expression.strip().split()
62
+
63
+ # if arch.lower() == "best":
64
+ # arch = "dcama" if len(train_set) > 8 and len(train_set) < 20 else "unet"
65
+ # ( c6 - c0 ) / ( c6 + c0 ) =
66
+ progress(0, desc="Creating Dataset...")
67
+ with tempfile.TemporaryDirectory() as tempdir:
68
+ train_set, val_set = create_datasets(
69
+ tif_paths, mask_paths, tempdir, expression=expression
70
+ )
71
+ progress(0.05, desc="Training Model...")
72
+ model, _ = train_model(train_set, val_set, "unet")
73
+
74
+ progress(0.95, desc="Model Trained! Saving...")
75
+ model_name = "_".join(model_name.split()) + ".pt"
76
+ model_path = os.path.join(MODEL_DIR, model_name)
77
+ save_model(model, model_path)
78
+ add_to_info_csv(
79
+ MODEL_INFO_PATH,
80
+ [
81
+ model_name,
82
+ "UNet",
83
+ val_set.n_channels,
84
+ ";".join(tif_names),
85
+ ";".join(mask_names),
86
+ " ".join(expression),
87
+ model_path,
88
+ ],
89
+ )
90
+ progress(1.0, desc="Done!")
91
+ model_df = pd.read_csv(MODEL_INFO_PATH)
92
+
93
+ return "Done!", model_df, gr.Dropdown.update(choices=model_df["Name"].to_list())
94
+
95
+
96
+ def gr_run_inference(tif_names, model_name, progress=gr.Progress()):
97
+ t = time()
98
+ tif_paths = list(map(lambda x: os.path.join(TIF_DIR, x), tif_names))
99
+ model_df = pd.read_csv(MODEL_INFO_PATH, index_col="Name")
100
+ model_path = model_df["Path"][model_name]
101
+
102
+ with tempfile.TemporaryDirectory() as tempdir:
103
+ progress(0, desc="Creating Dataset...")
104
+ dataset = create_inference_dataset(
105
+ tif_paths,
106
+ tempdir,
107
+ 256,
108
+ expression=model_df["Expression"][model_name].split(),
109
+ )
110
+ progress(0.1, desc="Loading Model...")
111
+ model = load_model(model_path)
112
+
113
+ result_dir = os.path.join(tempdir, "infer")
114
+ comb_result_dir = os.path.join(tempdir, "comb")
115
+ os.makedirs(result_dir)
116
+ os.makedirs(comb_result_dir)
117
+ progress(0.2, desc="Running Inference...")
118
+ run_inference(dataset, model, result_dir)
119
+ progress(0.8, desc="Preparing output...")
120
+ combine_seg_maps(result_dir, comb_result_dir)
121
+ results = get_combined_map_contours(comb_result_dir)
122
+
123
+ file_paths = []
124
+ out_dir = os.path.join(MASK_DIR, "output")
125
+ if os.path.exists(out_dir):
126
+ shutil.rmtree(out_dir)
127
+ os.makedirs(out_dir)
128
+ for tif_name, (contours, hierarchy) in results.items():
129
+ tif_path = os.path.join(TIF_DIR, f"{tif_name}.tif")
130
+ mask_path = os.path.join(out_dir, f"{tif_name}_mask.shp")
131
+ zip_path = contours_to_shapefile(contours, hierarchy, tif_path, mask_path)
132
+ file_paths.append(zip_path)
133
+ print(time() - t, "seconds")
134
+ return file_paths
135
+
136
+
137
+ def gr_save_mask_file(file_objs, filenames, obj_class):
138
+ print("Saving file(s)...")
139
+ idx = 0
140
+ for filename in filenames.split(";"):
141
+ if filename.strip() == "":
142
+ continue
143
+
144
+ filepath = os.path.join(MASK_DIR, filename.strip())
145
+ obj = file_objs[idx]
146
+ idx += 1
147
+
148
+ shutil.move(obj.name, filepath)
149
+ if filename.endswith(".shp"):
150
+ add_to_info_csv(DATASET_MASK_INFO_PATH, [filename, obj_class, filepath])
151
+ print("Done!")
152
+
153
+ dataset_df = pd.read_csv(DATASET_MASK_INFO_PATH)
154
+ choices = dataset_mask_df["Name"].to_list()
155
+ update = gr.Dropdown.update(choices=choices)
156
+
157
+ return dataset_df, update, update
158
+
159
+
160
+ def gr_save_tif_file(file_objs, filenames):
161
+ print("Saving file(s)...")
162
+ idx = 0
163
+ for filename in filenames.split(";"):
164
+ if filename.strip() == "":
165
+ continue
166
+
167
+ filepath = os.path.join(TIF_DIR, filename.strip())
168
+ obj = file_objs[idx]
169
+ idx += 1
170
+
171
+ shutil.copy2(obj.name, filepath)
172
+ n = get_tif_n_channels(filepath)
173
+ add_to_info_csv(DATASET_TIF_INFO_PATH, [filename, n, filepath])
174
+ print("Done!")
175
+
176
+ dataset_df = pd.read_csv(DATASET_TIF_INFO_PATH)
177
+ choices = dataset_mask_df["Name"].to_list()
178
+ update = gr.Dropdown.update(choices=choices)
179
+
180
+ return dataset_df, update, update
181
+
182
+
183
+ def gr_generate_map(mask_name: str, token: str = "", show_grid=True, show_mask=False):
184
+ mask_path = os.path.join(MASK_DIR, mask_name)
185
+ # token = "pk.eyJ1IjoiZGlsaXRoIiwiYSI6ImNsaDQ3NXF3ZDAxdDMzZXMxeWJic2h1cDQifQ.DDczQCDfTgQEUt6pGvjUAg"
186
+ center = (7.753769, 80.691730)
187
+
188
+ scattermaps = []
189
+ if show_grid:
190
+ indices = shapefile_to_grid_indices(mask_path)
191
+ points_to_shapefile(indices, mask_path[: -len(".shp")] + "-grid.shp")
192
+ scattermaps.append(
193
+ go.Scattermapbox(
194
+ lat=indices[:, 1],
195
+ lon=indices[:, 0],
196
+ mode="markers",
197
+ marker=go.scattermapbox.Marker(size=6),
198
+ )
199
+ )
200
+ if show_mask:
201
+ contours = shapefile_to_latlong(mask_path)
202
+ for contour in contours[38:39]:
203
+ lons = contour[:, 0]
204
+ lats = contour[:, 1]
205
+ scattermaps.append(
206
+ go.Scattermapbox(
207
+ fill="toself",
208
+ lat=lats,
209
+ lon=lons,
210
+ mode="markers",
211
+ marker=go.scattermapbox.Marker(size=6),
212
+ )
213
+ )
214
+
215
+ fig = go.Figure(scattermaps)
216
+
217
+ if token:
218
+ fig.update_layout(
219
+ mapbox=dict(
220
+ style="satellite-streets",
221
+ accesstoken=token,
222
+ center=go.layout.mapbox.Center(lat=center[0], lon=center[1]),
223
+ pitch=0,
224
+ zoom=7,
225
+ ),
226
+ mapbox_layers=[
227
+ {
228
+ # "below": "traces",
229
+ "sourcetype": "raster",
230
+ "sourceattribution": "United States Geological Survey",
231
+ "source": [
232
+ "https://basemap.nationalmap.gov/arcgis/rest/services/USGSImageryOnly/MapServer/tile/{z}/{y}/{x}"
233
+ ],
234
+ }
235
+ ],
236
+ )
237
+ else:
238
+ fig.update_layout(
239
+ mapbox_style="open-street-map",
240
+ hovermode="closest",
241
+ mapbox=dict(
242
+ bearing=0,
243
+ center=go.layout.mapbox.Center(lat=center[0], lon=center[1]),
244
+ pitch=0,
245
+ zoom=7,
246
+ ),
247
+ )
248
+
249
+ return fig
250
+
251
+
252
+ with gr.Blocks() as demo:
253
+ gr.Markdown(
254
+ """# SatSeg
255
+ Train models and run inference for segmentation of multispectral satellite images."""
256
+ )
257
+
258
+ model_df = pd.read_csv(MODEL_INFO_PATH)
259
+ dataset_tif_df = pd.read_csv(DATASET_TIF_INFO_PATH)
260
+ dataset_mask_df = pd.read_csv(DATASET_MASK_INFO_PATH)
261
+
262
+ with gr.Tab("Train"):
263
+ train_tif_names = gr.Dropdown(
264
+ label="TIF Files",
265
+ choices=dataset_tif_df["Name"].to_list(),
266
+ multiselect=True,
267
+ )
268
+ train_mask_names = gr.Dropdown(
269
+ label="Mask files",
270
+ choices=dataset_mask_df["Name"].to_list(),
271
+ multiselect=True,
272
+ )
273
+ train_rs_index = gr.Textbox(
274
+ label="Remote Sensing Index", placeholder="( c0 + c1 ) / ( c0 - c1 ) ="
275
+ )
276
+ # train_arch = gr.Dropdown(
277
+ # label="Model Architecture", choices=["Best", "UNet", "DCAMA"], value="Best"
278
+ # )
279
+ train_model_name = gr.Textbox(
280
+ label="Model Name", placeholder="Give the model a name"
281
+ )
282
+ train_button = gr.Button("Train")
283
+
284
+ train_completion = gr.Text(label="Training Status", value="Not Started")
285
+
286
+ with gr.Tab("Infer"):
287
+ infer_tif_names = gr.Dropdown(
288
+ label="TIF Files",
289
+ choices=dataset_tif_df["Name"].to_list(),
290
+ multiselect=True,
291
+ )
292
+ infer_model_name = gr.Dropdown(
293
+ label="Model Name",
294
+ choices=model_df["Name"].to_list(),
295
+ )
296
+ infer_button = gr.Button("Infer")
297
+
298
+ infer_mask = gr.Files(label="Output Shapefile", interactive=False)
299
+
300
+ # with gr.Tab("Sampling"):
301
+ # grid_mask_name = gr.Dropdown(
302
+ # label="Mask",
303
+ # choices=dataset_mask_df["Name"].to_list(),
304
+ # )
305
+
306
+ # grid_token = gr.Textbox(
307
+ # value="", label="Mapbox Token (https://account.mapbox.com/)"
308
+ # )
309
+ # grid_side_len = gr.Textbox(value="100", label="Sampling Gap (m)")
310
+
311
+ # grid_show_grid = gr.Checkbox(True, label="Show Grid")
312
+ # grid_show_mask = gr.Checkbox(False, label="Show Mask")
313
+
314
+ # grid_button = gr.Button("Generate Grid")
315
+
316
+ # grid_map = gr.Plot(label="Plot")
317
+
318
+ with gr.Tab("Datasets"):
319
+ dataset_tif_df = pd.read_csv(DATASET_TIF_INFO_PATH)
320
+ dataset_mask_df = pd.read_csv(DATASET_MASK_INFO_PATH)
321
+
322
+ datasets_upload_tif = gr.File(label="Images (.tif)", file_count="multiple")
323
+ datasets_upload_tif_name = gr.Textbox(
324
+ label="TIF name", placeholder="tif_file_1.tif;tif_file_2.tif"
325
+ )
326
+ datasets_save_uploaded_tif = gr.Button("Save")
327
+
328
+ datasets_upload_mask = gr.File(
329
+ label="Masks (Please upload all extensions (.shp, .shx, etc.))",
330
+ file_count="multiple",
331
+ )
332
+ datasets_upload_mask_name = gr.Textbox(
333
+ label="Mask name", placeholder="mask_1.shp;mask_1.shx"
334
+ )
335
+ datasets_mask_class_name = gr.Textbox(
336
+ label="Class (The name of the object you want to segment)"
337
+ )
338
+ datasets_save_uploaded_mask = gr.Button("Save")
339
+
340
+ datasets_tif_table = gr.Dataframe(dataset_tif_df, label="TIFs")
341
+ datasets_mask_table = gr.Dataframe(dataset_mask_df, label="Masks")
342
+
343
+ with gr.Tab("Models"):
344
+ models_table = gr.Dataframe(model_df)
345
+
346
+ train_button.click(
347
+ gr_train_model,
348
+ inputs=[
349
+ train_tif_names,
350
+ train_mask_names,
351
+ # train_arch,
352
+ train_model_name,
353
+ train_rs_index,
354
+ ],
355
+ outputs=[train_completion, models_table, infer_model_name],
356
+ )
357
+
358
+ infer_button.click(
359
+ gr_run_inference,
360
+ inputs=[infer_tif_names, infer_model_name],
361
+ outputs=[infer_mask],
362
+ )
363
+
364
+ datasets_upload_tif.upload(
365
+ lambda y: ";".join(list(map(lambda x: os.path.basename(x.orig_name), y))),
366
+ inputs=datasets_upload_tif,
367
+ outputs=datasets_upload_tif_name,
368
+ )
369
+
370
+ datasets_upload_mask.upload(
371
+ lambda y: ";".join(list(map(lambda x: os.path.basename(x.orig_name), y))),
372
+ inputs=datasets_upload_mask,
373
+ outputs=datasets_upload_mask_name,
374
+ )
375
+
376
+ # grid_button.click(
377
+ # gr_generate_map,
378
+ # inputs=[grid_mask_name, grid_token, grid_show_grid, grid_show_mask],
379
+ # outputs=grid_map,
380
+ # )
381
+
382
+ datasets_save_uploaded_tif.click(
383
+ gr_save_tif_file,
384
+ inputs=[datasets_upload_tif, datasets_upload_tif_name],
385
+ outputs=[datasets_tif_table, train_tif_names, infer_tif_names],
386
+ )
387
+ datasets_save_uploaded_mask.click(
388
+ gr_save_mask_file,
389
+ inputs=[
390
+ datasets_upload_mask,
391
+ datasets_upload_mask_name,
392
+ datasets_mask_class_name,
393
+ ],
394
+ outputs=[datasets_mask_table, train_mask_names],
395
+ )
396
+
397
+ demo.queue(concurrency_count=10).launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio==3.21.0
2
+ pandas==2.0.0
3
+ plotly==5.13.1
4
+ satseg==0.1.1
utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+
5
+ def init_info_csv(data_info_path: str, header: List[str]):
6
+ with open(data_info_path, "r") as fp:
7
+ if not fp.read().strip():
8
+ add_to_info_csv(data_info_path, header)
9
+
10
+
11
+ def add_to_info_csv(data_info_path: str, info: List[str]):
12
+ with open(data_info_path, "a") as fp:
13
+ fp.write(",".join(list(map(str, info))) + "\n")
14
+
15
+
16
+ def create_file_structure(dirs: List[str], files: List[str]):
17
+ for dir_path in dirs:
18
+ os.makedirs(dir_path, exist_ok=True)
19
+
20
+ for file_path in files:
21
+ with open(file_path, "a"):
22
+ pass