import gradio as gr import earthview as ev from PIL import Image import numpy as np import random import os import json import utils from pandas import DataFrame # --- Configuration --- DATASET_SUBSET = "satellogic" NUM_SAMPLES_TO_LABEL = 100 # You can adjust this LABELED_DATA_FILE = "labeled_data.json" DISPLAY_N_COOL = 5 # How many "cool" examples to display alongside each new image. SAMPLE_SEED = 10 # The seed to use when sampling the dataset for the demo. BATCH_SIZE = 10 # --- Load Dataset --- dataset = ev.load_dataset(DATASET_SUBSET, shards=[SAMPLE_SEED]) data_iter = iter(dataset) # --- Load Labeled Data (if it exists) --- def load_labeled_data(): if os.path.exists(LABELED_DATA_FILE): with open(LABELED_DATA_FILE, "r") as f: return json.load(f) else: return [] labeled_data = load_labeled_data() # --- Get Next Sample for Labeling --- def get_next_sample(): global data_iter try: sample = next(data_iter) sample = ev.item_to_images(DATASET_SUBSET, sample) return sample except StopIteration: print("No more samples in the dataset.") return None def get_images(batch_size, state): subset = state["subset"] images = [] metadatas = [] for _ in range(batch_size): sample = get_next_sample() if sample is None: break image = sample["rgb"][0] metadata = sample["metadata"] metadata["map"] = f'🧭' images.append(image) metadatas.append(metadata) return images, DataFrame(metadatas) # --- Save Labeled Data --- def save_labeled_data(label, state): global labeled_data sample = state["sample"] if sample is None: return "No image to label", None, DataFrame() image = sample["rgb"][0] # Get the PIL Image object image_bytes = image.convert("RGB").tobytes() labeled_data.append({ "image": image_bytes, "metadata": sample["metadata"], "label": label }) with open(LABELED_DATA_FILE, "w") as f: json.dump(labeled_data, f) new_sample = get_next_sample() if new_sample is None: state["sample"] = None return "Dataset exhausted.", None, DataFrame() state["sample"] = new_sample new_image = new_sample["rgb"][0] new_metadata = new_sample["metadata"] new_metadata["map"] = f'🧭' return "", new_image, DataFrame([new_metadata]) # --- Gradio Interface --- # --- Labeling UI --- def labeling_ui(): state = gr.State({"sample": None, "subset": DATASET_SUBSET}) with gr.Row(): with gr.Column(): gallery = gr.Gallery(label="Satellite Image", interactive=False, columns=1, object_fit="scale-down") with gr.Row(): cool_button = gr.Button("Cool") not_cool_button = gr.Button("Not Cool") table = gr.DataFrame(datatype="html") def initialize_labeling_ui(): sample = get_next_sample() image, metadata = get_images(1, {"sample": None, "subset": DATASET_SUBSET}) return sample, image, metadata initial_sample, initial_image, initial_metadata = initialize_labeling_ui() gallery.value = initial_image table.value = initial_metadata state.value["sample"] = initial_sample # Handle button clicks cool_button.click( fn=lambda label, state: save_labeled_data(label, state), inputs=[gr.Textbox(visible=False, value="cool"), state], outputs=[gr.Textbox(label="Debug"), gallery, table] ) not_cool_button.click( fn=lambda label, state: save_labeled_data(label, state), inputs=[gr.Textbox(visible=False, value="not cool"), state], outputs=[gr.Textbox(label="Debug"), gallery, table] ) # --- Display UI --- def display_ui(): def get_random_cool_images(n): cool_samples = [d for d in labeled_data if d["label"] == "cool"] return [Image.frombytes("RGB", (384, 384), s["image"]) for s in cool_samples] if len(cool_samples) >= n else [] def get_new_unlabeled_image(): global data_iter try: sample = next(data_iter) sample = ev.item_to_images(DATASET_SUBSET, sample) return sample["rgb"][0], json.dumps(sample["metadata"]["bounds"]) except StopIteration: print("No more samples in the dataset.") return None, None def refresh_display(): new_image, new_metadata = get_new_unlabeled_image() cool_images = get_random_cool_images(DISPLAY_N_COOL) if new_image is None: return "No more samples", None, [] return "", new_image, cool_images with gr.Row(): new_image_component = gr.Image(label="New Image", type="pil") metadata_display = gr.Textbox(label="Metadata (Bounds)") with gr.Row(): cool_images_gallery = gr.Gallery(label="Cool Examples", value=[], columns=DISPLAY_N_COOL) refresh_button = gr.Button("Refresh") refresh_button.click( fn=refresh_display, inputs=[], outputs=[gr.Textbox(label="Debug"), new_image_component, cool_images_gallery] ) def initialize_display_ui(): debug, image, gallery = refresh_display() return debug, image, gallery debug, initial_image, initial_gallery = initialize_display_ui() new_image_component.value = initial_image cool_images_gallery.value = initial_gallery # --- Main Interface --- with gr.Blocks() as demo: gr.Markdown("# TerraNomaly") with gr.Tabs(): with gr.TabItem("Labeling"): labeling_ui() with gr.TabItem("Display"): display_ui() demo.launch(debug=True)