dropbop commited on
Commit
e498afe
·
verified ·
1 Parent(s): 99a41a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -27
app.py CHANGED
@@ -1,5 +1,8 @@
1
  import gradio as gr
2
  import earthview as ev
 
 
 
3
  import os
4
  import json
5
  import utils
@@ -7,8 +10,11 @@ from pandas import DataFrame
7
 
8
  # --- Configuration ---
9
  DATASET_SUBSET = "satellogic"
 
10
  LABELED_DATA_FILE = "labeled_data.json"
 
11
  SAMPLE_SEED = 10 # The seed to use when sampling the dataset for the demo.
 
12
 
13
  # --- Load Dataset ---
14
  dataset = ev.load_dataset(DATASET_SUBSET, shards=[SAMPLE_SEED])
@@ -35,14 +41,30 @@ def get_next_sample():
35
  print("No more samples in the dataset.")
36
  return None
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # --- Save Labeled Data ---
39
  def save_labeled_data(label, state):
40
  global labeled_data
41
  sample = state["sample"]
42
  if sample is None:
43
- return "No image to label", DataFrame()
44
 
45
- image_bytes = sample["rgb"][0].convert("RGB").tobytes() # Convert PIL Image to bytes
 
46
  labeled_data.append({
47
  "image": image_bytes,
48
  "metadata": sample["metadata"],
@@ -55,45 +77,91 @@ def save_labeled_data(label, state):
55
  new_sample = get_next_sample()
56
  if new_sample is None:
57
  state["sample"] = None
58
- return "Dataset exhausted.", DataFrame()
59
 
60
  state["sample"] = new_sample
 
61
  new_metadata = new_sample["metadata"]
62
  new_metadata["map"] = f'<a href="{utils.get_google_map_link(new_sample, DATASET_SUBSET)}" target="_blank">🧭</a>'
63
- return "", DataFrame([new_metadata])
64
 
65
  # --- Gradio Interface ---
 
66
  def labeling_ui():
67
  state = gr.State({"sample": None, "subset": DATASET_SUBSET})
68
 
69
  with gr.Row():
70
- cool_button = gr.Button("Cool")
71
- not_cool_button = gr.Button("Not Cool")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- table = gr.DataFrame(datatype="html")
 
 
 
 
74
 
75
- def initialize_labeling_ui():
76
- sample = get_next_sample()
77
- if sample is None:
78
- return {"sample": None, "subset": DATASET_SUBSET}, DataFrame()
79
- metadata = sample["metadata"]
80
- metadata["map"] = f'<a href="{utils.get_google_map_link(sample, DATASET_SUBSET)}" target="_blank">🧭</a>'
81
- return {"sample": sample, "subset": DATASET_SUBSET}, DataFrame([metadata])
82
 
83
- initial_state, initial_metadata = initialize_labeling_ui()
84
- table.value = initial_metadata
85
- state.value = initial_state
86
 
87
- cool_button.click(
88
- fn=lambda label, state: save_labeled_data(label, state),
89
- inputs=[gr.Textbox(visible=False, value="cool"), state],
90
- outputs=[gr.Textbox(label="Debug"), table]
91
- )
92
- not_cool_button.click(
93
- fn=lambda label, state: save_labeled_data(label, state),
94
- inputs=[gr.Textbox(visible=False, value="not cool"), state],
95
- outputs=[gr.Textbox(label="Debug"), table]
96
- )
97
 
98
  # --- Main Interface ---
99
  with gr.Blocks() as demo:
@@ -101,5 +169,7 @@ with gr.Blocks() as demo:
101
  with gr.Tabs():
102
  with gr.TabItem("Labeling"):
103
  labeling_ui()
 
 
104
 
105
  demo.launch(debug=True)
 
1
  import gradio as gr
2
  import earthview as ev
3
+ from PIL import Image
4
+ import numpy as np
5
+ import random
6
  import os
7
  import json
8
  import utils
 
10
 
11
  # --- Configuration ---
12
  DATASET_SUBSET = "satellogic"
13
+ NUM_SAMPLES_TO_LABEL = 100 # You can adjust this
14
  LABELED_DATA_FILE = "labeled_data.json"
15
+ DISPLAY_N_COOL = 5 # How many "cool" examples to display alongside each new image.
16
  SAMPLE_SEED = 10 # The seed to use when sampling the dataset for the demo.
17
+ BATCH_SIZE = 10
18
 
19
  # --- Load Dataset ---
20
  dataset = ev.load_dataset(DATASET_SUBSET, shards=[SAMPLE_SEED])
 
41
  print("No more samples in the dataset.")
42
  return None
43
 
44
+ def get_images(batch_size, state):
45
+ subset = state["subset"]
46
+ images = []
47
+ metadatas = []
48
+ for _ in range(batch_size):
49
+ sample = get_next_sample()
50
+ if sample is None:
51
+ break
52
+ image = sample["rgb"][0]
53
+ metadata = sample["metadata"]
54
+ metadata["map"] = f'<a href="{utils.get_google_map_link(sample, subset)}" target="_blank">🧭</a>'
55
+ images.append(image)
56
+ metadatas.append(metadata)
57
+ return images, DataFrame(metadatas)
58
+
59
  # --- Save Labeled Data ---
60
  def save_labeled_data(label, state):
61
  global labeled_data
62
  sample = state["sample"]
63
  if sample is None:
64
+ return "No image to label", None, DataFrame()
65
 
66
+ image = sample["rgb"][0]
67
+ image_bytes = image.convert("RGB").tobytes()
68
  labeled_data.append({
69
  "image": image_bytes,
70
  "metadata": sample["metadata"],
 
77
  new_sample = get_next_sample()
78
  if new_sample is None:
79
  state["sample"] = None
80
+ return "Dataset exhausted.", [], DataFrame()
81
 
82
  state["sample"] = new_sample
83
+ new_image = new_sample["rgb"][0]
84
  new_metadata = new_sample["metadata"]
85
  new_metadata["map"] = f'<a href="{utils.get_google_map_link(new_sample, DATASET_SUBSET)}" target="_blank">🧭</a>'
86
+ return "", [new_image], DataFrame([new_metadata]) # Return a list of image data
87
 
88
  # --- Gradio Interface ---
89
+ # --- Labeling UI ---
90
  def labeling_ui():
91
  state = gr.State({"sample": None, "subset": DATASET_SUBSET})
92
 
93
  with gr.Row():
94
+ with gr.Column():
95
+ gallery = gr.Gallery(label="Satellite Image", interactive=False, columns=1, object_fit="scale-down")
96
+ with gr.Row():
97
+ cool_button = gr.Button("Cool")
98
+ not_cool_button = gr.Button("Not Cool")
99
+ table = gr.DataFrame(datatype="html")
100
+
101
+ def initialize_labeling_ui():
102
+ sample = get_next_sample()
103
+ images, metadata = get_images(1, {"sample": None, "subset": DATASET_SUBSET})
104
+ return sample, images, metadata
105
+
106
+ initial_sample, initial_images, initial_metadata = initialize_labeling_ui()
107
+ gallery.value = initial_images # Initialize with a list of images
108
+ table.value = initial_metadata
109
+ state.value = {"sample": initial_sample, "subset": DATASET_SUBSET}
110
+
111
+ cool_button.click(
112
+ fn=lambda label, state: save_labeled_data(label, state),
113
+ inputs=[gr.Textbox(visible=False, value="cool"), state],
114
+ outputs=[gr.Textbox(label="Debug"), gallery, table]
115
+ )
116
+ not_cool_button.click(
117
+ fn=lambda label, state: save_labeled_data(label, state),
118
+ inputs=[gr.Textbox(visible=False, value="not cool"), state],
119
+ outputs=[gr.Textbox(label="Debug"), gallery, table]
120
+ )
121
+
122
+ # --- Display UI ---
123
+ def display_ui():
124
+ def get_random_cool_images(n):
125
+ cool_samples = [d for d in labeled_data if d["label"] == "cool"]
126
+ return [Image.frombytes("RGB", (384, 384), s["image"]) for s in cool_samples] if len(cool_samples) >= n else []
127
+
128
+ def get_new_unlabeled_image():
129
+ global data_iter
130
+ try:
131
+ sample = next(data_iter)
132
+ sample = ev.item_to_images(DATASET_SUBSET, sample)
133
+ return sample["rgb"][0], json.dumps(sample["metadata"]["bounds"])
134
+ except StopIteration:
135
+ print("No more samples in the dataset.")
136
+ return None, None
137
+
138
+ def refresh_display():
139
+ new_image, new_metadata = get_new_unlabeled_image()
140
+ cool_images = get_random_cool_images(DISPLAY_N_COOL)
141
+ if new_image is None:
142
+ return "No more samples", None, []
143
+ return "", new_image, cool_images
144
 
145
+ with gr.Row():
146
+ new_image_component = gr.Image(label="New Image", type="pil")
147
+ metadata_display = gr.Textbox(label="Metadata (Bounds)")
148
+ with gr.Row():
149
+ cool_images_gallery = gr.Gallery(label="Cool Examples", value=[], columns=DISPLAY_N_COOL)
150
 
151
+ refresh_button = gr.Button("Refresh")
152
+ refresh_button.click(
153
+ fn=refresh_display,
154
+ inputs=[],
155
+ outputs=[gr.Textbox(label="Debug"), new_image_component, cool_images_gallery]
156
+ )
 
157
 
158
+ def initialize_display_ui():
159
+ debug, image, gallery = refresh_display()
160
+ return debug, image, gallery
161
 
162
+ debug, initial_image, initial_gallery = initialize_display_ui()
163
+ new_image_component.value = initial_image
164
+ cool_images_gallery.value = initial_gallery
 
 
 
 
 
 
 
165
 
166
  # --- Main Interface ---
167
  with gr.Blocks() as demo:
 
169
  with gr.Tabs():
170
  with gr.TabItem("Labeling"):
171
  labeling_ui()
172
+ with gr.TabItem("Display"):
173
+ display_ui()
174
 
175
  demo.launch(debug=True)