TerraNomaly / app.py
dropbop's picture
Fix labeling_ui and save_labeled_data functions
d56788b verified
raw
history blame
5.49 kB
import gradio as gr
import earthview as ev
from PIL import Image
import numpy as np
import random
import os
import json
# --- Configuration ---
DATASET_SUBSET = "satellogic"
NUM_SAMPLES_TO_LABEL = 100 # You can adjust this
LABELED_DATA_FILE = "labeled_data.json"
DISPLAY_N_COOL = 5 # How many "cool" examples to display alongside each new image.
SAMPLE_SEED = 10 # The seed to use when sampling the dataset for the demo.
# --- Load Dataset ---
dataset = ev.load_dataset(DATASET_SUBSET, shards=[SAMPLE_SEED])
data_iter = iter(dataset)
# --- Load Labeled Data (if it exists) ---
def load_labeled_data():
if os.path.exists(LABELED_DATA_FILE):
with open(LABELED_DATA_FILE, "r") as f:
return json.load(f)
else:
return []
labeled_data = load_labeled_data()
# --- Get Next Sample for Labeling ---
def get_next_sample():
global data_iter
try:
sample = next(data_iter)
sample = ev.item_to_images(DATASET_SUBSET, sample)
image = sample["rgb"][0]
metadata = sample["metadata"]
return image, metadata, len(labeled_data)
except StopIteration:
return None, None, None
# --- Save Labeled Data ---
def save_labeled_data(image, metadata, label):
global labeled_data
# Convert PIL Image to bytes before saving
if image is not None:
image_bytes = image.convert("RGB").tobytes()
else:
image_bytes = None
labeled_data.append({
"image": image_bytes,
"metadata": metadata,
"label": label
})
with open(LABELED_DATA_FILE, "w") as f:
json.dump(labeled_data, f)
image, metadata, count = get_next_sample()
if image is None:
return "No more samples", gr.Image.update(value=None), "", f"Labeled {count} samples."
return "", image, str(metadata["bounds"]), f"Labeled {count} samples."
# --- Gradio Interface ---
# --- Labeling UI ---
def labeling_ui():
with gr.Row():
with gr.Column():
image_component = gr.Image(label="Satellite Image", type="pil")
metadata_text = gr.Textbox(label="Metadata (Bounds)")
label_count_text = gr.Textbox(label="Label Count")
with gr.Row():
cool_button = gr.Button("Cool")
not_cool_button = gr.Button("Not Cool")
# Handle button clicks
cool_button.click(
fn=lambda image, metadata: save_labeled_data(image, metadata, "cool"),
inputs=[image_component, metadata_text],
outputs=[gr.Textbox(label="Debug"), image_component, metadata_text, label_count_text]
)
not_cool_button.click(
fn=lambda image, metadata: save_labeled_data(image, metadata, "not cool"),
inputs=[image_component, metadata_text],
outputs=[gr.Textbox(label="Debug"), image_component, metadata_text, label_count_text]
)
# Initialize with the first sample
def initialize_labeling_ui():
image, metadata, count = get_next_sample()
if image is not None:
return image, str(metadata["bounds"]), f"Labeled {count} samples."
else:
return None, "", "No samples loaded."
initial_image, initial_metadata, initial_count = initialize_labeling_ui()
image_component.value = initial_image
metadata_text.value = initial_metadata
label_count_text.value = initial_count
# --- Display UI ---
def display_ui():
def get_random_cool_images(n):
cool_samples = [d for d in labeled_data if d["label"] == "cool"]
if len(cool_samples) < n:
return [Image.frombytes("RGB", (384,384), s["image"]) for s in cool_samples]
selected_cool = random.sample(cool_samples, n)
return [Image.frombytes("RGB", (384,384), s["image"]) for s in selected_cool]
def get_new_unlabeled_image():
global data_iter
try:
sample = next(data_iter)
sample = ev.item_to_images(DATASET_SUBSET, sample)
image = sample["rgb"][0]
metadata = sample["metadata"]
return image, str(metadata["bounds"])
except StopIteration:
return None, None
def refresh_display():
new_image, new_metadata = get_new_unlabeled_image()
if new_image is None:
return "No more samples", gr.Image.update(value=None), gr.Gallery.update(value=[])
cool_images = get_random_cool_images(DISPLAY_N_COOL)
return "", new_image, cool_images
with gr.Row():
new_image_component = gr.Image(label="New Image", type="pil")
metadata_display = gr.Textbox(label="Metadata (Bounds)")
with gr.Row():
cool_images_gallery = gr.Gallery(
label="Cool Examples",
value=[],
columns=DISPLAY_N_COOL # Set grid layout here
)
with gr.Row():
refresh_button = gr.Button("Refresh")
refresh_button.click(fn=refresh_display, inputs=[], outputs=[gr.Textbox(label="Debug"), new_image_component, cool_images_gallery])
# Initialize
debug, image, gallery = refresh_display()
new_image_component.update(value=image)
cool_images_gallery.update(value=gallery)
# --- Main Interface ---
with gr.Blocks() as demo:
gr.Markdown("# TerraNomaly")
with gr.Tabs():
with gr.TabItem("Labeling"):
labeling_ui()
with gr.TabItem("Display"):
display_ui()
demo.launch(debug=True)