Spaces:
Configuration error
Configuration error
Upload 5 files
Browse files- shape_utils.py +498 -0
- standard_fields.py +281 -0
- static_shape.py +90 -0
- string_int_label_map_pb2.py +123 -0
- tf_label_map.pbtxt +120 -0
shape_utils.py
ADDED
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Utils used to manipulate tensor shapes."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
from six.moves import zip
|
23 |
+
import tensorflow as tf
|
24 |
+
|
25 |
+
import static_shape
|
26 |
+
|
27 |
+
|
28 |
+
get_dim_as_int = static_shape.get_dim_as_int
|
29 |
+
|
30 |
+
|
31 |
+
def _is_tensor(t):
|
32 |
+
"""Returns a boolean indicating whether the input is a tensor.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
t: the input to be tested.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
a boolean that indicates whether t is a tensor.
|
39 |
+
"""
|
40 |
+
return isinstance(t, (tf.Tensor, tf.SparseTensor, tf.Variable))
|
41 |
+
|
42 |
+
|
43 |
+
def _set_dim_0(t, d0):
|
44 |
+
"""Sets the 0-th dimension of the input tensor.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
t: the input tensor, assuming the rank is at least 1.
|
48 |
+
d0: an integer indicating the 0-th dimension of the input tensor.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
the tensor t with the 0-th dimension set.
|
52 |
+
"""
|
53 |
+
t_shape = t.get_shape().as_list()
|
54 |
+
t_shape[0] = d0
|
55 |
+
t.set_shape(t_shape)
|
56 |
+
return t
|
57 |
+
|
58 |
+
|
59 |
+
def pad_tensor(t, length):
|
60 |
+
"""Pads the input tensor with 0s along the first dimension up to the length.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
t: the input tensor, assuming the rank is at least 1.
|
64 |
+
length: a tensor of shape [1] or an integer, indicating the first dimension
|
65 |
+
of the input tensor t after padding, assuming length <= t.shape[0].
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
padded_t: the padded tensor, whose first dimension is length. If the length
|
69 |
+
is an integer, the first dimension of padded_t is set to length
|
70 |
+
statically.
|
71 |
+
"""
|
72 |
+
t_rank = tf.rank(t)
|
73 |
+
t_shape = tf.shape(t)
|
74 |
+
t_d0 = t_shape[0]
|
75 |
+
pad_d0 = tf.expand_dims(length - t_d0, 0)
|
76 |
+
pad_shape = tf.cond(
|
77 |
+
tf.greater(t_rank, 1), lambda: tf.concat([pad_d0, t_shape[1:]], 0),
|
78 |
+
lambda: tf.expand_dims(length - t_d0, 0))
|
79 |
+
padded_t = tf.concat([t, tf.zeros(pad_shape, dtype=t.dtype)], 0)
|
80 |
+
if not _is_tensor(length):
|
81 |
+
padded_t = _set_dim_0(padded_t, length)
|
82 |
+
return padded_t
|
83 |
+
|
84 |
+
|
85 |
+
def clip_tensor(t, length):
|
86 |
+
"""Clips the input tensor along the first dimension up to the length.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
t: the input tensor, assuming the rank is at least 1.
|
90 |
+
length: a tensor of shape [1] or an integer, indicating the first dimension
|
91 |
+
of the input tensor t after clipping, assuming length <= t.shape[0].
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
clipped_t: the clipped tensor, whose first dimension is length. If the
|
95 |
+
length is an integer, the first dimension of clipped_t is set to length
|
96 |
+
statically.
|
97 |
+
"""
|
98 |
+
clipped_t = tf.gather(t, tf.range(length))
|
99 |
+
if not _is_tensor(length):
|
100 |
+
clipped_t = _set_dim_0(clipped_t, length)
|
101 |
+
return clipped_t
|
102 |
+
|
103 |
+
|
104 |
+
def pad_or_clip_tensor(t, length):
|
105 |
+
"""Pad or clip the input tensor along the first dimension.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
t: the input tensor, assuming the rank is at least 1.
|
109 |
+
length: a tensor of shape [1] or an integer, indicating the first dimension
|
110 |
+
of the input tensor t after processing.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
processed_t: the processed tensor, whose first dimension is length. If the
|
114 |
+
length is an integer, the first dimension of the processed tensor is set
|
115 |
+
to length statically.
|
116 |
+
"""
|
117 |
+
return pad_or_clip_nd(t, [length] + t.shape.as_list()[1:])
|
118 |
+
|
119 |
+
|
120 |
+
def pad_or_clip_nd(tensor, output_shape):
|
121 |
+
"""Pad or Clip given tensor to the output shape.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
tensor: Input tensor to pad or clip.
|
125 |
+
output_shape: A list of integers / scalar tensors (or None for dynamic dim)
|
126 |
+
representing the size to pad or clip each dimension of the input tensor.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
Input tensor padded and clipped to the output shape.
|
130 |
+
"""
|
131 |
+
tensor_shape = tf.shape(tensor)
|
132 |
+
clip_size = [
|
133 |
+
tf.where(tensor_shape[i] - shape > 0, shape, -1)
|
134 |
+
if shape is not None else -1 for i, shape in enumerate(output_shape)
|
135 |
+
]
|
136 |
+
clipped_tensor = tf.slice(
|
137 |
+
tensor,
|
138 |
+
begin=tf.zeros(len(clip_size), dtype=tf.int32),
|
139 |
+
size=clip_size)
|
140 |
+
|
141 |
+
# Pad tensor if the shape of clipped tensor is smaller than the expected
|
142 |
+
# shape.
|
143 |
+
clipped_tensor_shape = tf.shape(clipped_tensor)
|
144 |
+
trailing_paddings = [
|
145 |
+
shape - clipped_tensor_shape[i] if shape is not None else 0
|
146 |
+
for i, shape in enumerate(output_shape)
|
147 |
+
]
|
148 |
+
paddings = tf.stack(
|
149 |
+
[
|
150 |
+
tf.zeros(len(trailing_paddings), dtype=tf.int32),
|
151 |
+
trailing_paddings
|
152 |
+
],
|
153 |
+
axis=1)
|
154 |
+
padded_tensor = tf.pad(clipped_tensor, paddings=paddings)
|
155 |
+
output_static_shape = [
|
156 |
+
dim if not isinstance(dim, tf.Tensor) else None for dim in output_shape
|
157 |
+
]
|
158 |
+
padded_tensor.set_shape(output_static_shape)
|
159 |
+
return padded_tensor
|
160 |
+
|
161 |
+
|
162 |
+
def combined_static_and_dynamic_shape(tensor):
|
163 |
+
"""Returns a list containing static and dynamic values for the dimensions.
|
164 |
+
|
165 |
+
Returns a list of static and dynamic values for shape dimensions. This is
|
166 |
+
useful to preserve static shapes when available in reshape operation.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
tensor: A tensor of any type.
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
A list of size tensor.shape.ndims containing integers or a scalar tensor.
|
173 |
+
"""
|
174 |
+
static_tensor_shape = tensor.shape.as_list()
|
175 |
+
dynamic_tensor_shape = tf.shape(tensor)
|
176 |
+
combined_shape = []
|
177 |
+
for index, dim in enumerate(static_tensor_shape):
|
178 |
+
if dim is not None:
|
179 |
+
combined_shape.append(dim)
|
180 |
+
else:
|
181 |
+
combined_shape.append(dynamic_tensor_shape[index])
|
182 |
+
return combined_shape
|
183 |
+
|
184 |
+
|
185 |
+
def static_or_dynamic_map_fn(fn, elems, dtype=None,
|
186 |
+
parallel_iterations=32, back_prop=True):
|
187 |
+
"""Runs map_fn as a (static) for loop when possible.
|
188 |
+
|
189 |
+
This function rewrites the map_fn as an explicit unstack input -> for loop
|
190 |
+
over function calls -> stack result combination. This allows our graphs to
|
191 |
+
be acyclic when the batch size is static.
|
192 |
+
For comparison, see https://www.tensorflow.org/api_docs/python/tf/map_fn.
|
193 |
+
|
194 |
+
Note that `static_or_dynamic_map_fn` currently is not *fully* interchangeable
|
195 |
+
with the default tf.map_fn function as it does not accept nested inputs (only
|
196 |
+
Tensors or lists of Tensors). Likewise, the output of `fn` can only be a
|
197 |
+
Tensor or list of Tensors.
|
198 |
+
|
199 |
+
TODO(jonathanhuang): make this function fully interchangeable with tf.map_fn.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
fn: The callable to be performed. It accepts one argument, which will have
|
203 |
+
the same structure as elems. Its output must have the
|
204 |
+
same structure as elems.
|
205 |
+
elems: A tensor or list of tensors, each of which will
|
206 |
+
be unpacked along their first dimension. The sequence of the
|
207 |
+
resulting slices will be applied to fn.
|
208 |
+
dtype: (optional) The output type(s) of fn. If fn returns a structure of
|
209 |
+
Tensors differing from the structure of elems, then dtype is not optional
|
210 |
+
and must have the same structure as the output of fn.
|
211 |
+
parallel_iterations: (optional) number of batch items to process in
|
212 |
+
parallel. This flag is only used if the native tf.map_fn is used
|
213 |
+
and defaults to 32 instead of 10 (unlike the standard tf.map_fn default).
|
214 |
+
back_prop: (optional) True enables support for back propagation.
|
215 |
+
This flag is only used if the native tf.map_fn is used.
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
A tensor or sequence of tensors. Each tensor packs the
|
219 |
+
results of applying fn to tensors unpacked from elems along the first
|
220 |
+
dimension, from first to last.
|
221 |
+
Raises:
|
222 |
+
ValueError: if `elems` a Tensor or a list of Tensors.
|
223 |
+
ValueError: if `fn` does not return a Tensor or list of Tensors
|
224 |
+
"""
|
225 |
+
if isinstance(elems, list):
|
226 |
+
for elem in elems:
|
227 |
+
if not isinstance(elem, tf.Tensor):
|
228 |
+
raise ValueError('`elems` must be a Tensor or list of Tensors.')
|
229 |
+
|
230 |
+
elem_shapes = [elem.shape.as_list() for elem in elems]
|
231 |
+
# Fall back on tf.map_fn if shapes of each entry of `elems` are None or fail
|
232 |
+
# to all be the same size along the batch dimension.
|
233 |
+
for elem_shape in elem_shapes:
|
234 |
+
if (not elem_shape or not elem_shape[0]
|
235 |
+
or elem_shape[0] != elem_shapes[0][0]):
|
236 |
+
return tf.map_fn(fn, elems, dtype, parallel_iterations, back_prop)
|
237 |
+
arg_tuples = zip(*[tf.unstack(elem) for elem in elems])
|
238 |
+
outputs = [fn(arg_tuple) for arg_tuple in arg_tuples]
|
239 |
+
else:
|
240 |
+
if not isinstance(elems, tf.Tensor):
|
241 |
+
raise ValueError('`elems` must be a Tensor or list of Tensors.')
|
242 |
+
elems_shape = elems.shape.as_list()
|
243 |
+
if not elems_shape or not elems_shape[0]:
|
244 |
+
return tf.map_fn(fn, elems, dtype, parallel_iterations, back_prop)
|
245 |
+
outputs = [fn(arg) for arg in tf.unstack(elems)]
|
246 |
+
# Stack `outputs`, which is a list of Tensors or list of lists of Tensors
|
247 |
+
if all([isinstance(output, tf.Tensor) for output in outputs]):
|
248 |
+
return tf.stack(outputs)
|
249 |
+
else:
|
250 |
+
if all([isinstance(output, list) for output in outputs]):
|
251 |
+
if all([all(
|
252 |
+
[isinstance(entry, tf.Tensor) for entry in output_list])
|
253 |
+
for output_list in outputs]):
|
254 |
+
return [tf.stack(output_tuple) for output_tuple in zip(*outputs)]
|
255 |
+
raise ValueError('`fn` should return a Tensor or a list of Tensors.')
|
256 |
+
|
257 |
+
|
258 |
+
def check_min_image_dim(min_dim, image_tensor):
|
259 |
+
"""Checks that the image width/height are greater than some number.
|
260 |
+
|
261 |
+
This function is used to check that the width and height of an image are above
|
262 |
+
a certain value. If the image shape is static, this function will perform the
|
263 |
+
check at graph construction time. Otherwise, if the image shape varies, an
|
264 |
+
Assertion control dependency will be added to the graph.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
min_dim: The minimum number of pixels along the width and height of the
|
268 |
+
image.
|
269 |
+
image_tensor: The image tensor to check size for.
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
If `image_tensor` has dynamic size, return `image_tensor` with a Assert
|
273 |
+
control dependency. Otherwise returns image_tensor.
|
274 |
+
|
275 |
+
Raises:
|
276 |
+
ValueError: if `image_tensor`'s' width or height is smaller than `min_dim`.
|
277 |
+
"""
|
278 |
+
image_shape = image_tensor.get_shape()
|
279 |
+
image_height = static_shape.get_height(image_shape)
|
280 |
+
image_width = static_shape.get_width(image_shape)
|
281 |
+
if image_height is None or image_width is None:
|
282 |
+
shape_assert = tf.Assert(
|
283 |
+
tf.logical_and(tf.greater_equal(tf.shape(image_tensor)[1], min_dim),
|
284 |
+
tf.greater_equal(tf.shape(image_tensor)[2], min_dim)),
|
285 |
+
['image size must be >= {} in both height and width.'.format(min_dim)])
|
286 |
+
with tf.control_dependencies([shape_assert]):
|
287 |
+
return tf.identity(image_tensor)
|
288 |
+
|
289 |
+
if image_height < min_dim or image_width < min_dim:
|
290 |
+
raise ValueError(
|
291 |
+
'image size must be >= %d in both height and width; image dim = %d,%d' %
|
292 |
+
(min_dim, image_height, image_width))
|
293 |
+
|
294 |
+
return image_tensor
|
295 |
+
|
296 |
+
|
297 |
+
def assert_shape_equal(shape_a, shape_b):
|
298 |
+
"""Asserts that shape_a and shape_b are equal.
|
299 |
+
|
300 |
+
If the shapes are static, raises a ValueError when the shapes
|
301 |
+
mismatch.
|
302 |
+
|
303 |
+
If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes
|
304 |
+
mismatch.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
shape_a: a list containing shape of the first tensor.
|
308 |
+
shape_b: a list containing shape of the second tensor.
|
309 |
+
|
310 |
+
Returns:
|
311 |
+
Either a tf.no_op() when shapes are all static and a tf.assert_equal() op
|
312 |
+
when the shapes are dynamic.
|
313 |
+
|
314 |
+
Raises:
|
315 |
+
ValueError: When shapes are both static and unequal.
|
316 |
+
"""
|
317 |
+
if (all(isinstance(dim, int) for dim in shape_a) and
|
318 |
+
all(isinstance(dim, int) for dim in shape_b)):
|
319 |
+
if shape_a != shape_b:
|
320 |
+
raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b))
|
321 |
+
else: return tf.no_op()
|
322 |
+
else:
|
323 |
+
return tf.assert_equal(shape_a, shape_b)
|
324 |
+
|
325 |
+
|
326 |
+
def assert_shape_equal_along_first_dimension(shape_a, shape_b):
|
327 |
+
"""Asserts that shape_a and shape_b are the same along the 0th-dimension.
|
328 |
+
|
329 |
+
If the shapes are static, raises a ValueError when the shapes
|
330 |
+
mismatch.
|
331 |
+
|
332 |
+
If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes
|
333 |
+
mismatch.
|
334 |
+
|
335 |
+
Args:
|
336 |
+
shape_a: a list containing shape of the first tensor.
|
337 |
+
shape_b: a list containing shape of the second tensor.
|
338 |
+
|
339 |
+
Returns:
|
340 |
+
Either a tf.no_op() when shapes are all static and a tf.assert_equal() op
|
341 |
+
when the shapes are dynamic.
|
342 |
+
|
343 |
+
Raises:
|
344 |
+
ValueError: When shapes are both static and unequal.
|
345 |
+
"""
|
346 |
+
if isinstance(shape_a[0], int) and isinstance(shape_b[0], int):
|
347 |
+
if shape_a[0] != shape_b[0]:
|
348 |
+
raise ValueError('Unequal first dimension {}, {}'.format(
|
349 |
+
shape_a[0], shape_b[0]))
|
350 |
+
else: return tf.no_op()
|
351 |
+
else:
|
352 |
+
return tf.assert_equal(shape_a[0], shape_b[0])
|
353 |
+
|
354 |
+
|
355 |
+
def assert_box_normalized(boxes, maximum_normalized_coordinate=1.1):
|
356 |
+
"""Asserts the input box tensor is normalized.
|
357 |
+
|
358 |
+
Args:
|
359 |
+
boxes: a tensor of shape [N, 4] where N is the number of boxes.
|
360 |
+
maximum_normalized_coordinate: Maximum coordinate value to be considered
|
361 |
+
as normalized, default to 1.1.
|
362 |
+
|
363 |
+
Returns:
|
364 |
+
a tf.Assert op which fails when the input box tensor is not normalized.
|
365 |
+
|
366 |
+
Raises:
|
367 |
+
ValueError: When the input box tensor is not normalized.
|
368 |
+
"""
|
369 |
+
box_minimum = tf.reduce_min(boxes)
|
370 |
+
box_maximum = tf.reduce_max(boxes)
|
371 |
+
return tf.Assert(
|
372 |
+
tf.logical_and(
|
373 |
+
tf.less_equal(box_maximum, maximum_normalized_coordinate),
|
374 |
+
tf.greater_equal(box_minimum, 0)),
|
375 |
+
[boxes])
|
376 |
+
|
377 |
+
|
378 |
+
def flatten_dimensions(inputs, first, last):
|
379 |
+
"""Flattens `K-d` tensor along [first, last) dimensions.
|
380 |
+
|
381 |
+
Converts `inputs` with shape [D0, D1, ..., D(K-1)] into a tensor of shape
|
382 |
+
[D0, D1, ..., D(first) * D(first+1) * ... * D(last-1), D(last), ..., D(K-1)].
|
383 |
+
|
384 |
+
Example:
|
385 |
+
`inputs` is a tensor with initial shape [10, 5, 20, 20, 3].
|
386 |
+
new_tensor = flatten_dimensions(inputs, first=1, last=3)
|
387 |
+
new_tensor.shape -> [10, 100, 20, 3].
|
388 |
+
|
389 |
+
Args:
|
390 |
+
inputs: a tensor with shape [D0, D1, ..., D(K-1)].
|
391 |
+
first: first value for the range of dimensions to flatten.
|
392 |
+
last: last value for the range of dimensions to flatten. Note that the last
|
393 |
+
dimension itself is excluded.
|
394 |
+
|
395 |
+
Returns:
|
396 |
+
a tensor with shape
|
397 |
+
[D0, D1, ..., D(first) * D(first + 1) * ... * D(last - 1), D(last), ...,
|
398 |
+
D(K-1)].
|
399 |
+
|
400 |
+
Raises:
|
401 |
+
ValueError: if first and last arguments are incorrect.
|
402 |
+
"""
|
403 |
+
if first >= inputs.shape.ndims or last > inputs.shape.ndims:
|
404 |
+
raise ValueError('`first` and `last` must be less than inputs.shape.ndims. '
|
405 |
+
'found {} and {} respectively while ndims is {}'.format(
|
406 |
+
first, last, inputs.shape.ndims))
|
407 |
+
shape = combined_static_and_dynamic_shape(inputs)
|
408 |
+
flattened_dim_prod = tf.reduce_prod(shape[first:last],
|
409 |
+
keepdims=True)
|
410 |
+
new_shape = tf.concat([shape[:first], flattened_dim_prod,
|
411 |
+
shape[last:]], axis=0)
|
412 |
+
return tf.reshape(inputs, new_shape)
|
413 |
+
|
414 |
+
|
415 |
+
def flatten_first_n_dimensions(inputs, n):
|
416 |
+
"""Flattens `K-d` tensor along first n dimension to be a `(K-n+1)-d` tensor.
|
417 |
+
|
418 |
+
Converts `inputs` with shape [D0, D1, ..., D(K-1)] into a tensor of shape
|
419 |
+
[D0 * D1 * ... * D(n-1), D(n), ... D(K-1)].
|
420 |
+
|
421 |
+
Example:
|
422 |
+
`inputs` is a tensor with initial shape [10, 5, 20, 20, 3].
|
423 |
+
new_tensor = flatten_first_n_dimensions(inputs, 2)
|
424 |
+
new_tensor.shape -> [50, 20, 20, 3].
|
425 |
+
|
426 |
+
Args:
|
427 |
+
inputs: a tensor with shape [D0, D1, ..., D(K-1)].
|
428 |
+
n: The number of dimensions to flatten.
|
429 |
+
|
430 |
+
Returns:
|
431 |
+
a tensor with shape [D0 * D1 * ... * D(n-1), D(n), ... D(K-1)].
|
432 |
+
"""
|
433 |
+
return flatten_dimensions(inputs, first=0, last=n)
|
434 |
+
|
435 |
+
|
436 |
+
def expand_first_dimension(inputs, dims):
|
437 |
+
"""Expands `K-d` tensor along first dimension to be a `(K+n-1)-d` tensor.
|
438 |
+
|
439 |
+
Converts `inputs` with shape [D0, D1, ..., D(K-1)] into a tensor of shape
|
440 |
+
[dims[0], dims[1], ..., dims[-1], D1, ..., D(k-1)].
|
441 |
+
|
442 |
+
Example:
|
443 |
+
`inputs` is a tensor with shape [50, 20, 20, 3].
|
444 |
+
new_tensor = expand_first_dimension(inputs, [10, 5]).
|
445 |
+
new_tensor.shape -> [10, 5, 20, 20, 3].
|
446 |
+
|
447 |
+
Args:
|
448 |
+
inputs: a tensor with shape [D0, D1, ..., D(K-1)].
|
449 |
+
dims: List with new dimensions to expand first axis into. The length of
|
450 |
+
`dims` is typically 2 or larger.
|
451 |
+
|
452 |
+
Returns:
|
453 |
+
a tensor with shape [dims[0], dims[1], ..., dims[-1], D1, ..., D(k-1)].
|
454 |
+
"""
|
455 |
+
inputs_shape = combined_static_and_dynamic_shape(inputs)
|
456 |
+
expanded_shape = tf.stack(dims + inputs_shape[1:])
|
457 |
+
|
458 |
+
# Verify that it is possible to expand the first axis of inputs.
|
459 |
+
assert_op = tf.assert_equal(
|
460 |
+
inputs_shape[0], tf.reduce_prod(tf.stack(dims)),
|
461 |
+
message=('First dimension of `inputs` cannot be expanded into provided '
|
462 |
+
'`dims`'))
|
463 |
+
|
464 |
+
with tf.control_dependencies([assert_op]):
|
465 |
+
inputs_reshaped = tf.reshape(inputs, expanded_shape)
|
466 |
+
|
467 |
+
return inputs_reshaped
|
468 |
+
|
469 |
+
|
470 |
+
def resize_images_and_return_shapes(inputs, image_resizer_fn):
|
471 |
+
"""Resizes images using the given function and returns their true shapes.
|
472 |
+
|
473 |
+
Args:
|
474 |
+
inputs: a float32 Tensor representing a batch of inputs of shape
|
475 |
+
[batch_size, height, width, channels].
|
476 |
+
image_resizer_fn: a function which takes in a single image and outputs
|
477 |
+
a resized image and its original shape.
|
478 |
+
|
479 |
+
Returns:
|
480 |
+
resized_inputs: The inputs resized according to image_resizer_fn.
|
481 |
+
true_image_shapes: A integer tensor of shape [batch_size, 3]
|
482 |
+
representing the height, width and number of channels in inputs.
|
483 |
+
"""
|
484 |
+
|
485 |
+
if inputs.dtype is not tf.float32:
|
486 |
+
raise ValueError('`resize_images_and_return_shapes` expects a'
|
487 |
+
' tf.float32 tensor')
|
488 |
+
|
489 |
+
# TODO(jonathanhuang): revisit whether to always use batch size as
|
490 |
+
# the number of parallel iterations vs allow for dynamic batching.
|
491 |
+
outputs = static_or_dynamic_map_fn(
|
492 |
+
image_resizer_fn,
|
493 |
+
elems=inputs,
|
494 |
+
dtype=[tf.float32, tf.int32])
|
495 |
+
resized_inputs = outputs[0]
|
496 |
+
true_image_shapes = outputs[1]
|
497 |
+
|
498 |
+
return resized_inputs, true_image_shapes
|
standard_fields.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Contains classes specifying naming conventions used for object detection.
|
17 |
+
|
18 |
+
|
19 |
+
Specifies:
|
20 |
+
InputDataFields: standard fields used by reader/preprocessor/batcher.
|
21 |
+
DetectionResultFields: standard fields returned by object detector.
|
22 |
+
BoxListFields: standard field used by BoxList
|
23 |
+
TfExampleFields: standard fields for tf-example data format (go/tf-example).
|
24 |
+
"""
|
25 |
+
|
26 |
+
|
27 |
+
class InputDataFields(object):
|
28 |
+
"""Names for the input tensors.
|
29 |
+
|
30 |
+
Holds the standard data field names to use for identifying input tensors. This
|
31 |
+
should be used by the decoder to identify keys for the returned tensor_dict
|
32 |
+
containing input tensors. And it should be used by the model to identify the
|
33 |
+
tensors it needs.
|
34 |
+
|
35 |
+
Attributes:
|
36 |
+
image: image.
|
37 |
+
image_additional_channels: additional channels.
|
38 |
+
original_image: image in the original input size.
|
39 |
+
original_image_spatial_shape: image in the original input size.
|
40 |
+
key: unique key corresponding to image.
|
41 |
+
source_id: source of the original image.
|
42 |
+
filename: original filename of the dataset (without common path).
|
43 |
+
groundtruth_image_classes: image-level class labels.
|
44 |
+
groundtruth_image_confidences: image-level class confidences.
|
45 |
+
groundtruth_labeled_classes: image-level annotation that indicates the
|
46 |
+
classes for which an image has been labeled.
|
47 |
+
groundtruth_boxes: coordinates of the ground truth boxes in the image.
|
48 |
+
groundtruth_classes: box-level class labels.
|
49 |
+
groundtruth_confidences: box-level class confidences. The shape should be
|
50 |
+
the same as the shape of groundtruth_classes.
|
51 |
+
groundtruth_label_types: box-level label types (e.g. explicit negative).
|
52 |
+
groundtruth_is_crowd: [DEPRECATED, use groundtruth_group_of instead]
|
53 |
+
is the groundtruth a single object or a crowd.
|
54 |
+
groundtruth_area: area of a groundtruth segment.
|
55 |
+
groundtruth_difficult: is a `difficult` object
|
56 |
+
groundtruth_group_of: is a `group_of` objects, e.g. multiple objects of the
|
57 |
+
same class, forming a connected group, where instances are heavily
|
58 |
+
occluding each other.
|
59 |
+
proposal_boxes: coordinates of object proposal boxes.
|
60 |
+
proposal_objectness: objectness score of each proposal.
|
61 |
+
groundtruth_instance_masks: ground truth instance masks.
|
62 |
+
groundtruth_instance_boundaries: ground truth instance boundaries.
|
63 |
+
groundtruth_instance_classes: instance mask-level class labels.
|
64 |
+
groundtruth_keypoints: ground truth keypoints.
|
65 |
+
groundtruth_keypoint_visibilities: ground truth keypoint visibilities.
|
66 |
+
groundtruth_keypoint_weights: groundtruth weight factor for keypoints.
|
67 |
+
groundtruth_label_weights: groundtruth label weights.
|
68 |
+
groundtruth_weights: groundtruth weight factor for bounding boxes.
|
69 |
+
num_groundtruth_boxes: number of groundtruth boxes.
|
70 |
+
is_annotated: whether an image has been labeled or not.
|
71 |
+
true_image_shapes: true shapes of images in the resized images, as resized
|
72 |
+
images can be padded with zeros.
|
73 |
+
multiclass_scores: the label score per class for each box.
|
74 |
+
context_features: a flattened list of contextual features.
|
75 |
+
context_feature_length: the fixed length of each feature in
|
76 |
+
context_features, used for reshaping.
|
77 |
+
valid_context_size: the valid context size, used in filtering the padded
|
78 |
+
context features.
|
79 |
+
"""
|
80 |
+
image = 'image'
|
81 |
+
image_additional_channels = 'image_additional_channels'
|
82 |
+
original_image = 'original_image'
|
83 |
+
original_image_spatial_shape = 'original_image_spatial_shape'
|
84 |
+
key = 'key'
|
85 |
+
source_id = 'source_id'
|
86 |
+
filename = 'filename'
|
87 |
+
groundtruth_image_classes = 'groundtruth_image_classes'
|
88 |
+
groundtruth_image_confidences = 'groundtruth_image_confidences'
|
89 |
+
groundtruth_labeled_classes = 'groundtruth_labeled_classes'
|
90 |
+
groundtruth_boxes = 'groundtruth_boxes'
|
91 |
+
groundtruth_classes = 'groundtruth_classes'
|
92 |
+
groundtruth_confidences = 'groundtruth_confidences'
|
93 |
+
groundtruth_label_types = 'groundtruth_label_types'
|
94 |
+
groundtruth_is_crowd = 'groundtruth_is_crowd'
|
95 |
+
groundtruth_area = 'groundtruth_area'
|
96 |
+
groundtruth_difficult = 'groundtruth_difficult'
|
97 |
+
groundtruth_group_of = 'groundtruth_group_of'
|
98 |
+
proposal_boxes = 'proposal_boxes'
|
99 |
+
proposal_objectness = 'proposal_objectness'
|
100 |
+
groundtruth_instance_masks = 'groundtruth_instance_masks'
|
101 |
+
groundtruth_instance_boundaries = 'groundtruth_instance_boundaries'
|
102 |
+
groundtruth_instance_classes = 'groundtruth_instance_classes'
|
103 |
+
groundtruth_keypoints = 'groundtruth_keypoints'
|
104 |
+
groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities'
|
105 |
+
groundtruth_keypoint_weights = 'groundtruth_keypoint_weights'
|
106 |
+
groundtruth_label_weights = 'groundtruth_label_weights'
|
107 |
+
groundtruth_weights = 'groundtruth_weights'
|
108 |
+
num_groundtruth_boxes = 'num_groundtruth_boxes'
|
109 |
+
is_annotated = 'is_annotated'
|
110 |
+
true_image_shape = 'true_image_shape'
|
111 |
+
multiclass_scores = 'multiclass_scores'
|
112 |
+
context_features = 'context_features'
|
113 |
+
context_feature_length = 'context_feature_length'
|
114 |
+
valid_context_size = 'valid_context_size'
|
115 |
+
|
116 |
+
|
117 |
+
class DetectionResultFields(object):
|
118 |
+
"""Naming conventions for storing the output of the detector.
|
119 |
+
|
120 |
+
Attributes:
|
121 |
+
source_id: source of the original image.
|
122 |
+
key: unique key corresponding to image.
|
123 |
+
detection_boxes: coordinates of the detection boxes in the image.
|
124 |
+
detection_scores: detection scores for the detection boxes in the image.
|
125 |
+
detection_multiclass_scores: class score distribution (including background)
|
126 |
+
for detection boxes in the image including background class.
|
127 |
+
detection_classes: detection-level class labels.
|
128 |
+
detection_masks: contains a segmentation mask for each detection box.
|
129 |
+
detection_boundaries: contains an object boundary for each detection box.
|
130 |
+
detection_keypoints: contains detection keypoints for each detection box.
|
131 |
+
detection_keypoint_scores: contains detection keypoint scores.
|
132 |
+
num_detections: number of detections in the batch.
|
133 |
+
raw_detection_boxes: contains decoded detection boxes without Non-Max
|
134 |
+
suppression.
|
135 |
+
raw_detection_scores: contains class score logits for raw detection boxes.
|
136 |
+
detection_anchor_indices: The anchor indices of the detections after NMS.
|
137 |
+
detection_features: contains extracted features for each detected box
|
138 |
+
after NMS.
|
139 |
+
"""
|
140 |
+
|
141 |
+
source_id = 'source_id'
|
142 |
+
key = 'key'
|
143 |
+
detection_boxes = 'detection_boxes'
|
144 |
+
detection_scores = 'detection_scores'
|
145 |
+
detection_multiclass_scores = 'detection_multiclass_scores'
|
146 |
+
detection_features = 'detection_features'
|
147 |
+
detection_classes = 'detection_classes'
|
148 |
+
detection_masks = 'detection_masks'
|
149 |
+
detection_boundaries = 'detection_boundaries'
|
150 |
+
detection_keypoints = 'detection_keypoints'
|
151 |
+
detection_keypoint_scores = 'detection_keypoint_scores'
|
152 |
+
num_detections = 'num_detections'
|
153 |
+
raw_detection_boxes = 'raw_detection_boxes'
|
154 |
+
raw_detection_scores = 'raw_detection_scores'
|
155 |
+
detection_anchor_indices = 'detection_anchor_indices'
|
156 |
+
|
157 |
+
|
158 |
+
class BoxListFields(object):
|
159 |
+
"""Naming conventions for BoxLists.
|
160 |
+
|
161 |
+
Attributes:
|
162 |
+
boxes: bounding box coordinates.
|
163 |
+
classes: classes per bounding box.
|
164 |
+
scores: scores per bounding box.
|
165 |
+
weights: sample weights per bounding box.
|
166 |
+
objectness: objectness score per bounding box.
|
167 |
+
masks: masks per bounding box.
|
168 |
+
boundaries: boundaries per bounding box.
|
169 |
+
keypoints: keypoints per bounding box.
|
170 |
+
keypoint_heatmaps: keypoint heatmaps per bounding box.
|
171 |
+
is_crowd: is_crowd annotation per bounding box.
|
172 |
+
"""
|
173 |
+
boxes = 'boxes'
|
174 |
+
classes = 'classes'
|
175 |
+
scores = 'scores'
|
176 |
+
weights = 'weights'
|
177 |
+
confidences = 'confidences'
|
178 |
+
objectness = 'objectness'
|
179 |
+
masks = 'masks'
|
180 |
+
boundaries = 'boundaries'
|
181 |
+
keypoints = 'keypoints'
|
182 |
+
keypoint_visibilities = 'keypoint_visibilities'
|
183 |
+
keypoint_heatmaps = 'keypoint_heatmaps'
|
184 |
+
is_crowd = 'is_crowd'
|
185 |
+
|
186 |
+
|
187 |
+
class PredictionFields(object):
|
188 |
+
"""Naming conventions for standardized prediction outputs.
|
189 |
+
|
190 |
+
Attributes:
|
191 |
+
feature_maps: List of feature maps for prediction.
|
192 |
+
anchors: Generated anchors.
|
193 |
+
raw_detection_boxes: Decoded detection boxes without NMS.
|
194 |
+
raw_detection_feature_map_indices: Feature map indices from which each raw
|
195 |
+
detection box was produced.
|
196 |
+
"""
|
197 |
+
feature_maps = 'feature_maps'
|
198 |
+
anchors = 'anchors'
|
199 |
+
raw_detection_boxes = 'raw_detection_boxes'
|
200 |
+
raw_detection_feature_map_indices = 'raw_detection_feature_map_indices'
|
201 |
+
|
202 |
+
|
203 |
+
class TfExampleFields(object):
|
204 |
+
"""TF-example proto feature names for object detection.
|
205 |
+
|
206 |
+
Holds the standard feature names to load from an Example proto for object
|
207 |
+
detection.
|
208 |
+
|
209 |
+
Attributes:
|
210 |
+
image_encoded: JPEG encoded string
|
211 |
+
image_format: image format, e.g. "JPEG"
|
212 |
+
filename: filename
|
213 |
+
channels: number of channels of image
|
214 |
+
colorspace: colorspace, e.g. "RGB"
|
215 |
+
height: height of image in pixels, e.g. 462
|
216 |
+
width: width of image in pixels, e.g. 581
|
217 |
+
source_id: original source of the image
|
218 |
+
image_class_text: image-level label in text format
|
219 |
+
image_class_label: image-level label in numerical format
|
220 |
+
image_class_confidence: image-level confidence of the label
|
221 |
+
object_class_text: labels in text format, e.g. ["person", "cat"]
|
222 |
+
object_class_label: labels in numbers, e.g. [16, 8]
|
223 |
+
object_bbox_xmin: xmin coordinates of groundtruth box, e.g. 10, 30
|
224 |
+
object_bbox_xmax: xmax coordinates of groundtruth box, e.g. 50, 40
|
225 |
+
object_bbox_ymin: ymin coordinates of groundtruth box, e.g. 40, 50
|
226 |
+
object_bbox_ymax: ymax coordinates of groundtruth box, e.g. 80, 70
|
227 |
+
object_view: viewpoint of object, e.g. ["frontal", "left"]
|
228 |
+
object_truncated: is object truncated, e.g. [true, false]
|
229 |
+
object_occluded: is object occluded, e.g. [true, false]
|
230 |
+
object_difficult: is object difficult, e.g. [true, false]
|
231 |
+
object_group_of: is object a single object or a group of objects
|
232 |
+
object_depiction: is object a depiction
|
233 |
+
object_is_crowd: [DEPRECATED, use object_group_of instead]
|
234 |
+
is the object a single object or a crowd
|
235 |
+
object_segment_area: the area of the segment.
|
236 |
+
object_weight: a weight factor for the object's bounding box.
|
237 |
+
instance_masks: instance segmentation masks.
|
238 |
+
instance_boundaries: instance boundaries.
|
239 |
+
instance_classes: Classes for each instance segmentation mask.
|
240 |
+
detection_class_label: class label in numbers.
|
241 |
+
detection_bbox_ymin: ymin coordinates of a detection box.
|
242 |
+
detection_bbox_xmin: xmin coordinates of a detection box.
|
243 |
+
detection_bbox_ymax: ymax coordinates of a detection box.
|
244 |
+
detection_bbox_xmax: xmax coordinates of a detection box.
|
245 |
+
detection_score: detection score for the class label and box.
|
246 |
+
"""
|
247 |
+
image_encoded = 'image/encoded'
|
248 |
+
image_format = 'image/format' # format is reserved keyword
|
249 |
+
filename = 'image/filename'
|
250 |
+
channels = 'image/channels'
|
251 |
+
colorspace = 'image/colorspace'
|
252 |
+
height = 'image/height'
|
253 |
+
width = 'image/width'
|
254 |
+
source_id = 'image/source_id'
|
255 |
+
image_class_text = 'image/class/text'
|
256 |
+
image_class_label = 'image/class/label'
|
257 |
+
image_class_confidence = 'image/class/confidence'
|
258 |
+
object_class_text = 'image/object/class/text'
|
259 |
+
object_class_label = 'image/object/class/label'
|
260 |
+
object_bbox_ymin = 'image/object/bbox/ymin'
|
261 |
+
object_bbox_xmin = 'image/object/bbox/xmin'
|
262 |
+
object_bbox_ymax = 'image/object/bbox/ymax'
|
263 |
+
object_bbox_xmax = 'image/object/bbox/xmax'
|
264 |
+
object_view = 'image/object/view'
|
265 |
+
object_truncated = 'image/object/truncated'
|
266 |
+
object_occluded = 'image/object/occluded'
|
267 |
+
object_difficult = 'image/object/difficult'
|
268 |
+
object_group_of = 'image/object/group_of'
|
269 |
+
object_depiction = 'image/object/depiction'
|
270 |
+
object_is_crowd = 'image/object/is_crowd'
|
271 |
+
object_segment_area = 'image/object/segment/area'
|
272 |
+
object_weight = 'image/object/weight'
|
273 |
+
instance_masks = 'image/segmentation/object'
|
274 |
+
instance_boundaries = 'image/boundaries/object'
|
275 |
+
instance_classes = 'image/segmentation/object/class'
|
276 |
+
detection_class_label = 'image/detection/label'
|
277 |
+
detection_bbox_ymin = 'image/detection/bbox/ymin'
|
278 |
+
detection_bbox_xmin = 'image/detection/bbox/xmin'
|
279 |
+
detection_bbox_ymax = 'image/detection/bbox/ymax'
|
280 |
+
detection_bbox_xmax = 'image/detection/bbox/xmax'
|
281 |
+
detection_score = 'image/detection/score'
|
static_shape.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Helper functions to access TensorShape values.
|
17 |
+
|
18 |
+
The rank 4 tensor_shape must be of the form [batch_size, height, width, depth].
|
19 |
+
"""
|
20 |
+
|
21 |
+
from __future__ import absolute_import
|
22 |
+
from __future__ import division
|
23 |
+
from __future__ import print_function
|
24 |
+
|
25 |
+
|
26 |
+
def get_dim_as_int(dim):
|
27 |
+
"""Utility to get v1 or v2 TensorShape dim as an int.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
dim: The TensorShape dimension to get as an int
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
None or an int.
|
34 |
+
"""
|
35 |
+
try:
|
36 |
+
return dim.value
|
37 |
+
except AttributeError:
|
38 |
+
return dim
|
39 |
+
|
40 |
+
|
41 |
+
def get_batch_size(tensor_shape):
|
42 |
+
"""Returns batch size from the tensor shape.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
tensor_shape: A rank 4 TensorShape.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
An integer representing the batch size of the tensor.
|
49 |
+
"""
|
50 |
+
tensor_shape.assert_has_rank(rank=4)
|
51 |
+
return get_dim_as_int(tensor_shape[0])
|
52 |
+
|
53 |
+
|
54 |
+
def get_height(tensor_shape):
|
55 |
+
"""Returns height from the tensor shape.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
tensor_shape: A rank 4 TensorShape.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
An integer representing the height of the tensor.
|
62 |
+
"""
|
63 |
+
tensor_shape.assert_has_rank(rank=4)
|
64 |
+
return get_dim_as_int(tensor_shape[1])
|
65 |
+
|
66 |
+
|
67 |
+
def get_width(tensor_shape):
|
68 |
+
"""Returns width from the tensor shape.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
tensor_shape: A rank 4 TensorShape.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
An integer representing the width of the tensor.
|
75 |
+
"""
|
76 |
+
tensor_shape.assert_has_rank(rank=4)
|
77 |
+
return get_dim_as_int(tensor_shape[2])
|
78 |
+
|
79 |
+
|
80 |
+
def get_depth(tensor_shape):
|
81 |
+
"""Returns depth from the tensor shape.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
tensor_shape: A rank 4 TensorShape.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
An integer representing the depth of the tensor.
|
88 |
+
"""
|
89 |
+
tensor_shape.assert_has_rank(rank=4)
|
90 |
+
return get_dim_as_int(tensor_shape[3])
|
string_int_label_map_pb2.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
2 |
+
# source: object_detection/protos/string_int_label_map.proto
|
3 |
+
|
4 |
+
import sys
|
5 |
+
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
|
6 |
+
from google.protobuf import descriptor as _descriptor
|
7 |
+
from google.protobuf import message as _message
|
8 |
+
from google.protobuf import reflection as _reflection
|
9 |
+
from google.protobuf import symbol_database as _symbol_database
|
10 |
+
from google.protobuf import descriptor_pb2
|
11 |
+
# @@protoc_insertion_point(imports)
|
12 |
+
|
13 |
+
_sym_db = _symbol_database.Default()
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
DESCRIPTOR = _descriptor.FileDescriptor(
|
19 |
+
name='object_detection/protos/string_int_label_map.proto',
|
20 |
+
package='object_detection.protos',
|
21 |
+
syntax='proto2',
|
22 |
+
serialized_pb=_b('\n2object_detection/protos/string_int_label_map.proto\x12\x17object_detection.protos\"G\n\x15StringIntLabelMapItem\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x14\n\x0c\x64isplay_name\x18\x03 \x01(\t\"Q\n\x11StringIntLabelMap\x12<\n\x04item\x18\x01 \x03(\x0b\x32..object_detection.protos.StringIntLabelMapItem')
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
_STRINGINTLABELMAPITEM = _descriptor.Descriptor(
|
29 |
+
name='StringIntLabelMapItem',
|
30 |
+
full_name='object_detection.protos.StringIntLabelMapItem',
|
31 |
+
filename=None,
|
32 |
+
file=DESCRIPTOR,
|
33 |
+
containing_type=None,
|
34 |
+
fields=[
|
35 |
+
_descriptor.FieldDescriptor(
|
36 |
+
name='name', full_name='object_detection.protos.StringIntLabelMapItem.name', index=0,
|
37 |
+
number=1, type=9, cpp_type=9, label=1,
|
38 |
+
has_default_value=False, default_value=_b("").decode('utf-8'),
|
39 |
+
message_type=None, enum_type=None, containing_type=None,
|
40 |
+
is_extension=False, extension_scope=None,
|
41 |
+
options=None, file=DESCRIPTOR),
|
42 |
+
_descriptor.FieldDescriptor(
|
43 |
+
name='id', full_name='object_detection.protos.StringIntLabelMapItem.id', index=1,
|
44 |
+
number=2, type=5, cpp_type=1, label=1,
|
45 |
+
has_default_value=False, default_value=0,
|
46 |
+
message_type=None, enum_type=None, containing_type=None,
|
47 |
+
is_extension=False, extension_scope=None,
|
48 |
+
options=None, file=DESCRIPTOR),
|
49 |
+
_descriptor.FieldDescriptor(
|
50 |
+
name='display_name', full_name='object_detection.protos.StringIntLabelMapItem.display_name', index=2,
|
51 |
+
number=3, type=9, cpp_type=9, label=1,
|
52 |
+
has_default_value=False, default_value=_b("").decode('utf-8'),
|
53 |
+
message_type=None, enum_type=None, containing_type=None,
|
54 |
+
is_extension=False, extension_scope=None,
|
55 |
+
options=None, file=DESCRIPTOR),
|
56 |
+
],
|
57 |
+
extensions=[
|
58 |
+
],
|
59 |
+
nested_types=[],
|
60 |
+
enum_types=[
|
61 |
+
],
|
62 |
+
options=None,
|
63 |
+
is_extendable=False,
|
64 |
+
syntax='proto2',
|
65 |
+
extension_ranges=[],
|
66 |
+
oneofs=[
|
67 |
+
],
|
68 |
+
serialized_start=79,
|
69 |
+
serialized_end=150,
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
_STRINGINTLABELMAP = _descriptor.Descriptor(
|
74 |
+
name='StringIntLabelMap',
|
75 |
+
full_name='object_detection.protos.StringIntLabelMap',
|
76 |
+
filename=None,
|
77 |
+
file=DESCRIPTOR,
|
78 |
+
containing_type=None,
|
79 |
+
fields=[
|
80 |
+
_descriptor.FieldDescriptor(
|
81 |
+
name='item', full_name='object_detection.protos.StringIntLabelMap.item', index=0,
|
82 |
+
number=1, type=11, cpp_type=10, label=3,
|
83 |
+
has_default_value=False, default_value=[],
|
84 |
+
message_type=None, enum_type=None, containing_type=None,
|
85 |
+
is_extension=False, extension_scope=None,
|
86 |
+
options=None, file=DESCRIPTOR),
|
87 |
+
],
|
88 |
+
extensions=[
|
89 |
+
],
|
90 |
+
nested_types=[],
|
91 |
+
enum_types=[
|
92 |
+
],
|
93 |
+
options=None,
|
94 |
+
is_extendable=False,
|
95 |
+
syntax='proto2',
|
96 |
+
extension_ranges=[],
|
97 |
+
oneofs=[
|
98 |
+
],
|
99 |
+
serialized_start=152,
|
100 |
+
serialized_end=233,
|
101 |
+
)
|
102 |
+
|
103 |
+
_STRINGINTLABELMAP.fields_by_name['item'].message_type = _STRINGINTLABELMAPITEM
|
104 |
+
DESCRIPTOR.message_types_by_name['StringIntLabelMapItem'] = _STRINGINTLABELMAPITEM
|
105 |
+
DESCRIPTOR.message_types_by_name['StringIntLabelMap'] = _STRINGINTLABELMAP
|
106 |
+
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
|
107 |
+
|
108 |
+
StringIntLabelMapItem = _reflection.GeneratedProtocolMessageType('StringIntLabelMapItem', (_message.Message,), dict(
|
109 |
+
DESCRIPTOR = _STRINGINTLABELMAPITEM,
|
110 |
+
__module__ = 'object_detection.protos.string_int_label_map_pb2'
|
111 |
+
# @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem)
|
112 |
+
))
|
113 |
+
_sym_db.RegisterMessage(StringIntLabelMapItem)
|
114 |
+
|
115 |
+
StringIntLabelMap = _reflection.GeneratedProtocolMessageType('StringIntLabelMap', (_message.Message,), dict(
|
116 |
+
DESCRIPTOR = _STRINGINTLABELMAP,
|
117 |
+
__module__ = 'object_detection.protos.string_int_label_map_pb2'
|
118 |
+
# @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap)
|
119 |
+
))
|
120 |
+
_sym_db.RegisterMessage(StringIntLabelMap)
|
121 |
+
|
122 |
+
|
123 |
+
# @@protoc_insertion_point(module_scope)
|
tf_label_map.pbtxt
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
item {
|
2 |
+
id: 1
|
3 |
+
name: 'yellowTag'
|
4 |
+
}
|
5 |
+
item {
|
6 |
+
id: 2
|
7 |
+
name: '0'
|
8 |
+
}
|
9 |
+
item {
|
10 |
+
id: 3
|
11 |
+
name: '1'
|
12 |
+
}
|
13 |
+
item {
|
14 |
+
id: 4
|
15 |
+
name: '2'
|
16 |
+
}
|
17 |
+
item {
|
18 |
+
id: 5
|
19 |
+
name: '3'
|
20 |
+
}
|
21 |
+
item {
|
22 |
+
id: 6
|
23 |
+
name: '4'
|
24 |
+
}
|
25 |
+
item {
|
26 |
+
id: 7
|
27 |
+
name: '5'
|
28 |
+
}
|
29 |
+
item {
|
30 |
+
id: 8
|
31 |
+
name: '6'
|
32 |
+
}
|
33 |
+
item {
|
34 |
+
id: 9
|
35 |
+
name: '7'
|
36 |
+
}
|
37 |
+
item {
|
38 |
+
id: 10
|
39 |
+
name: '8'
|
40 |
+
}
|
41 |
+
item {
|
42 |
+
id: 11
|
43 |
+
name: '9'
|
44 |
+
}
|
45 |
+
item {
|
46 |
+
id: 12
|
47 |
+
name: 'P'
|
48 |
+
}
|
49 |
+
item {
|
50 |
+
id: 13
|
51 |
+
name: 'G'
|
52 |
+
}
|
53 |
+
item {
|
54 |
+
id: 14
|
55 |
+
name: 'E'
|
56 |
+
}
|
57 |
+
item {
|
58 |
+
id: 15
|
59 |
+
name: 'H'
|
60 |
+
}
|
61 |
+
item {
|
62 |
+
id: 16
|
63 |
+
name: 'N'
|
64 |
+
}
|
65 |
+
item {
|
66 |
+
id: 17
|
67 |
+
name: 'S'
|
68 |
+
}
|
69 |
+
item {
|
70 |
+
id: 18
|
71 |
+
name: 'B'
|
72 |
+
}
|
73 |
+
item {
|
74 |
+
id: 19
|
75 |
+
name: 'M'
|
76 |
+
}
|
77 |
+
item {
|
78 |
+
id: 20
|
79 |
+
name: 'C'
|
80 |
+
}
|
81 |
+
item {
|
82 |
+
id: 21
|
83 |
+
name: 'W'
|
84 |
+
}
|
85 |
+
item {
|
86 |
+
id: 22
|
87 |
+
name: 'T'
|
88 |
+
}
|
89 |
+
item {
|
90 |
+
id: 23
|
91 |
+
name: 'Y'
|
92 |
+
}
|
93 |
+
item {
|
94 |
+
id: 24
|
95 |
+
name: 'A'
|
96 |
+
}
|
97 |
+
item {
|
98 |
+
id: 25
|
99 |
+
name: 'F'
|
100 |
+
}
|
101 |
+
item {
|
102 |
+
id: 26
|
103 |
+
name: 'D'
|
104 |
+
}
|
105 |
+
item {
|
106 |
+
id: 27
|
107 |
+
name: 'L'
|
108 |
+
}
|
109 |
+
item {
|
110 |
+
id: 28
|
111 |
+
name: 'X'
|
112 |
+
}
|
113 |
+
item {
|
114 |
+
id: 29
|
115 |
+
name: 'J'
|
116 |
+
}
|
117 |
+
item {
|
118 |
+
id: 30
|
119 |
+
name: 'I'
|
120 |
+
}
|