import torch
[docs]class FilterLayer(torch.nn.Module):
    """
    Layer which filters input data, either returning values or all zeros depending on state
    """
    def __init__(
        self,
        is_on=True,
        device=None,
        dtype=None
    ):
        """
        Parameters
        ----------
        is_on : bool (default False)
            Whether the layer is on or off
        """
        super().__init__()
        self.is_on = is_on
        self.factory_kwargs = {'device': device, 'dtype': dtype}
    @property
    def is_on(self):
        return self._is_on
    @is_on.setter
    def is_on(self, value):
        if not isinstance(value, bool):
            raise TypeError('is_on must be Boolean')
        self._is_on = value
[docs]    def forward(self, inputs):
        """
        Call the layer on input data
        Parameters
        ----------
        inputs : torch.Tensor
            Inputs to call the layer's logic on
        Returns
        -------
        results : torch.Tensor
            The results of the layer's logic
        """
        if self.is_on:
            return inputs
        else:
            return torch.zeros_like(inputs, **self.factory_kwargs) 
[docs]    def turn_on(self):
        """
        Turn on the layer
        """
        self.is_on = True 
[docs]    def turn_off(self):
        """
        Turn off the layer
        """
        self.is_on = False