Spaces:
Build error
Build error
File size: 14,137 Bytes
a560c26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 |
# coding=utf-8
# Copyright 2023 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Input pipeline for TFDS datasets."""
import functools
import os
from typing import Dict, List, Tuple
from clu import deterministic_data
from clu import preprocess_spec
import jax
import jax.numpy as jnp
import ml_collections
import sunds
import tensorflow as tf
import tensorflow_datasets as tfds
from invariant_slot_attention.lib import preprocessing
Array = jnp.ndarray
PRNGKey = Array
PATH_CLEVR_WITH_MASKS = "gs://multi-object-datasets/clevr_with_masks/clevr_with_masks_train.tfrecords"
FEATURES_CLEVR_WITH_MASKS = {
"image": tf.io.FixedLenFeature([240, 320, 3], tf.string),
"mask": tf.io.FixedLenFeature([11, 240, 320, 1], tf.string),
"x": tf.io.FixedLenFeature([11], tf.float32),
"y": tf.io.FixedLenFeature([11], tf.float32),
"z": tf.io.FixedLenFeature([11], tf.float32),
"pixel_coords": tf.io.FixedLenFeature([11, 3], tf.float32),
"rotation": tf.io.FixedLenFeature([11], tf.float32),
"size": tf.io.FixedLenFeature([11], tf.string),
"material": tf.io.FixedLenFeature([11], tf.string),
"shape": tf.io.FixedLenFeature([11], tf.string),
"color": tf.io.FixedLenFeature([11], tf.string),
"visibility": tf.io.FixedLenFeature([11], tf.float32),
}
PATH_TETROMINOES = "gs://multi-object-datasets/tetrominoes/tetrominoes_train.tfrecords"
FEATURES_TETROMINOES = {
"image": tf.io.FixedLenFeature([35, 35, 3], tf.string),
"mask": tf.io.FixedLenFeature([4, 35, 35, 1], tf.string),
"x": tf.io.FixedLenFeature([4], tf.float32),
"y": tf.io.FixedLenFeature([4], tf.float32),
"shape": tf.io.FixedLenFeature([4], tf.float32),
"color": tf.io.FixedLenFeature([4, 3], tf.float32),
"visibility": tf.io.FixedLenFeature([4], tf.float32),
}
PATH_OBJECTS_ROOM = "gs://multi-object-datasets/objects_room/objects_room_train.tfrecords"
FEATURES_OBJECTS_ROOM = {
"image": tf.io.FixedLenFeature([64, 64, 3], tf.string),
"mask": tf.io.FixedLenFeature([7, 64, 64, 1], tf.string),
}
PATH_WAYMO_OPEN = "datasets/waymo_v_1_4_0_images/tfrecords"
FEATURES_WAYMO_OPEN = {
"image": tf.io.FixedLenFeature([128, 192, 3], tf.string),
"segmentations": tf.io.FixedLenFeature([128, 192], tf.string),
"depth": tf.io.FixedLenFeature([128, 192], tf.float32),
"num_objects": tf.io.FixedLenFeature([1], tf.int64),
"has_mask": tf.io.FixedLenFeature([1], tf.int64),
"camera": tf.io.FixedLenFeature([1], tf.int64),
}
def _decode_tetrominoes(example_proto):
single_example = tf.io.parse_single_example(
example_proto, FEATURES_TETROMINOES)
for k in ["mask", "image"]:
single_example[k] = tf.squeeze(
tf.io.decode_raw(single_example[k], tf.uint8), axis=-1)
return single_example
def _decode_objects_room(example_proto):
single_example = tf.io.parse_single_example(
example_proto, FEATURES_OBJECTS_ROOM)
for k in ["mask", "image"]:
single_example[k] = tf.squeeze(
tf.io.decode_raw(single_example[k], tf.uint8), axis=-1)
return single_example
def _decode_clevr_with_masks(example_proto):
single_example = tf.io.parse_single_example(
example_proto, FEATURES_CLEVR_WITH_MASKS)
for k in ["mask", "image", "color", "material", "shape", "size"]:
single_example[k] = tf.squeeze(
tf.io.decode_raw(single_example[k], tf.uint8), axis=-1)
return single_example
def _decode_waymo_open(example_proto):
"""Unserializes a serialized tf.train.Example sample."""
single_example = tf.io.parse_single_example(
example_proto, FEATURES_WAYMO_OPEN)
for k in ["image", "segmentations"]:
single_example[k] = tf.squeeze(
tf.io.decode_raw(single_example[k], tf.uint8), axis=-1)
single_example["segmentations"] = tf.expand_dims(
single_example["segmentations"], axis=-1)
single_example["depth"] = tf.expand_dims(
single_example["depth"], axis=-1)
return single_example
def _preprocess_minimal(example):
return {
"image": example["image"],
"segmentations": tf.cast(tf.argmax(example["mask"], axis=0), tf.uint8),
}
def _sunds_create_task():
"""Create a sunds task to return images and instance segmentation."""
return sunds.tasks.Nerf(
yield_mode=sunds.tasks.YieldMode.IMAGE,
additional_camera_specs={
"depth_image": False, # Not available in the dataset.
"category_image": False, # Not available in the dataset.
"instance_image": True,
"extrinsics": True,
},
additional_frame_specs={"pose": True},
add_name=True
)
def preprocess_example(features,
preprocess_strs):
"""Processes a single data example.
Args:
features: A dictionary containing the tensors of a single data example.
preprocess_strs: List of strings, describing one preprocessing operation
each, in clu.preprocess_spec format.
Returns:
Dictionary containing the preprocessed tensors of a single data example.
"""
all_ops = preprocessing.all_ops()
preprocess_fn = preprocess_spec.parse("|".join(preprocess_strs), all_ops)
return preprocess_fn(features) # pytype: disable=bad-return-type # allow-recursive-types
def get_batch_dims(global_batch_size):
"""Gets the first two axis sizes for data batches.
Args:
global_batch_size: Integer, the global batch size (across all devices).
Returns:
List of batch dimensions
Raises:
ValueError if the requested dimensions don't make sense with the
number of devices.
"""
num_local_devices = jax.local_device_count()
if global_batch_size % jax.host_count() != 0:
raise ValueError(f"Global batch size {global_batch_size} not evenly "
f"divisble with {jax.host_count()}.")
per_host_batch_size = global_batch_size // jax.host_count()
if per_host_batch_size % num_local_devices != 0:
raise ValueError(f"Global batch size {global_batch_size} not evenly "
f"divisible with {jax.host_count()} hosts with a per host "
f"batch size of {per_host_batch_size} and "
f"{num_local_devices} local devices. ")
return [num_local_devices, per_host_batch_size // num_local_devices]
def create_datasets(
config,
data_rng):
"""Create datasets for training and evaluation.
For the same data_rng and config this will return the same datasets. The
datasets only contain stateless operations.
Args:
config: Configuration to use.
data_rng: JAX PRNGKey for dataset pipeline.
Returns:
A tuple with the training dataset and the evaluation dataset.
"""
if config.data.dataset_name == "tetrominoes":
ds = tf.data.TFRecordDataset(
PATH_TETROMINOES,
compression_type="GZIP", buffer_size=2*(2**20))
ds = ds.map(_decode_tetrominoes,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.map(_preprocess_minimal,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
class TetrominoesBuilder:
"""Builder for tentrominoes dataset."""
def as_dataset(self, split, *unused_args, ds=ds, **unused_kwargs):
"""Simple function to conform to the builder api."""
if split == "train":
# We use 512 training examples.
ds = ds.skip(100)
ds = ds.take(512)
return tf.data.experimental.assert_cardinality(512)(ds)
elif split == "validation":
# 100 validation examples.
ds = ds.take(100)
return tf.data.experimental.assert_cardinality(100)(ds)
else:
raise ValueError("Invalid split.")
dataset_builder = TetrominoesBuilder()
elif config.data.dataset_name == "objects_room":
ds = tf.data.TFRecordDataset(
PATH_OBJECTS_ROOM,
compression_type="GZIP", buffer_size=2*(2**20))
ds = ds.map(_decode_objects_room,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.map(_preprocess_minimal,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
class ObjectsRoomBuilder:
"""Builder for objects room dataset."""
def as_dataset(self, split, *unused_args, ds=ds, **unused_kwargs):
"""Simple function to conform to the builder api."""
if split == "train":
# 1M - 100 training examples.
ds = ds.skip(100)
return tf.data.experimental.assert_cardinality(999900)(ds)
elif split == "validation":
# 100 validation examples.
ds = ds.take(100)
return tf.data.experimental.assert_cardinality(100)(ds)
else:
raise ValueError("Invalid split.")
dataset_builder = ObjectsRoomBuilder()
elif config.data.dataset_name == "clevr_with_masks":
ds = tf.data.TFRecordDataset(
PATH_CLEVR_WITH_MASKS,
compression_type="GZIP", buffer_size=2*(2**20))
ds = ds.map(_decode_clevr_with_masks,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.map(_preprocess_minimal,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
class CLEVRWithMasksBuilder:
def as_dataset(self, split, *unused_args, ds=ds, **unused_kwargs):
if split == "train":
ds = ds.skip(100)
return tf.data.experimental.assert_cardinality(99900)(ds)
elif split == "validation":
ds = ds.take(100)
return tf.data.experimental.assert_cardinality(100)(ds)
else:
raise ValueError("Invalid split.")
dataset_builder = CLEVRWithMasksBuilder()
elif config.data.dataset_name == "waymo_open":
train_path = os.path.join(
PATH_WAYMO_OPEN, "training/camera_1/*tfrecords*")
eval_path = os.path.join(
PATH_WAYMO_OPEN, "validation/camera_1/*tfrecords*")
train_files = tf.data.Dataset.list_files(train_path)
eval_files = tf.data.Dataset.list_files(eval_path)
train_data_reader = functools.partial(
tf.data.TFRecordDataset,
compression_type="ZLIB", buffer_size=2*(2**20))
eval_data_reader = functools.partial(
tf.data.TFRecordDataset,
compression_type="ZLIB", buffer_size=2*(2**20))
train_dataset = train_files.interleave(
train_data_reader, num_parallel_calls=tf.data.experimental.AUTOTUNE)
eval_dataset = eval_files.interleave(
eval_data_reader, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.map(
_decode_waymo_open, num_parallel_calls=tf.data.experimental.AUTOTUNE)
eval_dataset = eval_dataset.map(
_decode_waymo_open, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# We need to set the dataset cardinality. We assume we have
# the full dataset.
train_dataset = train_dataset.apply(
tf.data.experimental.assert_cardinality(158081))
class WaymoOpenBuilder:
def as_dataset(self, split, *unused_args, **unused_kwargs):
if split == "train":
return train_dataset
elif split == "validation":
return eval_dataset
else:
raise ValueError("Invalid split.")
dataset_builder = WaymoOpenBuilder()
elif config.data.dataset_name == "multishapenet_easy":
dataset_builder = sunds.builder(
name=config.get("tfds_name", "msn_easy"),
data_dir=config.get(
"data_dir", "gs://kubric-public/tfds"),
try_gcs=True)
dataset_builder.as_dataset = functools.partial(
dataset_builder.as_dataset, task=_sunds_create_task())
elif config.data.dataset_name == "tfds":
dataset_builder = tfds.builder(
config.data.tfds_name, data_dir=config.data.data_dir)
else:
raise ValueError("Please specify a valid dataset name.")
batch_dims = get_batch_dims(config.batch_size)
train_preprocess_fn = functools.partial(
preprocess_example, preprocess_strs=config.preproc_train)
eval_preprocess_fn = functools.partial(
preprocess_example, preprocess_strs=config.preproc_eval)
train_split_name = config.get("train_split", "train")
eval_split_name = config.get("validation_split", "validation")
train_ds = deterministic_data.create_dataset(
dataset_builder,
split=train_split_name,
rng=data_rng,
preprocess_fn=train_preprocess_fn,
cache=False,
shuffle_buffer_size=config.data.shuffle_buffer_size,
batch_dims=batch_dims,
num_epochs=None,
shuffle=True)
if config.data.dataset_name == "waymo_open":
# We filter Waymo Open for empty segmentation masks.
def filter_fn(features):
unique_instances = tf.unique(
tf.reshape(features[preprocessing.SEGMENTATIONS], (-1,)))[0]
n_instances = tf.size(unique_instances, tf.int32)
# n_instances == 1 means we only have the background.
return 2 <= n_instances
else:
filter_fn = None
eval_ds = deterministic_data.create_dataset(
dataset_builder,
split=eval_split_name,
rng=None,
preprocess_fn=eval_preprocess_fn,
filter_fn=filter_fn,
cache=False,
batch_dims=batch_dims,
num_epochs=1,
shuffle=False,
pad_up_to_batches=None)
if config.data.dataset_name == "waymo_open":
# We filter Waymo Open for empty segmentation masks after preprocessing.
# For the full dataset, we know how many we will end up with.
eval_batch_size = batch_dims[0] * batch_dims[1]
# We don't pad the last batch => floor.
eval_num_batches = int(
jnp.floor(1872 / eval_batch_size / jax.host_count()))
eval_ds = eval_ds.apply(
tf.data.experimental.assert_cardinality(
eval_num_batches))
return train_ds, eval_ds
|