From 58918c76f495aef6bc6bd0c9bc315c1e5c238b74 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 23 Jun 2020 11:35:42 -0400 Subject: [PATCH] [bart] add config.extra_pos_embeddings to facilitate reuse (#5190) --- src/transformers/configuration_bart.py | 4 ++++ src/transformers/modeling_bart.py | 31 +++++++++++++------------- src/transformers/modeling_roberta.py | 15 ++++++++++++- src/transformers/modeling_utils.py | 14 ------------ tests/test_modeling_roberta.py | 3 +-- 5 files changed, 35 insertions(+), 32 deletions(-) diff --git a/src/transformers/configuration_bart.py b/src/transformers/configuration_bart.py index 5f13f8ccb5..be87589923 100644 --- a/src/transformers/configuration_bart.py +++ b/src/transformers/configuration_bart.py @@ -41,6 +41,7 @@ class BartConfig(PretrainedConfig): def __init__( self, activation_dropout=0.0, + extra_pos_embeddings=2, activation_function="gelu", vocab_size=50265, d_model=1024, @@ -118,6 +119,9 @@ class BartConfig(PretrainedConfig): # Classifier stuff self.classif_dropout = classifier_dropout + # pos embedding offset + self.extra_pos_embeddings = self.pad_token_id + 1 + @property def num_attention_heads(self) -> int: return self.encoder_attention_heads diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 9e1812be8e..42ea80ce11 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss from .activations import ACT2FN from .configuration_bart import BartConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids +from .modeling_utils import PreTrainedModel logger = logging.getLogger(__name__) @@ -96,6 +96,7 @@ BART_INPUTS_DOCSTRING = r""" def invert_mask(attention_mask): + """Turns 1->0, 0->1, False->True, True-> False""" assert attention_mask.dim() == 2 return attention_mask.eq(0) @@ -261,7 +262,7 @@ class BartEncoder(nn.Module): ) else: self.embed_positions = LearnedPositionalEmbedding( - config.max_position_embeddings, embed_dim, self.padding_idx, + config.max_position_embeddings, embed_dim, self.padding_idx, config.extra_pos_embeddings, ) self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() @@ -435,7 +436,7 @@ class BartDecoder(nn.Module): ) else: self.embed_positions = LearnedPositionalEmbedding( - config.max_position_embeddings, config.d_model, self.padding_idx, + config.max_position_embeddings, config.d_model, self.padding_idx, config.extra_pos_embeddings, ) self.layers = nn.ModuleList( [DecoderLayer(config) for _ in range(config.decoder_layers)] @@ -745,23 +746,23 @@ class LearnedPositionalEmbedding(nn.Embedding): position ids are passed to the forward function. """ - def __init__( - self, num_embeddings: int, embedding_dim: int, padding_idx: int, - ): - # if padding_idx is specified then offset the embedding ids by - # this index and adjust num_embeddings appropriately + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset): + # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models dont have this hack + self.offset = offset assert padding_idx is not None - num_embeddings += padding_idx + 1 # WHY? + num_embeddings += offset super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx) - def forward(self, input, use_cache=False): + def forward(self, input_ids, use_cache=False): """Input is expected to be of size [bsz x seqlen].""" - if use_cache: # the position is our current step in the decoded sequence - pos = int(self.padding_idx + input.size(1)) - positions = input.data.new(1, 1).fill_(pos) + bsz, seq_len = input_ids.shape[:2] + if use_cache: + positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing else: - positions = create_position_ids_from_input_ids(input, self.padding_idx) - return super().forward(positions) + # starts at 0, ends at 1-seq_len + positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device) + return super().forward(positions + self.offset) def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True): diff --git a/src/transformers/modeling_roberta.py b/src/transformers/modeling_roberta.py index be98fb9cdd..13452b46ae 100644 --- a/src/transformers/modeling_roberta.py +++ b/src/transformers/modeling_roberta.py @@ -26,7 +26,6 @@ from torch.nn import CrossEntropyLoss, MSELoss from .configuration_roberta import RobertaConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_bert import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel, gelu -from .modeling_utils import create_position_ids_from_input_ids logger = logging.getLogger(__name__) @@ -733,3 +732,17 @@ class RobertaForQuestionAnswering(BertPreTrainedModel): outputs = (total_loss,) + outputs return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) + + +def create_position_ids_from_input_ids(input_ids, padding_idx): + """ Replace non-padding symbols with their position numbers. Position numbers begin at + padding_idx+1. Padding symbols are ignored. This is modified from fairseq's + `utils.make_positions`. + + :param torch.Tensor x: + :return torch.Tensor: + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask + return incremental_indices.long() + padding_idx diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e167abd089..befb577317 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2090,20 +2090,6 @@ class SequenceSummary(nn.Module): return output -def create_position_ids_from_input_ids(input_ids, padding_idx): - """ Replace non-padding symbols with their position numbers. Position numbers begin at - padding_idx+1. Padding symbols are ignored. This is modified from fairseq's - `utils.make_positions`. - - :param torch.Tensor x: - :return torch.Tensor: - """ - # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. - mask = input_ids.ne(padding_idx).int() - incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask - return incremental_indices.long() + padding_idx - - def prune_linear_layer(layer, index, dim=0): """ Prune a linear layer (a model parameters) to keep only entries in index. Return the pruned layer as a new layer with requires_grad=True. diff --git a/tests/test_modeling_roberta.py b/tests/test_modeling_roberta.py index e096ad4f29..516670b517 100644 --- a/tests/test_modeling_roberta.py +++ b/tests/test_modeling_roberta.py @@ -34,9 +34,8 @@ if is_torch_available(): RobertaForSequenceClassification, RobertaForTokenClassification, ) - from transformers.modeling_roberta import RobertaEmbeddings + from transformers.modeling_roberta import RobertaEmbeddings, create_position_ids_from_input_ids from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST - from transformers.modeling_utils import create_position_ids_from_input_ids class RobertaModelTester: