Source code for beyondml.pt.layers.MultiMaskedDense

import numpy as np
import torch


[docs]class MultiMaskedDense(torch.nn.Module): """ Multi-Fully-Connected layer which supports masking and pruning """ def __init__( self, in_features, out_features, num_tasks, device=None, dtype=None ): """ Parameters ---------- in_features : int The number of input features out_features : int The number of output features. Also known as the number of artificial neurons num_tasks : int The number of tasks to initialize for """ super().__init__() factory_kwargs = {'device': device, 'dtype': dtype} self.in_features = in_features self.out_features = out_features self.num_tasks = num_tasks weight = torch.Tensor( num_tasks, in_features, out_features ).to(**factory_kwargs) weight = torch.nn.init.kaiming_normal_(weight, a=np.sqrt(5)) self.w = torch.nn.Parameter(weight) self.register_buffer( 'w_mask', torch.ones_like(self.w, **factory_kwargs)) bias = torch.zeros(num_tasks, out_features, **factory_kwargs) self.b = torch.nn.Parameter(bias) self.register_buffer('b_mask', torch.ones_like(bias, **factory_kwargs))
[docs] def forward(self, inputs): """ Call the layer on input data Parameters ---------- inputs : torch.Tensor Inputs to call the layer's logic on Returns ------- results : torch.Tensor The results of the layer's logic """ outputs = [] for i in range(len(inputs)): out = torch.mm(inputs[i], self.w[i] * self.w_mask[i]) out = torch.add(out, self.b[i] * self.b_mask[i]) outputs.append(out) return outputs
[docs] def prune(self, percentile): """ Prune the layer by updating the layer's mask Parameters ---------- percentile : int Integer between 0 and 99 which represents the proportion of weights to be inactive Notes ----- Acts on the layer in place """ w_copy = np.abs(self.w.detach().cpu().numpy()) b_copy = np.abs(self.b.detach().cpu().numpy()) new_w_mask = np.zeros_like(w_copy) new_b_mask = np.zeros_like(b_copy) for task_num in range(self.num_tasks): if task_num != 0: for prev_idx in range(task_num): w_copy[task_num][new_w_mask[prev_idx] == 1] = 0 b_copy[task_num][new_b_mask[prev_idx] == 1] = 0 w_percentile = np.percentile(w_copy[task_num], percentile) b_percentile = np.percentile(b_copy[task_num], percentile) new_w_mask[task_num] = ( w_copy[task_num] >= w_percentile).astype(int) new_b_mask[task_num] = ( b_copy[task_num] >= b_percentile).astype(int) self.w_mask[:] = torch.Tensor(new_w_mask) self.b_mask[:] = torch.Tensor(new_b_mask) self.w = torch.nn.Parameter( self.w.detach() * self.w_mask ) self.b = torch.nn.Parameter( self.b.detach() * self.b_mask )