import tensorflow as tf
from tensorflow.keras.layers import Layer
[docs]class MultiMaxPool3D(Layer):
    """
    Multitask 3D Max Pooling Layer. This layer implements the Max Pooling
    algorithm across multiple inputs for developing multitask models
    """
    def __init__(
        self,
        pool_size=(3, 3, 3),
        strides=(1, 1, 1),
        padding='same',
        **kwargs
    ):
        """
        Parameters
        ----------
        pool_size : integer or tuple of 3 integers (default (3, 3, 3))
            Window size over which to take the maximum
        strides : integer or tuple of 3 integers (default (1, 1, 1))
            Stride values to move the pooling window after each step
        padding : str (default 'same')
            One of either 'same' or 'valid', case-insensitive. The
            padding to apply to the inputs
        """
        super().__init__(**kwargs)
        self.pool_size = pool_size
        self.strides = strides
        self.padding = padding
[docs]    def call(self, inputs):
        """
        This is where the layer's logic lives and is called upon inputs
        Parameters
        ----------
        inputs : TensorFlow Tensor or Tensor-like
            The inputs to the layer
        Returns
        -------
        outputs : TensorFlow Tensor
            The outputs of the layer's logic
        """
        return [
            tf.nn.max_pool3d(
                input=inputs[i],
                ksize=self.pool_size,
                strides=self.strides,
                padding=self.padding.upper(),
                data_format='NDHWC'
            ) for i in range(len(inputs))
        ] 
[docs]    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                'pool_size': self.pool_size,
                'strides': self.strides,
                'padding': self.padding
            }
        )
        return config 
[docs]    @classmethod
    def from_config(cls, config):
        return cls(
            pool_size=config['pool_size'],
            strides=config['strides'],
            padding=config['padding']
        )