dropbop commited on
Commit
c501c19
·
verified ·
1 Parent(s): 7c425bb

Update app.py

Browse files

# Key Changes and Explanations:

## get_images(): This function now takes a batch_size argument and fetches multiple images and their metadata, similar to the Dataset Viewer. It uses the get_next_sample() function to get images one by one until the batch size is reached or the dataset is exhausted.

## save_labeled_data(): This function is updated to work similarly to get_images(), but it processes only one image at a time. It gets the next sample using get_next_sample(), saves the labeled data, and then updates the UI with the new image and metadata.

## labeling_ui():

Uses gr.Gallery to display images, aligning with the Dataset Viewer's approach.

Uses gr.DataFrame for metadata, which is better suited for structured data.

The initialize_labeling_ui() function is similar to before but uses the new get_images() to fetch the first image and its metadata.

Button click events are updated to work with the new image retrieval logic.

## display_ui(): Remains mostly the same, as it was already functional.

## State Management: Uses gr.State to store the current sample, which is updated each time a new sample is fetched for labeling.

Files changed (1) hide show
  1. app.py +74 -31
app.py CHANGED
@@ -5,6 +5,8 @@ import numpy as np
5
  import random
6
  import os
7
  import json
 
 
8
 
9
  # --- Configuration ---
10
  DATASET_SUBSET = "satellogic"
@@ -12,6 +14,7 @@ NUM_SAMPLES_TO_LABEL = 100 # You can adjust this
12
  LABELED_DATA_FILE = "labeled_data.json"
13
  DISPLAY_N_COOL = 5 # How many "cool" examples to display alongside each new image.
14
  SAMPLE_SEED = 10 # The seed to use when sampling the dataset for the demo.
 
15
 
16
  # --- Load Dataset ---
17
  dataset = ev.load_dataset(DATASET_SUBSET, shards=[SAMPLE_SEED])
@@ -33,64 +36,104 @@ def get_next_sample():
33
  try:
34
  sample = next(data_iter)
35
  sample = ev.item_to_images(DATASET_SUBSET, sample)
36
- image = sample["rgb"][0]
37
- metadata = sample["metadata"]
38
- return image, metadata, len(labeled_data)
39
  except StopIteration:
40
  print("No more samples in the dataset.")
41
- return None, None, len(labeled_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  # --- Save Labeled Data ---
44
- def save_labeled_data(image, metadata, label):
45
  global labeled_data
46
- image_bytes = image.convert("RGB").tobytes() if image else None
 
 
 
 
 
 
47
  labeled_data.append({
48
  "image": image_bytes,
49
- "metadata": metadata,
50
  "label": label
51
  })
 
52
  with open(LABELED_DATA_FILE, "w") as f:
53
  json.dump(labeled_data, f)
54
- new_image, new_metadata, count = get_next_sample()
55
- if new_image is None:
56
- return "Dataset exhausted.", None, "", f"Labeled {count} samples."
57
- return "", new_image, json.dumps(new_metadata["bounds"]), f"Labeled {count} samples."
 
 
 
 
 
 
 
 
 
58
 
59
  # --- Gradio Interface ---
60
  # --- Labeling UI ---
61
  def labeling_ui():
 
 
 
62
  with gr.Row():
63
  with gr.Column():
64
- image_component = gr.Image(label="Satellite Image", type="pil")
65
- metadata_text = gr.Textbox(label="Metadata (Bounds)")
66
- label_count_text = gr.Textbox(label="Label Count")
67
  with gr.Row():
68
  cool_button = gr.Button("Cool")
69
  not_cool_button = gr.Button("Not Cool")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  cool_button.click(
71
- fn=save_labeled_data,
72
- inputs=[image_component, metadata_text],
73
- outputs=[gr.Textbox(label="Debug"), image_component, metadata_text, label_count_text]
74
  )
75
  not_cool_button.click(
76
- fn=save_labeled_data,
77
- inputs=[image_component, metadata_text],
78
- outputs=[gr.Textbox(label="Debug"), image_component, metadata_text, label_count_text]
79
  )
80
 
81
- def initialize_labeling_ui():
82
- image, metadata, count = get_next_sample()
83
- if image:
84
- return image, json.dumps(metadata["bounds"]), f"Labeled {count} samples."
85
- return None, "", "No samples loaded."
86
-
87
- initial_image, initial_metadata, initial_count = initialize_labeling_ui()
88
- image_component.value = initial_image
89
- metadata_text.value = initial_metadata
90
- label_count_text.value = initial_count
91
-
92
  # --- Display UI ---
93
  def display_ui():
 
94
  def get_random_cool_images(n):
95
  cool_samples = [d for d in labeled_data if d["label"] == "cool"]
96
  return [Image.frombytes("RGB", (384, 384), s["image"]) for s in cool_samples] if len(cool_samples) >= n else []
 
5
  import random
6
  import os
7
  import json
8
+ import utils
9
+ from pandas import DataFrame
10
 
11
  # --- Configuration ---
12
  DATASET_SUBSET = "satellogic"
 
14
  LABELED_DATA_FILE = "labeled_data.json"
15
  DISPLAY_N_COOL = 5 # How many "cool" examples to display alongside each new image.
16
  SAMPLE_SEED = 10 # The seed to use when sampling the dataset for the demo.
17
+ BATCH_SIZE = 10
18
 
19
  # --- Load Dataset ---
20
  dataset = ev.load_dataset(DATASET_SUBSET, shards=[SAMPLE_SEED])
 
36
  try:
37
  sample = next(data_iter)
38
  sample = ev.item_to_images(DATASET_SUBSET, sample)
39
+ return sample
 
 
40
  except StopIteration:
41
  print("No more samples in the dataset.")
42
+ return None
43
+
44
+ def get_images(batch_size, state):
45
+ subset = state["subset"]
46
+ images = []
47
+ metadatas = []
48
+ for _ in range(batch_size):
49
+ sample = get_next_sample()
50
+ if sample is None:
51
+ break
52
+
53
+ image = sample["rgb"][0]
54
+ metadata = sample["metadata"]
55
+ metadata["map"] = f'<a href="{utils.get_google_map_link(sample, subset)}" target="_blank">🧭</a>'
56
+
57
+ images.append(image)
58
+ metadatas.append(metadata)
59
+
60
+ return images, DataFrame(metadatas)
61
 
62
  # --- Save Labeled Data ---
63
+ def save_labeled_data(label, state):
64
  global labeled_data
65
+ sample = state["sample"]
66
+ if sample is None:
67
+ return "No image to label", None, DataFrame()
68
+
69
+ image = sample["rgb"][0] # Get the PIL Image object
70
+
71
+ image_bytes = image.convert("RGB").tobytes()
72
  labeled_data.append({
73
  "image": image_bytes,
74
+ "metadata": sample["metadata"],
75
  "label": label
76
  })
77
+
78
  with open(LABELED_DATA_FILE, "w") as f:
79
  json.dump(labeled_data, f)
80
+
81
+ new_sample = get_next_sample()
82
+ if new_sample is None:
83
+ state["sample"] = None
84
+ return "Dataset exhausted.", None, DataFrame()
85
+
86
+ state["sample"] = new_sample
87
+
88
+ new_image = new_sample["rgb"][0]
89
+ new_metadata = new_sample["metadata"]
90
+ new_metadata["map"] = f'<a href="{utils.get_google_map_link(new_sample, DATASET_SUBSET)}" target="_blank">🧭</a>'
91
+
92
+ return "", new_image, DataFrame([new_metadata])
93
 
94
  # --- Gradio Interface ---
95
  # --- Labeling UI ---
96
  def labeling_ui():
97
+
98
+ state = gr.State({"sample": None, "subset": DATASET_SUBSET})
99
+
100
  with gr.Row():
101
  with gr.Column():
102
+ gallery = gr.Gallery(label="Satellite Image", interactive=False, columns=1, object_fit="scale-down")
103
+
 
104
  with gr.Row():
105
  cool_button = gr.Button("Cool")
106
  not_cool_button = gr.Button("Not Cool")
107
+
108
+ table = gr.DataFrame(datatype="html")
109
+
110
+ def initialize_labeling_ui():
111
+ sample = get_next_sample()
112
+
113
+ image, metadata = get_images(1, {"sample": None, "subset": DATASET_SUBSET})
114
+
115
+ return sample, image, metadata
116
+
117
+ initial_sample, initial_image, initial_metadata = initialize_labeling_ui()
118
+ gallery.value = initial_image
119
+ table.value = initial_metadata
120
+ state.value["sample"] = initial_sample
121
+
122
+ # Handle button clicks
123
  cool_button.click(
124
+ fn=lambda label, state: save_labeled_data(label, state),
125
+ inputs=[gr.Textbox(visible=False, value="cool"), state],
126
+ outputs=[gr.Textbox(label="Debug"), gallery, table]
127
  )
128
  not_cool_button.click(
129
+ fn=lambda label, state: save_labeled_data(label, state),
130
+ inputs=[gr.Textbox(visible=False, value="not cool"), state],
131
+ outputs=[gr.Textbox(label="Debug"), gallery, table]
132
  )
133
 
 
 
 
 
 
 
 
 
 
 
 
134
  # --- Display UI ---
135
  def display_ui():
136
+
137
  def get_random_cool_images(n):
138
  cool_samples = [d for d in labeled_data if d["label"] == "cool"]
139
  return [Image.frombytes("RGB", (384, 384), s["image"]) for s in cool_samples] if len(cool_samples) >= n else []