Source code for beyondml.pt.layers.MaskedTransformerEncoderLayer
import torch
from typing import Optional, Any, Union, Callable
import torch
from torch import Tensor
from torch.nn import Dropout, LayerNorm
from beyondml.pt.layers import MaskedDense
from torch.nn import functional as F
from .MaskedMultiHeadAttention import MaskedMultiHeadAttention
[docs]class MaskedTransformerEncoderLayer(torch.nn.Module):
"""TransformerEncoderLayer is made up of self-attn and feedforward network.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of the intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
norm_first: if ``True``, layer norm is done prior to attention and feedforward
operations, respectivaly. Otherwise it's done after. Default: ``False`` (after).
"""
__constants__ = ['batch_first', 'norm_first']
def __init__(self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor],
Tensor]] = torch.nn.functional.relu,
layer_norm_eps: float = 1e-5,
batch_first: bool = False,
norm_first: bool = False,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(MaskedTransformerEncoderLayer, self).__init__()
self.self_attn = MaskedMultiHeadAttention(
d_model,
nhead,
dropout=dropout,
batch_first=batch_first,
**factory_kwargs
)
# Implementation of Feedforward model
self.linear1 = MaskedDense(d_model, dim_feedforward, **factory_kwargs)
self.dropout = Dropout(dropout)
self.linear2 = MaskedDense(d_model, dim_feedforward, **factory_kwargs)
self.norm_first = norm_first
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
def __setstate__(self, state):
super(MaskedTransformerEncoderLayer, self).__setstate__(state)
if not hasattr(self, 'activation'):
self.activation = F.relu
[docs] def forward(self, src: Tensor):
"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
"""
x = src
x = self._sa_block(x)
x = self._ff_block(x)
return x
# self-attention block
def _sa_block(self, x: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
x = self.self_attn(x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False)[0]
return self.dropout1(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
[docs] def prune(self, percentile):
self.self_attn.prune(percentile)
self.linear1.prune(percentile)
self.linear2.prune(percentile)