import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer
[docs]class MultiMaskedConv3D(Layer):
"""
Masked multitask 3-dimensional convoluational layer. This layer implements
multiple stacks of the convolutional architecture and implements masking
consistent with the BeyondML API to support developing sparse multitask models.
"""
def __init__(
self,
filters,
kernel_size=3,
padding='same',
strides=1,
use_bias=True,
activation=None,
kernel_initializer='random_normal',
bias_initializer='zeros',
mask_initializer='ones',
**kwargs
):
"""
Parameters
----------
filters : int
The number of convolutional filters to apply
kernel_size : int or tuple of ints (default 3)
The kernel size in height and width
padding : str (default 'same')
Either 'same' or 'valid', the padding to use during convolution
strides : int or tuple of ints
Stride lenghts to use during convolution
use_bias : bool (default True)
Whether to use a bias calculation on the outputs
activation : None, str, or function (default None)
Activation function to use on the outputs
kernel_initializer : str or keras initialization function (default 'random_normal')
The initialization function to use for the weights
bias_initializer : str or keras initialization function (default 'zeros')
The initialization function to use for the bias
mask_initializer : str or keras initialization function (default 'ones')
The mask initialization function to use
"""
super().__init__(**kwargs)
self.filters = int(filters) if not isinstance(
filters, int) else filters
self.kernel_size = kernel_size
self.padding = padding
self.strides = tuple(strides) if isinstance(strides, list) else strides
self.activation = tf.keras.activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self.bias_initializer = tf.keras.initializers.get(bias_initializer)
self.mask_initializer = tf.keras.initializers.get(mask_initializer)
@property
def kernel_size(self):
return self._kernel_size
@kernel_size.setter
def kernel_size(self, value):
if isinstance(value, int):
self._kernel_size = (value, value, value)
else:
self._kernel_size = value
[docs] def build(self, input_shape):
"""
Build the layer in preparation to be trained or called. Should not be called directly,
but rather is called when the layer is added to a model
"""
try:
input_shape = [
tuple(shape.as_list()) for shape in input_shape
]
except AttributeError:
# Sometimes, input shapes come as tuples already
pass
simplified_shape = input_shape[0]
self.w = self.add_weight(
shape=(len(input_shape), self.kernel_size[0], self.kernel_size[1],
self.kernel_size[2], simplified_shape[-1], self.filters),
initializer=self.kernel_initializer,
trainable=True,
name='weights'
)
self.w_mask = self.add_weight(
shape=self.w.shape,
initializer=self.mask_initializer,
trainable=False,
name='weights_mask'
)
if self.use_bias:
self.b = self.add_weight(
shape=(len(input_shape), self.filters),
initializer=self.bias_initializer,
trainable=True,
name='bias'
)
self.b_mask = self.add_weight(
shape=self.b.shape,
initializer=self.mask_initializer,
trainable=False,
name='bias_mask'
)
[docs] def call(self, inputs):
"""
This is where the layer's logic lives and is called upon inputs
Parameters
----------
inputs : TensorFlow Tensor or Tensor-like
The inputs to the layer
Returns
-------
outputs : TensorFlow Tensor
The outputs of the layer's logic
"""
conv_outputs = [
tf.nn.convolution(
inputs[i],
self.w[i] * self.w_mask[i],
padding=self.padding.upper(),
strides=self.strides,
data_format='NDHWC'
) for i in range(len(inputs))
]
if self.use_bias:
conv_outputs = [
conv_outputs[i] + (self.b[i] * self.b_mask[i]) for i in range(len(conv_outputs))
]
return [self.activation(output) for output in conv_outputs]
[docs] def get_config(self):
config = super().get_config().copy()
config.update(
{
'filters': self.filters,
'kernel_size': list(self.kernel_size),
'padding': self.padding,
'strides': self.strides,
'activation': tf.keras.activations.serialize(self.activation),
'use_bias': self.use_bias,
'kernel_initializer': tf.keras.initializers.serialize(self.kernel_initializer),
'bias_initializer': tf.keras.initializers.serialize(self.bias_initializer),
'mask_initializer': tf.keras.initializers.serialize(self.mask_initializer)
}
)
return config
[docs] def set_masks(self, new_masks):
if not self.use_bias:
self.set_weights(
[self.w.numpy() * new_masks[0].astype(np.float32),
new_masks[0].astype(np.float32)]
)
else:
self.set_weights(
[self.w.numpy() * new_masks[0].astype(np.float32), self.b.numpy() * new_masks[1].astype(
np.float32), new_masks[0].astype(np.float32), new_masks[1].astype(np.float32)]
)
[docs] @classmethod
def from_config(cls, config):
return cls(
filters=config['filters'],
kernel_size=config['kernel_size'],
padding=config['padding'],
strides=config['strides'],
activation=config['activation'],
use_bias=config['use_bias'],
kernel_initializer=config['kernel_initializer'],
bias_initializer=config['bias_initializer'],
mask_initializer=config['mask_initializer']
)