Source code for beyondml.tflow.layers.MultiMaskedConv3D

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer


[docs]class MultiMaskedConv3D(Layer): """ Masked multitask 3-dimensional convoluational layer. This layer implements multiple stacks of the convolutional architecture and implements masking consistent with the BeyondML API to support developing sparse multitask models. """ def __init__( self, filters, kernel_size=3, padding='same', strides=1, use_bias=True, activation=None, kernel_initializer='random_normal', bias_initializer='zeros', mask_initializer='ones', **kwargs ): """ Parameters ---------- filters : int The number of convolutional filters to apply kernel_size : int or tuple of ints (default 3) The kernel size in height and width padding : str (default 'same') Either 'same' or 'valid', the padding to use during convolution strides : int or tuple of ints Stride lenghts to use during convolution use_bias : bool (default True) Whether to use a bias calculation on the outputs activation : None, str, or function (default None) Activation function to use on the outputs kernel_initializer : str or keras initialization function (default 'random_normal') The initialization function to use for the weights bias_initializer : str or keras initialization function (default 'zeros') The initialization function to use for the bias mask_initializer : str or keras initialization function (default 'ones') The mask initialization function to use """ super().__init__(**kwargs) self.filters = int(filters) if not isinstance( filters, int) else filters self.kernel_size = kernel_size self.padding = padding self.strides = tuple(strides) if isinstance(strides, list) else strides self.activation = tf.keras.activations.get(activation) self.use_bias = use_bias self.kernel_initializer = tf.keras.initializers.get(kernel_initializer) self.bias_initializer = tf.keras.initializers.get(bias_initializer) self.mask_initializer = tf.keras.initializers.get(mask_initializer) @property def kernel_size(self): return self._kernel_size @kernel_size.setter def kernel_size(self, value): if isinstance(value, int): self._kernel_size = (value, value, value) else: self._kernel_size = value
[docs] def build(self, input_shape): """ Build the layer in preparation to be trained or called. Should not be called directly, but rather is called when the layer is added to a model """ try: input_shape = [ tuple(shape.as_list()) for shape in input_shape ] except AttributeError: # Sometimes, input shapes come as tuples already pass simplified_shape = input_shape[0] self.w = self.add_weight( shape=(len(input_shape), self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], simplified_shape[-1], self.filters), initializer=self.kernel_initializer, trainable=True, name='weights' ) self.w_mask = self.add_weight( shape=self.w.shape, initializer=self.mask_initializer, trainable=False, name='weights_mask' ) if self.use_bias: self.b = self.add_weight( shape=(len(input_shape), self.filters), initializer=self.bias_initializer, trainable=True, name='bias' ) self.b_mask = self.add_weight( shape=self.b.shape, initializer=self.mask_initializer, trainable=False, name='bias_mask' )
[docs] def call(self, inputs): """ This is where the layer's logic lives and is called upon inputs Parameters ---------- inputs : TensorFlow Tensor or Tensor-like The inputs to the layer Returns ------- outputs : TensorFlow Tensor The outputs of the layer's logic """ conv_outputs = [ tf.nn.convolution( inputs[i], self.w[i] * self.w_mask[i], padding=self.padding.upper(), strides=self.strides, data_format='NDHWC' ) for i in range(len(inputs)) ] if self.use_bias: conv_outputs = [ conv_outputs[i] + (self.b[i] * self.b_mask[i]) for i in range(len(conv_outputs)) ] return [self.activation(output) for output in conv_outputs]
[docs] def get_config(self): config = super().get_config().copy() config.update( { 'filters': self.filters, 'kernel_size': list(self.kernel_size), 'padding': self.padding, 'strides': self.strides, 'activation': tf.keras.activations.serialize(self.activation), 'use_bias': self.use_bias, 'kernel_initializer': tf.keras.initializers.serialize(self.kernel_initializer), 'bias_initializer': tf.keras.initializers.serialize(self.bias_initializer), 'mask_initializer': tf.keras.initializers.serialize(self.mask_initializer) } ) return config
[docs] def set_masks(self, new_masks): if not self.use_bias: self.set_weights( [self.w.numpy() * new_masks[0].astype(np.float32), new_masks[0].astype(np.float32)] ) else: self.set_weights( [self.w.numpy() * new_masks[0].astype(np.float32), self.b.numpy() * new_masks[1].astype( np.float32), new_masks[0].astype(np.float32), new_masks[1].astype(np.float32)] )
[docs] @classmethod def from_config(cls, config): return cls( filters=config['filters'], kernel_size=config['kernel_size'], padding=config['padding'], strides=config['strides'], activation=config['activation'], use_bias=config['use_bias'], kernel_initializer=config['kernel_initializer'], bias_initializer=config['bias_initializer'], mask_initializer=config['mask_initializer'] )