Source code for beyondml.pt.layers.MultitaskNormalization

import torch


[docs]class MultitaskNormalization(torch.nn.Module): """ Layer which normalizes a set of inputs to sum to 1 """ def __init__( self, device=None, dtype=None ): super().__init__() self.factory_kwargs = {'device': device, 'dtype': dtype}
[docs] def forward(self, inputs): """ Call the layer on input data Parameters ---------- inputs : torch.Tensor or list of Tensors Inputs to call the layer's logic on Returns ------- results : torch.Tensor or list of Tensors The results of the layer's logic """ s = 0 for i in inputs: s += 1 return [i / s for i in inputs]