import numpy as np
import torch
[docs]class MaskedConv3D(torch.nn.Module):
"""
Masked 3D Convolutional layer
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
padding='same',
strides=1,
device=None,
dtype=None
):
"""
Parameters
----------
in_channels : int
The number of channels for input data
out_channels : int
The number of filters to use
kernel_size : int or tuple (default 3)
The kernel size to use
padding : int or str (default 'same')
Padding to use
strides : int or tuple (default 1)
The number of strides to use
"""
super().__init__()
factory_kwargs = {'device': device, 'dtype': dtype}
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.strides = strides
filters = torch.Tensor(
self.out_channels,
self.in_channels,
self.kernel_size[0],
self.kernel_size[1],
self.kernel_size[2]
).to(**factory_kwargs)
filters = torch.nn.init.kaiming_normal_(filters, a=np.sqrt(5))
self.w = torch.nn.Parameter(filters)
self.register_buffer(
'w_mask', torch.ones_like(self.w, **factory_kwargs))
bias = torch.zeros(out_channels)
self.b = torch.nn.Parameter(bias)
self.register_buffer(
'b_mask', torch.ones_like(self.b, **factory_kwargs))
@property
def in_channels(self):
return self._in_channels
@in_channels.setter
def in_channels(self, value):
if not isinstance(value, int):
raise TypeError('in_channels must be int')
self._in_channels = value
@property
def out_channels(self):
return self._out_channels
@out_channels.setter
def out_channels(self, value):
if not isinstance(value, int):
raise TypeError('out_channels must be int')
self._out_channels = value
@property
def kernel_size(self):
return self._kernel_size
@kernel_size.setter
def kernel_size(self, value):
if isinstance(value, int):
value = (value, value, value)
elif isinstance(value, tuple):
if not all([isinstance(val, int) for val in value]) and len(value) == 3:
raise ValueError(
'If tuple, kernel_size must be three integers')
else:
raise TypeError('kernel_size must be int or tuple')
self._kernel_size = value
[docs] def forward(self, inputs):
"""
Call the layer on input data
Parameters
----------
inputs : torch.Tensor
Inputs to call the layer's logic on
Returns
-------
results : torch.Tensor
The results of the layer's logic
"""
return torch.nn.functional.conv3d(
inputs,
self.w * self.w_mask,
self.b * self.b_mask,
stride=self.strides,
padding=self.padding
)
[docs] def prune(self, percentile):
"""
Prune the layer by updating the layer's masks
Parameters
----------
percentile : int
Integer between 0 and 99 which represents the proportion of weights to be inactive
Notes
-----
Acts on the layer in place
"""
w_copy = np.abs(self.w.detach().cpu().numpy())
b_copy = np.abs(self.b.detach().cpu().numpy())
w_percentile = np.percentile(w_copy, percentile)
b_percentile = np.percentile(b_copy, percentile)
self.w_mask[:] = torch.Tensor(
(w_copy >= w_percentile).astype(int))
self.b_mask[:] = torch.Tensor(
(b_copy >= b_percentile).astype(int))
self.w = torch.nn.Parameter(
self.w.detach() * self.w_mask
)
self.b = torch.nn.Parameter(
self.b.detach() * self.b_mask
)