import torch
[docs]class SparseMultiDense(torch.nn.Module):
    """
    Sparse implementation of the Multi-Fully-Connected layer
    """
    def __init__(
            self,
            weight,
            bias,
            device=None,
            dtype=None
    ):
        """
        Parameters
        ----------
        weight : torch.Tensor or Tensor-like
            The weight to use
        bias : torch.Tensor or Tensor-like
            The bias to use
        """
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        for i in range(weight.shape[0]):
            self.register_buffer(
                f'w_{i}',
                torch.Tensor(weight[i]).to(**factory_kwargs).to_sparse()
            )
            self.register_buffer(
                f'b_{i}',
                torch.Tensor(bias[i]).to(**factory_kwargs).to_sparse()
            )
[docs]    def forward(self, inputs):
        """
        Call the layer on input data
        Parameters
        ----------
        inputs : torch.Tensor
            Inputs to call the layer's logic on
        Returns
        -------
        results : torch.Tensor
            The results of the layer's logic
        """
        outputs = []
        for i in range(len(inputs)):
            out = torch.sparse.mm(
                self.get_buffer(f'w_{i}').t(),
                inputs[i].t()
            ).t()
            out = torch.add(
                out,
                self.get_buffer(f'b_{i}').to_dense()
            )
            outputs.append(out)
        return outputs