import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer
[docs]class FilterLayer(Layer):
"""
Layer which filters inputs based on status of `on` or `off`
Example:
>>> # Create a model with just a FilterLayer
>>> input_layer = tf.keras.layers.Input(10)
>>> filter_layer = mann.layers.FilterLayer()(input_layer)
>>> model = tf.keras.models.Model(input_layer, filter_layer)
>>> model.compile()
>>> # Call the model with the layer turned on
>>> data = np.arange(10).reshape((1, 10))
>>> model.predict(data)
array([[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]], dtype=float32)
>>> # Turn off the FilterLayer and call it again
>>> model.layers[-1].turn_off()
>>> # Model must be recompiled after turning the layer on or off
>>> model.compile()
>>> model.predict(data)
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)
"""
def __init__(
self,
is_on=True,
**kwargs
):
super(FilterLayer, self).__init__(**kwargs)
self.is_on = is_on
[docs] def call(self, inputs):
"""
This is where the layer's logic lives and is called upon inputs
Parameters
----------
inputs : TensorFlow Tensor or Tensor-like
The inputs to the layer
Returns
-------
outputs : TensorFlow Tensor
The outputs of the layer's logic
"""
if self.is_on:
return inputs
else:
return tf.zeros_like(inputs)
[docs] def get_config(self):
config = super().get_config().copy()
config.update({'is_on': self.is_on})
return config
[docs] def turn_on(self):
"""Turn the layer `on` so inputs are returned unchanged as outputs"""
self.is_on = True
[docs] def turn_off(self):
"""Turn the layer `off` so inputs are destroyed and all-zero tensors are output"""
self.is_on = False
[docs] @classmethod
def from_config(cls, config):
return cls(**config)