import numpy as np
import torch
[docs]class MaskedDense(torch.nn.Module):
"""
Masked fully-connected layer
"""
def __init__(
self,
in_features,
out_features,
device=None,
dtype=None
):
"""
Parameters
----------
in_features : int
The number of features input to the layer
out_features : int
The number of features to be output by the layer.
Also considered the number of artificial neurons
"""
super().__init__()
factory_kwargs = {'device': device, 'dtype': dtype}
self.in_features = in_features
self.out_features = out_features
weight = torch.Tensor(
in_features,
out_features,
).to(**factory_kwargs)
weight = torch.nn.init.kaiming_normal_(weight, a=np.sqrt(5))
self.w = torch.nn.Parameter(weight)
self.register_buffer(
'w_mask', torch.ones_like(self.w, **factory_kwargs))
bias = torch.zeros(out_features, **factory_kwargs)
self.b = torch.nn.Parameter(bias)
self.register_buffer('b_mask', torch.ones_like(bias, **factory_kwargs))
[docs] def forward(self, inputs):
"""
Call the layer on input data
Parameters
----------
inputs : torch.Tensor
Inputs to call the layer's logic on
Returns
-------
results : torch.Tensor
The results of the layer's logic
"""
weight = self.w * self.w_mask
bias = self.b * self.b_mask
out = torch.matmul(inputs, weight)
out = torch.add(out, bias)
return out
[docs] def prune(self, percentile):
"""
Prune the layer by updating the layer's mask
Parameters
----------
percentile : int
Integer between 0 and 99 which represents the proportion of weights to be inactive
Notes
-----
Acts on the layer in place
"""
w_copy = np.abs(self.w.detach().cpu().numpy())
b_copy = np.abs(self.b.detach().cpu().numpy())
w_percentile = np.percentile(w_copy, percentile)
b_percentile = np.percentile(b_copy, percentile)
new_w_mask = torch.Tensor(
(w_copy >= w_percentile).astype(int))
new_b_mask = torch.Tensor(
(b_copy >= b_percentile).astype(int))
self.w_mask[:] = new_w_mask
self.b_mask[:] = new_b_mask
self.w = torch.nn.Parameter(
self.w.detach() * self.w_mask
)
self.b = torch.nn.Parameter(
self.b.detach() * self.b_mask
)