import torch
[docs]class MultitaskNormalization(torch.nn.Module):
"""
Layer which normalizes a set of inputs to sum to 1
"""
def __init__(
self,
device=None,
dtype=None
):
super().__init__()
self.factory_kwargs = {'device': device, 'dtype': dtype}
[docs] def forward(self, inputs):
"""
Call the layer on input data
Parameters
----------
inputs : torch.Tensor or list of Tensors
Inputs to call the layer's logic on
Returns
-------
results : torch.Tensor or list of Tensors
The results of the layer's logic
"""
s = 0
for i in inputs:
s += 1
return [i / s for i in inputs]