Source code for beyondml.pt.layers.FilterLayer

import torch


[docs]class FilterLayer(torch.nn.Module): """ Layer which filters input data, either returning values or all zeros depending on state """ def __init__( self, is_on=True, device=None, dtype=None ): """ Parameters ---------- is_on : bool (default False) Whether the layer is on or off """ super().__init__() self.is_on = is_on self.factory_kwargs = {'device': device, 'dtype': dtype} @property def is_on(self): return self._is_on @is_on.setter def is_on(self, value): if not isinstance(value, bool): raise TypeError('is_on must be Boolean') self._is_on = value
[docs] def forward(self, inputs): """ Call the layer on input data Parameters ---------- inputs : torch.Tensor Inputs to call the layer's logic on Returns ------- results : torch.Tensor The results of the layer's logic """ if self.is_on: return inputs else: return torch.zeros_like(inputs, **self.factory_kwargs)
[docs] def turn_on(self): """ Turn on the layer """ self.is_on = True
[docs] def turn_off(self): """ Turn off the layer """ self.is_on = False