Source code for beyondml.tflow.layers.FilterLayer

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer


[docs]class FilterLayer(Layer): """ Layer which filters inputs based on status of `on` or `off` Example: >>> # Create a model with just a FilterLayer >>> input_layer = tf.keras.layers.Input(10) >>> filter_layer = mann.layers.FilterLayer()(input_layer) >>> model = tf.keras.models.Model(input_layer, filter_layer) >>> model.compile() >>> # Call the model with the layer turned on >>> data = np.arange(10).reshape((1, 10)) >>> model.predict(data) array([[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]], dtype=float32) >>> # Turn off the FilterLayer and call it again >>> model.layers[-1].turn_off() >>> # Model must be recompiled after turning the layer on or off >>> model.compile() >>> model.predict(data) array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32) """ def __init__( self, is_on=True, **kwargs ): super(FilterLayer, self).__init__(**kwargs) self.is_on = is_on
[docs] def call(self, inputs): """ This is where the layer's logic lives and is called upon inputs Parameters ---------- inputs : TensorFlow Tensor or Tensor-like The inputs to the layer Returns ------- outputs : TensorFlow Tensor The outputs of the layer's logic """ if self.is_on: return inputs else: return tf.zeros_like(inputs)
[docs] def get_config(self): config = super().get_config().copy() config.update({'is_on': self.is_on}) return config
[docs] def turn_on(self): """Turn the layer `on` so inputs are returned unchanged as outputs""" self.is_on = True
[docs] def turn_off(self): """Turn the layer `off` so inputs are destroyed and all-zero tensors are output""" self.is_on = False
[docs] @classmethod def from_config(cls, config): return cls(**config)