from torch.nn import Module
[docs]class SelectorLayer(Module):
    """
    Layer which selects an individual input based on index and only returns that one
    """
    def __init__(
        self,
        sel_index
    ):
        """
        Parameters
        ----------
        sel_index : int
            The index of inputs to select
        """
        super().__init__()
        self.sel_index = sel_index
    @property
    def sel_index(self):
        return self._sel_index
    @sel_index.setter
    def sel_index(self, value):
        if not isinstance(value, int):
            raise TypeError('sel_index must be integer-valued')
        self._sel_index = value
[docs]    def forward(self, inputs):
        """
        Call the layer on input data
        Parameters
        ----------
        inputs : torch.Tensor
            Inputs to call the layer's logic on
        Returns
        -------
        results : torch.Tensor
            The results of the layer's logic
        """
        return inputs[self.sel_index]