from .MaskedDense import MaskedDense
import numpy as np
import torch
[docs]class MaskedMultiHeadAttention(torch.nn.Module):
    """
    Masked Multi-Headed Attention Layer
    """
    def __init__(
            self,
            embed_dim,
            num_heads,
            dropout=0,
            batch_first=False,
            device=None,
            dtype=None
    ):
        """
        Parameters
        ----------
        embed_dim : int
            The embedding dimension
        num_heads : int
            The number of attention heads
        dropout : float (default 0)
            The dropout rate to apply
        batch_first : bool (default False)
            Whether the batch dimension is first
        """
        super().__init__()
        factory_kwargs = {'device': device, 'dtype': dtype}
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = embed_dim // num_heads
        if self.head_dim * self.num_heads != embed_dim:
            raise ValueError('num_heads must evenly divide embed_dim')
        in_proj_weight = torch.Tensor(
            3 * embed_dim, embed_dim).to(**factory_kwargs)
        in_proj_weight = torch.nn.init.xavier_uniform_(in_proj_weight)
        self.in_proj_weight = torch.nn.Parameter(in_proj_weight)
        self.register_buffer('in_proj_weight_mask', torch.ones_like(
            self.in_proj_weight, **factory_kwargs))
        self.in_proj_bias = torch.nn.Parameter(
            torch.zeros((3 * embed_dim), **factory_kwargs))
        self.register_buffer('in_proj_bias_mask', torch.ones_like(
            self.in_proj_bias, **factory_kwargs))
        self.out_proj = MaskedDense(
            embed_dim, embed_dim, **factory_kwargs)
        self.out_proj_weight = self.out_proj.w
        self.out_proj_weight_mask = self.out_proj.w_mask
        self.out_proj_bias = self.out_proj.b
        self.out_proj_bias_mask = self.out_proj.b_mask
[docs]    def forward(
            self,
            query,
            key,
            value,
            key_padding_mask=None,
            need_weights=True,
            attn_mask=None,
            average_attn_weights=True
    ):
        """
        Call the layer on input data
        Parameters
        ----------
        query : torch Tensor
            Query tensor
        key : torch Tensor
            Key tensor
        value : torch Tensor
            Value tensor
        key_padding_mask : None or torch Tensor (default None)
            If specified, a mask indicating which elements in ``key`` to ignore
        need_weights : Bool (default True)
            If specified, returns ``attn_output_weights`` as well as ``attn_outputs``
        attn_mask : None or torch Tensor (default None)
            If specified, a 2D or 3D mask preventing attention
        average_attn_weights : Bool (default True)
            If True, indicates that returned ``attn_weights`` should be averaged across heads
        """
        is_batched = query.dim() == 3
        if self.batch_first and is_batched:
            if key is value:
                if query is key:
                    query = key = value = query.transpose(1, 0)
                else:
                    query, key = [x.transpose(1, 0) for x in (query, key)]
                    value = key
            else:
                query, key, value = [x.transpose(
                    1, 0) for x in (query, key, value)]
        attn_output, attn_output_weights = torch.nn.functional.multi_head_attention_forward(
            query=query,
            key=key,
            value=value,
            embed_dim_to_check=self.embed_dim,
            num_heads=self.num_heads,
            in_proj_weight=self.in_proj_weight * self.in_proj_weight_mask,
            in_proj_bias=self.in_proj_bias * self.in_proj_bias_mask,
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=self.dropout,
            out_proj_weight=self.out_proj_weight * self.out_proj_weight_mask,
            out_proj_bias=self.out_proj_bias * self.out_proj_bias_mask
        )
        if self.batch_first and is_batched:
            return attn_output.transpose(1, 0), attn_output_weights
        else:
            return attn_output, attn_output_weights 
[docs]    def prune(self, percentile):
        """
        Prune the layer by updating the layer's mask
        Parameters
        ----------
        percentile : int
            Integer between 0 and 99 which represents the proportion of weights to be made inactive
        Notes
        -----
        Acts on the layer in place
        """
        w_copy = np.abs(self.in_proj_weight.detach().cpu().numpy())
        b_copy = np.abs(self.in_proj_bias.detach().cpu().numpy())
        w_percentile = np.percentile(w_copy, percentile)
        b_percentile = np.percentile(b_copy, percentile)
        new_w_mask = torch.Tensor(
            (w_copy >= w_percentile).astype(int))
        new_b_mask = torch.Tensor(
            (b_copy >= b_percentile).astype(int))
        self.in_proj_weight_mask[:] = new_w_mask
        self.in_proj_bias_mask[:] = new_b_mask
        self.in_proj_weight = torch.nn.Parameter(
            self.in_proj_weight * self.in_proj_weight_mask
        )
        self.in_proj_bias = torch.nn.Parameter(
            self.in_proj_bias * self.in_proj_bias_mask
        )
        self.out_proj.prune(percentile)