Spaces:
Sleeping
Sleeping
| from datasets import load_dataset as _load_dataset | |
| from os import environ | |
| from PIL import Image | |
| import numpy as np | |
| import json | |
| from pyarrow.parquet import ParquetFile | |
| from pyarrow import Table as pa_Table | |
| from datasets import Dataset | |
| DATASET = "satellogic/EarthView" | |
| sets = { | |
| "satellogic": { | |
| "shards" : 7863, | |
| }, | |
| "sentinel_1": { | |
| "shards" : 1763, | |
| }, | |
| "neon": { | |
| "config" : "default", | |
| "shards" : 607, | |
| "path" : "data", | |
| }, | |
| "sentinel_2": { | |
| "shards" : 19997, | |
| }, | |
| } | |
| def get_subsets(): | |
| return sets.keys() | |
| def get_nshards(subset): | |
| return sets[subset]["shards"] | |
| def get_path(subset): | |
| return sets[subset].get("path", subset) | |
| def get_config(subset): | |
| return sets[subset].get("config", subset) | |
| def load_dataset(subset, dataset="satellogic/EarthView", split="train", shards = None, streaming=True, **kwargs): | |
| config = get_config(subset) | |
| nshards = get_nshards(subset) | |
| path = get_path(subset) | |
| if shards is None: | |
| data_files = None | |
| else: | |
| if subset == "sentinel_2": | |
| data_files = [f"{path}/sentinel_2-{shard//10}/{split}-{shard % 10:05d}-of-00010.parquet" for shard in shards] | |
| else: | |
| data_files = [f"{path}/{split}-{shard:05d}-of-{nshards:05d}.parquet" for shard in shards] | |
| data_files = {split: data_files} | |
| ds = _load_dataset( | |
| path=dataset, | |
| name=config, | |
| save_infos=True, | |
| split=split, | |
| data_files=data_files, | |
| streaming=streaming, | |
| token=environ.get("HF_TOKEN", None), | |
| **kwargs) | |
| return ds | |
| def load_parquet(subset_or_filename, batch_size=100): | |
| if subset_or_filename in get_subsets(): | |
| filename = f"dataset/{subset_or_filename}/sample.parquet" | |
| else: | |
| filename = subset_or_filename | |
| pqfile = ParquetFile(filename) | |
| batch = pqfile.iter_batches(batch_size=batch_size) | |
| return Dataset(pa_Table.from_batches(batch)) | |
| def item_to_images(subset, item): | |
| """ | |
| Converts the images within an item (arrays), as retrieved from the dataset to proper PIL.Image | |
| subset: The name of the Subset, one of "satellogic", "neon", "sentinel-1" | |
| item: The item as retrieved from the subset | |
| returns the item, with arrays converted to PIL.Image | |
| """ | |
| metadata = item["metadata"] | |
| if type(metadata) == str: | |
| metadata = json.loads(metadata) | |
| item = { | |
| k: np.asarray(v).astype("uint8") | |
| for k,v in item.items() | |
| if k != "metadata" | |
| } | |
| item["metadata"] = metadata | |
| if subset == "satellogic": | |
| # item["rgb"] = [ | |
| # Image.fromarray(np.average(image.transpose(1,2,0), 2).astype("uint8")) | |
| # for image in item["rgb"] | |
| # ] | |
| rgbs = [] | |
| for rgb in item["rgb"]: | |
| rgbs.append(Image.fromarray(rgb.transpose(1,2,0))) | |
| # rgbs.append(Image.fromarray(rgb[0,:,:])) # Red | |
| # rgbs.append(Image.fromarray(rgb[1,:,:])) # Green | |
| # rgbs.append(Image.fromarray(rgb[2,:,:])) # Blue | |
| item["rgb"] = rgbs | |
| item["1m"] = [ | |
| Image.fromarray(image[0,:,:]) | |
| for image in item["1m"] | |
| ] | |
| count = len(item["1m"]) | |
| elif subset == "sentinel_1": | |
| # Mapping of V and H to RGB. May not be correct | |
| # https://gis.stackexchange.com/questions/400726/creating-composite-rgb-images-from-sentinel-1-channels | |
| i10m = item["10m"] | |
| i10m = np.concatenate( | |
| ( i10m, | |
| np.expand_dims( | |
| i10m[:,0,:,:]/(i10m[:,1,:,:]+0.01)*256, | |
| 1 | |
| ).astype("uint8") | |
| ), | |
| 1 | |
| ) | |
| item["10m"] = [ | |
| Image.fromarray(image.transpose(1,2,0)) | |
| for image in i10m | |
| ] | |
| count = len(item["10m"]) | |
| elif subset == "sentinel_2": | |
| for channel in ['10m', '20m', 'rgb', 'scl']: #, '40m']: | |
| data = item[channel] | |
| count = len(data) | |
| data = np.asarray(data).astype("uint8").transpose(0,2,3,1) | |
| if channel == "20m": | |
| data = data[:,:,:,[0,2,4]] | |
| mode = "L" if channel in ["10m", "scl"] else "RGB" | |
| images = [Image.fromarray(data[i].squeeze(), mode=mode) for i in range(count)] | |
| item[channel] = images | |
| for field in ["solarAngles", "tileGeometry", "viewIncidenceAngles"]: | |
| item["metadata"][field] = [json.loads(s) for s in item["metadata"][field]] | |
| elif subset == "neon": | |
| item["rgb"] = [ | |
| Image.fromarray(image.transpose(1,2,0)) | |
| for image in item["rgb"] | |
| ] | |
| item["chm"] = [ | |
| Image.fromarray(image[0]) | |
| for image in item["chm"] | |
| ] | |
| # The next is a very arbitrary conversion from the 369 hyperspectral data to RGB | |
| # It just averages each 1/3 of the bads and assigns it to a channel | |
| item["1m"] = [ | |
| Image.fromarray( | |
| np.concatenate(( | |
| np.expand_dims(np.average(image[:124],0),2), | |
| np.expand_dims(np.average(image[124:247],0),2), | |
| np.expand_dims(np.average(image[247:],0),2)) | |
| ,2).astype("uint8")) | |
| for image in item["1m"] | |
| ] | |
| count = len(item["rgb"]) | |
| bounds = item["metadata"]["bounds"] | |
| # swap pairs | |
| item["metadata"]["bounds"] = [bounds[i+1-l] for i in range(0, len(bounds), 2) for l in range(2)] | |
| # fix CRS | |
| item["metadata"]["epsg"] = "EPSG:4326" | |
| item["metadata"]["count"] = count | |
| return item |