Source code for beyondml.pt.layers.MaskedMultiHeadAttention

from .MaskedDense import MaskedDense
import numpy as np
import torch


[docs]class MaskedMultiHeadAttention(torch.nn.Module): """ Masked Multi-Headed Attention Layer """ def __init__( self, embed_dim, num_heads, dropout=0, batch_first=False, device=None, dtype=None ): """ Parameters ---------- embed_dim : int The embedding dimension num_heads : int The number of attention heads dropout : float (default 0) The dropout rate to apply batch_first : bool (default False) Whether the batch dimension is first """ super().__init__() factory_kwargs = {'device': device, 'dtype': dtype} self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.batch_first = batch_first self.head_dim = embed_dim // num_heads if self.head_dim * self.num_heads != embed_dim: raise ValueError('num_heads must evenly divide embed_dim') in_proj_weight = torch.Tensor( 3 * embed_dim, embed_dim).to(**factory_kwargs) in_proj_weight = torch.nn.init.xavier_uniform_(in_proj_weight) self.in_proj_weight = torch.nn.Parameter(in_proj_weight) self.register_buffer('in_proj_weight_mask', torch.ones_like( self.in_proj_weight, **factory_kwargs)) self.in_proj_bias = torch.nn.Parameter( torch.zeros((3 * embed_dim), **factory_kwargs)) self.register_buffer('in_proj_bias_mask', torch.ones_like( self.in_proj_bias, **factory_kwargs)) self.out_proj = MaskedDense( embed_dim, embed_dim, **factory_kwargs) self.out_proj_weight = self.out_proj.w self.out_proj_weight_mask = self.out_proj.w_mask self.out_proj_bias = self.out_proj.b self.out_proj_bias_mask = self.out_proj.b_mask
[docs] def forward( self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True ): """ Call the layer on input data Parameters ---------- query : torch Tensor Query tensor key : torch Tensor Key tensor value : torch Tensor Value tensor key_padding_mask : None or torch Tensor (default None) If specified, a mask indicating which elements in ``key`` to ignore need_weights : Bool (default True) If specified, returns ``attn_output_weights`` as well as ``attn_outputs`` attn_mask : None or torch Tensor (default None) If specified, a 2D or 3D mask preventing attention average_attn_weights : Bool (default True) If True, indicates that returned ``attn_weights`` should be averaged across heads """ is_batched = query.dim() == 3 if self.batch_first and is_batched: if key is value: if query is key: query = key = value = query.transpose(1, 0) else: query, key = [x.transpose(1, 0) for x in (query, key)] value = key else: query, key, value = [x.transpose( 1, 0) for x in (query, key, value)] attn_output, attn_output_weights = torch.nn.functional.multi_head_attention_forward( query=query, key=key, value=value, embed_dim_to_check=self.embed_dim, num_heads=self.num_heads, in_proj_weight=self.in_proj_weight * self.in_proj_weight_mask, in_proj_bias=self.in_proj_bias * self.in_proj_bias_mask, bias_k=None, bias_v=None, add_zero_attn=False, dropout_p=self.dropout, out_proj_weight=self.out_proj_weight * self.out_proj_weight_mask, out_proj_bias=self.out_proj_bias * self.out_proj_bias_mask ) if self.batch_first and is_batched: return attn_output.transpose(1, 0), attn_output_weights else: return attn_output, attn_output_weights
[docs] def prune(self, percentile): """ Prune the layer by updating the layer's mask Parameters ---------- percentile : int Integer between 0 and 99 which represents the proportion of weights to be made inactive Notes ----- Acts on the layer in place """ w_copy = np.abs(self.in_proj_weight.detach().cpu().numpy()) b_copy = np.abs(self.in_proj_bias.detach().cpu().numpy()) w_percentile = np.percentile(w_copy, percentile) b_percentile = np.percentile(b_copy, percentile) new_w_mask = torch.Tensor( (w_copy >= w_percentile).astype(int)) new_b_mask = torch.Tensor( (b_copy >= b_percentile).astype(int)) self.in_proj_weight_mask[:] = new_w_mask self.in_proj_bias_mask[:] = new_b_mask self.in_proj_weight = torch.nn.Parameter( self.in_proj_weight * self.in_proj_weight_mask ) self.in_proj_bias = torch.nn.Parameter( self.in_proj_bias * self.in_proj_bias_mask ) self.out_proj.prune(percentile)