DDMR / ddmr /layers /depthwise_conv_3d.py
andreped's picture
Renamed module to ddmr
a27d55f
# SRC: https://github.com/alexandrosstergiou/keras-DepthwiseConv3D
'''
This is a modification of the SeparableConv3D code in Keras,
to perform just the Depthwise Convolution (1st step) of the
Depthwise Separable Convolution layer.
'''
from __future__ import absolute_import
from tensorflow.keras import backend as K
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras import constraints
import tensorflow.keras.utils as conv_utils
from tensorflow.keras.layers import Conv3D, InputSpec
from tensorflow.python.keras.backend import _preprocess_padding, _preprocess_conv3d_input
import tensorflow as tf
class DepthwiseConv3D(Conv3D):
"""Depthwise 3D convolution.
Depth-wise part of separable convolutions consist in performing
just the first step/operation
(which acts on each input channel separately).
It does not perform the pointwise convolution (second step).
The `depth_multiplier` argument controls how many
output channels are generated per input channel in the depthwise step.
# Arguments
kernel_size: An integer or tuple/list of 3 integers, specifying the
depth, width and height of the 3D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 3 integers,
specifying the strides of the convolution along the depth, width and height.
Can be a single integer to specify the same value for
all spatial dimensions.
padding: one of `"valid"` or `"same"` (case-insensitive).
depth_multiplier: The number of depthwise convolution output channels
for each input channel.
The total number of depthwise convolution output
channels will be equal to `filterss_in * depth_multiplier`.
groups: The depth size of the convolution (as a variant of the original Depthwise conv)
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
activation: Activation function to use
(see [activations](../activations.md)).
If you don't specify anything, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
depthwise_initializer: Initializer for the depthwise kernel matrix
(see [initializers](../initializers.md)).
bias_initializer: Initializer for the bias vector
(see [initializers](../initializers.md)).
depthwise_regularizer: Regularizer function applied to
the depthwise kernel matrix
(see [regularizer](../regularizers.md)).
bias_regularizer: Regularizer function applied to the bias vector
(see [regularizer](../regularizers.md)).
dialation_rate: List of ints.
Defines the dilation factor for each dimension in the
input. Defaults to (1,1,1)
activity_regularizer: Regularizer function applied to
the output of the layer (its "activation").
(see [regularizer](../regularizers.md)).
depthwise_constraint: Constraint function applied to
the depthwise kernel matrix
(see [constraints](../constraints.md)).
bias_constraint: Constraint function applied to the bias vector
(see [constraints](../constraints.md)).
# Input shape
5D tensor with shape:
`(batch, depth, channels, rows, cols)` if data_format='channels_first'
or 5D tensor with shape:
`(batch, depth, rows, cols, channels)` if data_format='channels_last'.
# Output shape
5D tensor with shape:
`(batch, filters * depth, new_depth, new_rows, new_cols)` if data_format='channels_first'
or 4D tensor with shape:
`(batch, new_depth, new_rows, new_cols, filters * depth)` if data_format='channels_last'.
`rows` and `cols` values might have changed due to padding.
"""
#@legacy_depthwise_conv3d_support
def __init__(self,
kernel_size,
strides=(1, 1, 1),
padding='valid',
depth_multiplier=1,
groups=None,
data_format=None,
activation=None,
use_bias=True,
depthwise_initializer='glorot_uniform',
bias_initializer='zeros',
dilation_rate = (1, 1, 1),
depthwise_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
depthwise_constraint=None,
bias_constraint=None,
**kwargs):
super(DepthwiseConv3D, self).__init__(
filters=None,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
activation=activation,
use_bias=use_bias,
bias_regularizer=bias_regularizer,
dilation_rate=dilation_rate,
activity_regularizer=activity_regularizer,
bias_constraint=bias_constraint,
**kwargs)
self.depth_multiplier = depth_multiplier
self.groups = groups
self.depthwise_initializer = initializers.get(depthwise_initializer)
self.depthwise_regularizer = regularizers.get(depthwise_regularizer)
self.depthwise_constraint = constraints.get(depthwise_constraint)
self.bias_initializer = initializers.get(bias_initializer)
self.dilation_rate = dilation_rate
self._padding = _preprocess_padding(self.padding)
self._strides = (1,) + self.strides + (1,)
self._data_format = "NDHWC"
self.input_dim = None
def build(self, input_shape):
if len(input_shape) < 5:
raise ValueError('Inputs to `DepthwiseConv3D` should have rank 5. '
'Received input shape:', str(input_shape))
if self.data_format == 'channels_first':
channel_axis = 1
else:
channel_axis = -1
if input_shape[channel_axis] is None:
raise ValueError('The channel dimension of the inputs to '
'`DepthwiseConv3D` '
'should be defined. Found `None`.')
self.input_dim = int(input_shape[channel_axis])
if self.groups is None:
self.groups = self.input_dim
if self.groups > self.input_dim:
raise ValueError('The number of groups cannot exceed the number of channels')
if self.input_dim % self.groups != 0:
raise ValueError('Warning! The channels dimension is not divisible by the group size chosen')
depthwise_kernel_shape = (self.kernel_size[0],
self.kernel_size[1],
self.kernel_size[2],
self.input_dim,
self.depth_multiplier)
self.depthwise_kernel = self.add_weight(
shape=depthwise_kernel_shape,
initializer=self.depthwise_initializer,
name='depthwise_kernel',
regularizer=self.depthwise_regularizer,
constraint=self.depthwise_constraint)
if self.use_bias:
self.bias = self.add_weight(shape=(self.groups * self.depth_multiplier,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
# Set input spec.
self.input_spec = InputSpec(ndim=5, axes={channel_axis: self.input_dim})
self.built = True
def call(self, inputs, training=None):
inputs = _preprocess_conv3d_input(inputs, self.data_format)
if self.data_format == 'channels_last':
dilation = (1,) + self.dilation_rate + (1,)
else:
dilation = self.dilation_rate + (1,) + (1,)
if self._data_format == 'NCDHW':
outputs = tf.concat(
[tf.nn.conv3d(inputs[0][:, i:i+self.input_dim//self.groups, :, :, :], self.depthwise_kernel[:, :, :, i:i+self.input_dim//self.groups, :],
strides=self._strides,
padding=self._padding,
dilations=dilation,
data_format=self._data_format) for i in range(0, self.input_dim, self.input_dim//self.groups)], axis=1)
else:
outputs = tf.concat(
[tf.nn.conv3d(inputs[0][:, :, :, :, i:i+self.input_dim//self.groups], self.depthwise_kernel[:, :, :, i:i+self.input_dim//self.groups, :],
strides=self._strides,
padding=self._padding,
dilations=dilation,
data_format=self._data_format) for i in range(0, self.input_dim, self.input_dim//self.groups)], axis=-1)
if self.bias is not None:
outputs = K.bias_add(
outputs,
self.bias,
data_format=self.data_format)
if self.activation is not None:
return self.activation(outputs)
return outputs
def compute_output_shape(self, input_shape):
if self.data_format == 'channels_first':
depth = input_shape[2]
rows = input_shape[3]
cols = input_shape[4]
out_filters = self.groups * self.depth_multiplier
elif self.data_format == 'channels_last':
depth = input_shape[1]
rows = input_shape[2]
cols = input_shape[3]
out_filters = self.groups * self.depth_multiplier
depth = conv_utils.conv_output_length(depth, self.kernel_size[0],
self.padding,
self.strides[0])
rows = conv_utils.conv_output_length(rows, self.kernel_size[1],
self.padding,
self.strides[1])
cols = conv_utils.conv_output_length(cols, self.kernel_size[2],
self.padding,
self.strides[2])
if self.data_format == 'channels_first':
return (input_shape[0], out_filters, depth, rows, cols)
elif self.data_format == 'channels_last':
return (input_shape[0], depth, rows, cols, out_filters)
def get_config(self):
config = super(DepthwiseConv3D, self).get_config()
config.pop('filters')
config.pop('kernel_initializer')
config.pop('kernel_regularizer')
config.pop('kernel_constraint')
config['depth_multiplier'] = self.depth_multiplier
config['depthwise_initializer'] = initializers.serialize(self.depthwise_initializer)
config['depthwise_regularizer'] = regularizers.serialize(self.depthwise_regularizer)
config['depthwise_constraint'] = constraints.serialize(self.depthwise_constraint)
return config
def __call__(self, inputs, training=True):
return self.call(inputs, training)
DepthwiseConvolution3D = DepthwiseConv3D