From e4c07faf0ae125559c30998c3513dd85f012c88e Mon Sep 17 00:00:00 2001 From: Victor SANH Date: Wed, 27 May 2020 18:24:23 -0400 Subject: [PATCH] add sparsity modules --- .../movement-pruning/emmental/__init__.py | 11 + .../emmental/configuration_bert_masked.py | 72 ++ .../emmental/modeling_bert_masked.py | 1019 +++++++++++++++++ .../emmental/modules/__init__.py | 2 + .../emmental/modules/binarizer.py | 144 +++ .../emmental/modules/masked_nn.py | 107 ++ 6 files changed, 1355 insertions(+) create mode 100644 examples/movement-pruning/emmental/__init__.py create mode 100644 examples/movement-pruning/emmental/configuration_bert_masked.py create mode 100644 examples/movement-pruning/emmental/modeling_bert_masked.py create mode 100644 examples/movement-pruning/emmental/modules/__init__.py create mode 100644 examples/movement-pruning/emmental/modules/binarizer.py create mode 100644 examples/movement-pruning/emmental/modules/masked_nn.py diff --git a/examples/movement-pruning/emmental/__init__.py b/examples/movement-pruning/emmental/__init__.py new file mode 100644 index 0000000000..ee0f1a1334 --- /dev/null +++ b/examples/movement-pruning/emmental/__init__.py @@ -0,0 +1,11 @@ +from .modules import * + +from .configuration_bert_masked import MaskedBertConfig + +from .modeling_bert_masked import ( + MaskedBertModel, + MaskedBertForQuestionAnswering, + MaskedBertForSequenceClassification, + MaskedBertForTokenClassification, + MaskedBertForMultipleChoice, +) diff --git a/examples/movement-pruning/emmental/configuration_bert_masked.py b/examples/movement-pruning/emmental/configuration_bert_masked.py new file mode 100644 index 0000000000..bc17d4e05e --- /dev/null +++ b/examples/movement-pruning/emmental/configuration_bert_masked.py @@ -0,0 +1,72 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Masked BERT model configuration. It replicates the class `~transformers.BertConfig` +and adapts it to the specificities of MaskedBert (`pruning_method`, `mask_init` and `mask_scale`.""" + + +import logging + +from transformers.configuration_utils import PretrainedConfig +from transformers.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP + +logger = logging.getLogger(__name__) + + +class MaskedBertConfig(PretrainedConfig): + """ + A class replicating the `~transformers.BertConfig` with additional parameters for pruning/masking configuration. + """ + + pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP + model_type = "masked_bert" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + pruning_method="topK", + mask_init="constant", + mask_scale=0.0, + **kwargs + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.pruning_method = pruning_method + self.mask_init = mask_init + self.mask_scale = mask_scale diff --git a/examples/movement-pruning/emmental/modeling_bert_masked.py b/examples/movement-pruning/emmental/modeling_bert_masked.py new file mode 100644 index 0000000000..4915265290 --- /dev/null +++ b/examples/movement-pruning/emmental/modeling_bert_masked.py @@ -0,0 +1,1019 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Masked Version of BERT. It replaces the `torch.nn.Linear` layers with +:class:`~emmental.MaskedLinear` and add an additional parameters in the forward pass to +compute the adaptive mask. +Built on top of `transformers.modeling_bert`""" + + +import logging +import math + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss + +from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable +from transformers.modeling_utils import PreTrainedModel, prune_linear_layer +from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP +from transformers.modeling_bert import load_tf_weights_in_bert, ACT2FN, BertLayerNorm + +from emmental import MaskedLinear +from emmental import MaskedBertConfig + +logger = logging.getLogger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + device = input_ids.device if input_ids is not None else inputs_embeds.device + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(input_shape) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + self.output_attentions = config.output_attentions + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = MaskedLinear( + config.hidden_size, + self.all_head_size, + pruning_method=config.pruning_method, + mask_init=config.mask_init, + mask_scale=config.mask_scale, + ) + self.key = MaskedLinear( + config.hidden_size, + self.all_head_size, + pruning_method=config.pruning_method, + mask_init=config.mask_init, + mask_scale=config.mask_scale, + ) + self.value = MaskedLinear( + config.hidden_size, + self.all_head_size, + pruning_method=config.pruning_method, + mask_init=config.mask_init, + mask_scale=config.mask_scale, + ) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + threshold=None, + ): + mixed_query_layer = self.query(hidden_states, threshold=threshold) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + if encoder_hidden_states is not None: + mixed_key_layer = self.key(encoder_hidden_states, threshold=threshold) + mixed_value_layer = self.value(encoder_hidden_states, threshold=threshold) + attention_mask = encoder_attention_mask + else: + mixed_key_layer = self.key(hidden_states, threshold=threshold) + mixed_value_layer = self.value(hidden_states, threshold=threshold) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = MaskedLinear( + config.hidden_size, + config.hidden_size, + pruning_method=config.pruning_method, + mask_init=config.mask_init, + mask_scale=config.mask_scale, + ) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor, threshold): + hidden_states = self.dense(hidden_states, threshold=threshold) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) + heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads + for head in heads: + # Compute how many pruned heads are before the head and move the index accordingly + head = head - sum(1 if h < head else 0 for h in self.pruned_heads) + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index = torch.arange(len(mask))[mask].long() + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + threshold=None, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + threshold=threshold, + ) + attention_output = self.output(self_outputs[0], hidden_states, threshold=threshold) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = MaskedLinear( + config.hidden_size, + config.intermediate_size, + pruning_method=config.pruning_method, + mask_init=config.mask_init, + mask_scale=config.mask_scale, + ) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states, threshold): + hidden_states = self.dense(hidden_states, threshold=threshold) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = MaskedLinear( + config.intermediate_size, + config.hidden_size, + pruning_method=config.pruning_method, + mask_init=config.mask_init, + mask_scale=config.mask_scale, + ) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor, threshold): + hidden_states = self.dense(hidden_states, threshold=threshold) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = BertAttention(config) + self.is_decoder = config.is_decoder + if self.is_decoder: + self.crossattention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + threshold=None, + ): + self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask, threshold=threshold) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.is_decoder and encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + + intermediate_output = self.intermediate(attention_output, threshold=threshold) + layer_output = self.output(intermediate_output, attention_output, threshold=threshold) + outputs = (layer_output,) + outputs + return outputs + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + threshold=None, + ): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + threshold=threshold, + ) + hidden_states = layer_outputs[0] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = (hidden_states,) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states,) + if self.output_attentions: + outputs = outputs + (all_attentions,) + return outputs # last-layer hidden state, (all hidden states), (all attentions) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class MaskedBertPreTrainedModel(PreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + config_class = MaskedBertConfig + pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP + load_tf_weights = load_tf_weights_in_bert + base_model_prefix = "bert" + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +MASKED_BERT_START_DOCSTRING = r""" + This model is a PyTorch `torch.nn.Module `_ sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config (:class:`~emmental.MaskedBertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +MASKED_BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`transformers.BertTokenizer`. + See :func:`transformers.PreTrainedTokenizer.encode` and + :func:`transformers.PreTrainedTokenizer.encode_plus` for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Segment token indices to indicate first and second portions of the inputs. + Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` + corresponds to a `sentence B` token + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range ``[0, config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask + is used in the cross-attention if the model is configured as a decoder. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. +""" + + +@add_start_docstrings( + "The bare Masked Bert Model transformer outputting raw hidden-states without any specific head on top.", + MASKED_BERT_START_DOCSTRING, +) +class MaskedBertModel(MaskedBertPreTrainedModel): + """ + The `MaskedBertModel` class replicates the :class:`~transformers.BertModel` class + and adds specific inputs to compute the adaptive mask on the fly. + Note that we freeze the embeddings modules from their pre-trained values. + """ + + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + self.embeddings.requires_grad_(requires_grad=False) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_callable(MASKED_BERT_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + threshold=None, + ): + r""" + threshold (:obj:`float`): + Threshold value (see :class:`~emmental.MaskedLinear`). + + Return: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~emmental.MaskedBertConfig`) and inputs: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) + further processed by a Linear layer and a Tanh activation function. The Linear + layer weights are trained from the next sentence prediction (classification) + objective during pre-training. + + This output is usually *not* a good summary + of the semantic content of the input, you're often better with averaging or pooling + the sequence of hidden-states for the whole input sequence. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder: + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + causal_mask = causal_mask.to( + attention_mask.dtype + ) # causal and attention masks must have same type with pytorch version < 1.3 + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + elif encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format( + encoder_hidden_shape, encoder_attention_mask.shape + ) + ) + + encoder_extended_attention_mask = encoder_extended_attention_mask.to( + dtype=next(self.parameters()).dtype + ) # fp16 compatibility + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = ( + head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) + ) # We can specify head_mask for each layer + head_mask = head_mask.to( + dtype=next(self.parameters()).dtype + ) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.config.num_hidden_layers + + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + threshold=threshold, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + outputs = (sequence_output, pooled_output,) + encoder_outputs[ + 1: + ] # add hidden_states and attentions if they are here + return outputs # sequence_output, pooled_output, (hidden_states), (attentions) + + +@add_start_docstrings( + """Masked Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of + the pooled output) e.g. for GLUE tasks. """, + MASKED_BERT_START_DOCSTRING, +) +class MaskedBertForSequenceClassification(MaskedBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = MaskedBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_callable(MASKED_BERT_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + threshold=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the sequence classification/regression loss. + Indices should be in :obj:`[0, ..., config.num_labels - 1]`. + If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + threshold (:obj:`float`): + Threshold value (see :class:`~emmental.MaskedLinear`). + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~emmental.MaskedBertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + threshold=threshold, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here + + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + outputs = (loss,) + outputs + + return outputs # (loss), logits, (hidden_states), (attentions) + + +@add_start_docstrings( + """Masked Bert Model with a multiple choice classification head on top (a linear layer on top of + the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, + MASKED_BERT_START_DOCSTRING, +) +class MaskedBertForMultipleChoice(MaskedBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = MaskedBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + self.init_weights() + + @add_start_docstrings_to_callable(MASKED_BERT_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + threshold=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the multiple choice classification loss. + Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension + of the input tensors. (see `input_ids` above) + threshold (:obj:`float`): + Threshold value (see :class:`~emmental.MaskedLinear`). + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~emmental.MaskedBertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided): + Classification loss. + classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): + `num_choices` is the second dimension of the input tensors. (see `input_ids` above). + + Classification scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + """ + num_choices = input_ids.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + threshold=threshold, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + outputs = (loss,) + outputs + + return outputs # (loss), reshaped_logits, (hidden_states), (attentions) + + +@add_start_docstrings( + """Masked Bert Model with a token classification head on top (a linear layer on top of + the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, + MASKED_BERT_START_DOCSTRING, +) +class MaskedBertForTokenClassification(MaskedBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = MaskedBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_callable(MASKED_BERT_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + threshold=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Labels for computing the token classification loss. + Indices should be in ``[0, ..., config.num_labels - 1]``. + threshold (:obj:`float`): + Threshold value (see :class:`~emmental.MaskedLinear`). + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~emmental.MaskedBertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) : + Classification loss. + scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`) + Classification scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + threshold=threshold, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + outputs = (loss,) + outputs + + return outputs # (loss), scores, (hidden_states), (attentions) + + +@add_start_docstrings( + """Masked Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, + MASKED_BERT_START_DOCSTRING, +) +class MaskedBertForQuestionAnswering(MaskedBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = MaskedBertModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_callable(MASKED_BERT_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + threshold=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + threshold (:obj:`float`): + Threshold value (see :class:`~emmental.MaskedLinear`). + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~emmental.MaskedBertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-start scores (before SoftMax). + end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-end scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + threshold=threshold, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + outputs = (start_logits, end_logits,) + outputs[2:] + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + outputs = (total_loss,) + outputs + + return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) diff --git a/examples/movement-pruning/emmental/modules/__init__.py b/examples/movement-pruning/emmental/modules/__init__.py new file mode 100644 index 0000000000..0eb6d7bbdf --- /dev/null +++ b/examples/movement-pruning/emmental/modules/__init__.py @@ -0,0 +1,2 @@ +from .binarizer import ThresholdBinarizer, TopKBinarizer, MagnitudeBinarizer +from .masked_nn import MaskedLinear diff --git a/examples/movement-pruning/emmental/modules/binarizer.py b/examples/movement-pruning/emmental/modules/binarizer.py new file mode 100644 index 0000000000..f6c6a732c4 --- /dev/null +++ b/examples/movement-pruning/emmental/modules/binarizer.py @@ -0,0 +1,144 @@ +# coding=utf-8 +# Copyright 2020-present, AllenAI Authors, University of Illinois Urbana-Champaign, +# Intel Nervana Systems and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Binarizers take a (real value) matrice as input and produce a binary (values in {0,1}) mask of the same shape. +""" + +import torch +from torch import autograd + + +class ThresholdBinarizer(autograd.Function): + """ + Thresholdd binarizer. + Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j} > \tau` + where `\tau` is a real value threshold. + + Implementation is inspired from: + https://github.com/arunmallya/piggyback + Piggyback: Adapting a Single Network to Multiple Tasks by Learning to Mask Weights + Arun Mallya, Dillon Davis, Svetlana Lazebnik + """ + + @staticmethod + def forward(ctx, inputs: torch.tensor, threshold: float, sigmoid: bool): + """ + Args: + inputs (`torch.FloatTensor`) + The input matrix from which the binarizer computes the binary mask. + threshold (`float`) + The threshold value (in R). + sigmoid (`bool`) + If set to ``True``, we apply the sigmoid function to the `inputs` matrix before comparing to `threshold`. + In this case, `threshold` should be a value between 0 and 1. + Returns: + mask (`torch.FloatTensor`) + Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is + retained, 0 - the associated weight is pruned). + """ + nb_elems = inputs.numel() + nb_min = int(0.005 * nb_elems) + 1 + if sigmoid: + mask = (torch.sigmoid(inputs) > threshold).type(inputs.type()) + else: + mask = (inputs > threshold).type(inputs.type()) + if mask.sum() < nb_min: + # We limit the pruning so that at least 0.5% (half a percent) of the weights are remaining + k_threshold = inputs.flatten().kthvalue(max(nb_elems - nb_min, 1)).values + mask = (inputs > k_threshold).type(inputs.type()) + return mask + + @staticmethod + def backward(ctx, gradOutput): + return gradOutput, None, None + + +class TopKBinarizer(autograd.Function): + """ + Top-k Binarizer. + Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j}` + is among the k% highest values of S. + + Implementation is inspired from: + https://github.com/allenai/hidden-networks + What's hidden in a randomly weighted neural network? + Vivek Ramanujan*, Mitchell Wortsman*, Aniruddha Kembhavi, Ali Farhadi, Mohammad Rastegari + """ + + @staticmethod + def forward(ctx, inputs: torch.tensor, threshold: float): + """ + Args: + inputs (`torch.FloatTensor`) + The input matrix from which the binarizer computes the binary mask. + threshold (`float`) + The percentage of weights to keep (the rest is pruned). + `threshold` is a float between 0 and 1. + Returns: + mask (`torch.FloatTensor`) + Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is + retained, 0 - the associated weight is pruned). + """ + # Get the subnetwork by sorting the inputs and using the top threshold % + mask = inputs.clone() + _, idx = inputs.flatten().sort(descending=True) + j = int(threshold * inputs.numel()) + + # flat_out and mask access the same memory. + flat_out = mask.flatten() + flat_out[idx[j:]] = 0 + flat_out[idx[:j]] = 1 + return mask + + @staticmethod + def backward(ctx, gradOutput): + return gradOutput, None + + +class MagnitudeBinarizer(object): + """ + Magnitude Binarizer. + Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j}` + is among the k% highest values of |S| (absolute value). + + Implementation is inspired from https://github.com/NervanaSystems/distiller/blob/2291fdcc2ea642a98d4e20629acb5a9e2e04b4e6/distiller/pruning/automated_gradual_pruner.py#L24 + """ + + @staticmethod + def apply(inputs: torch.tensor, threshold: float): + """ + Args: + inputs (`torch.FloatTensor`) + The input matrix from which the binarizer computes the binary mask. + This input marix is typically the weight matrix. + threshold (`float`) + The percentage of weights to keep (the rest is pruned). + `threshold` is a float between 0 and 1. + Returns: + mask (`torch.FloatTensor`) + Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is + retained, 0 - the associated weight is pruned). + """ + # Get the subnetwork by sorting the inputs and using the top threshold % + mask = inputs.clone() + _, idx = inputs.abs().flatten().sort(descending=True) + j = int(threshold * inputs.numel()) + + # flat_out and mask access the same memory. + flat_out = mask.flatten() + flat_out[idx[j:]] = 0 + flat_out[idx[:j]] = 1 + return mask diff --git a/examples/movement-pruning/emmental/modules/masked_nn.py b/examples/movement-pruning/emmental/modules/masked_nn.py new file mode 100644 index 0000000000..2caf76b389 --- /dev/null +++ b/examples/movement-pruning/emmental/modules/masked_nn.py @@ -0,0 +1,107 @@ +# coding=utf-8 +# Copyright 2020-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Masked Linear module: A fully connected layer that computes an adaptive binary mask on the fly. +The mask (binary or not) is computed at each forward pass and multiplied against +the weight matrix to prune a portion of the weights. +The pruned weight matrix is then multiplied against the inputs (and if necessary, the bias is added). +""" + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn import init + +import math + +from .binarizer import ThresholdBinarizer, TopKBinarizer, MagnitudeBinarizer + + +class MaskedLinear(nn.Linear): + """ + Fully Connected layer with on the fly adaptive mask. + If needed, a score matrix is created to store the importance of each associated weight. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + mask_init: str = "constant", + mask_scale: float = 0.0, + pruning_method: str = "topK", + ): + """ + Args: + in_features (`int`) + Size of each input sample + out_features (`int`) + Size of each output sample + bias (`bool`) + If set to ``False``, the layer will not learn an additive bias. + Default: ``True`` + mask_init (`str`) + The initialization method for the score matrix if a score matrix is needed. + Choices: ["constant", "uniform", "kaiming"] + Default: ``constant`` + mask_scale (`float`) + The initialization parameter for the chosen initialization method `mask_init`. + Default: ``0.`` + pruning_method (`str`) + Method to compute the mask. + Choices: ["topK", "threshold", "sigmoied_threshold", "magnitude", "l0"] + Default: ``topK`` + """ + super(MaskedLinear, self).__init__(in_features=in_features, out_features=out_features, bias=bias) + assert pruning_method in ["topK", "threshold", "sigmoied_threshold", "magnitude", "l0"] + self.pruning_method = pruning_method + + if self.pruning_method in ["topK", "threshold", "sigmoied_threshold", "l0"]: + self.mask_scale = mask_scale + self.mask_init = mask_init + self.mask_scores = nn.Parameter(torch.Tensor(self.weight.size())) + self.init_mask() + + def init_mask(self): + if self.mask_init == "constant": + init.constant_(self.mask_scores, val=self.mask_scale) + elif self.mask_init == "uniform": + init.uniform_(self.mask_scores, a=-self.mask_scale, b=self.mask_scale) + elif self.mask_init == "kaiming": + init.kaiming_uniform_(self.mask_scores, a=math.sqrt(5)) + + def forward(self, input: torch.tensor, threshold: float): + # Get the mask + if self.pruning_method == "topK": + mask = TopKBinarizer.apply(self.mask_scores, threshold) + elif self.pruning_method in ["threshold", "sigmoied_threshold"]: + sig = "sigmoied" in self.pruning_method + mask = ThresholdBinarizer.apply(self.mask_scores, threshold, sig) + elif self.pruning_method == "magnitude": + mask = MagnitudeBinarizer.apply(self.weight, threshold) + elif self.pruning_method == "l0": + l, r, b = -0.1, 1.1, 2 / 3 + if self.training: + u = torch.zeros_like(self.mask_scores).uniform_().clamp(0.0001, 0.9999) + s = torch.sigmoid((u.log() - (1 - u).log() + self.mask_scores) / b) + else: + s = torch.sigmoid(self.mask_scores) + s_bar = s * (r - l) + l + mask = s_bar.clamp(min=0.0, max=1.0) + # Mask weights with computed mask + weight_thresholded = mask * self.weight + # Compute output (linear layer) with masked weights + return F.linear(input, weight_thresholded, self.bias)