from tensorflow.keras.layers import Layer
[docs]class SelectorLayer(Layer):
"""
Layer which selects individual inputs
Example:
>>> # Create a model with two inputs and one SelectorLayer
>>> input_1 = tf.keras.layers.Input(10)
>>> input_2 = tf.keras.layers.Input(10)
>>> selector = mann.layers.SelectorLayer(1)([input_1, input_2]) # 1 here indicates to select the second input and return it
>>> model = tf.keras.models.Model([input_1, input_2], selector)
>>> model.compile()
>>> # Call the model
>>> data1 = np.arange(10).reshape((1, 10))
>>> data2 = 2*np.arange(10).reshape((1, 10))
>>> model.predict([data1, data2])
array([[ 0., 2., 4., 6., 8., 10., 12., 14., 16., 18.]], dtype=float32)
"""
def __init__(
self,
sel_index,
**kwargs
):
"""
Parameters
----------
sel_index : int
The index of the inputs to be selected
"""
super(SelectorLayer, self).__init__(**kwargs)
self.sel_index = sel_index
@property
def sel_index(self):
return self._sel_index
@sel_index.setter
def sel_index(self, value):
if not isinstance(value, int):
raise TypeError(
f'sel_index must be int, got {value}, type {type(value)}')
self._sel_index = value
[docs] def call(self, inputs):
"""
This is where the layer's logic lives and is called upon inputs
Parameters
----------
inputs : TensorFlow Tensor or Tensor-like
The inputs to the layer
Returns
-------
outputs : TensorFlow Tensor
The outputs of the layer's logic
"""
return inputs[self.sel_index]
[docs] def get_config(self):
config = super().get_config().copy()
config.update(
{
'sel_index': self.sel_index
}
)
return config
[docs] @classmethod
def from_config(cls, config):
return cls(
sel_index=config['sel_index']
)