From 3ee431dd4c720e67e35a449b453d3dc2b15ccfff Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 26 Mar 2020 21:34:15 -0400 Subject: [PATCH] [Bart/Memory] Two separate, smaller decoder attention masks (#3371) --- src/transformers/modeling_bart.py | 100 +++++++++++++++--------------- tests/test_modeling_bart.py | 53 +++++++--------- 2 files changed, 71 insertions(+), 82 deletions(-) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index e4885682be..fe374d79da 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -74,39 +74,37 @@ BART_INPUTS_DOCSTRING = r""" ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper. - decoder_attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, 1, tgt_seq_len, tgt_seq_len)`, `optional`, defaults to :obj:`None`): - Default behavior: generate a tensor that ignores pad tokens and future tokens, as in the paper. + decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): + Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify. See diagram 1 in the paper for more info on the default strategy """ -LARGE_NEGATIVE = -1e8 + + +def invert_mask(attention_mask): + assert attention_mask.dim() == 2 + return attention_mask.eq(0) def _prepare_bart_decoder_inputs( - config, input_ids, decoder_input_ids=None, decoder_attn_mask=None, mask_dtype=None, + config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32 ): - """Prepare masks that ignore padding tokens in the decoder and a causal lm mask for the decoder if + """Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if none are provided. This mimics the default behavior in fairseq. To override it pass in masks. Note: this is not called during generation """ pad_token_id = config.pad_token_id - need_causal_mask = not config.output_past if decoder_input_ids is None: decoder_input_ids = shift_tokens_right(input_ids, pad_token_id) - bsz, tgt_len = decoder_input_ids.size()[:2] - if decoder_attn_mask is None: + bsz, tgt_len = decoder_input_ids.size() + if decoder_padding_mask is None: decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id) - if need_causal_mask: - causal_lm_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1) - else: - causal_lm_mask = None - new_shape = (bsz, tgt_len, tgt_len) - # make it broadcastable so can just be added to the attention coefficients - decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape).to(device=input_ids.device) - if mask_dtype is not None: - decoder_attn_mask = decoder_attn_mask.to(mask_dtype) - assert decoder_attn_mask is None or decoder_attn_mask.shape == (bsz, 1, tgt_len, tgt_len) - return decoder_input_ids, decoder_attn_mask + else: + decoder_padding_mask = invert_mask(decoder_padding_mask) + causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to( + dtype=causal_mask_dtype, device=decoder_input_ids.device + ) + return decoder_input_ids, decoder_padding_mask, causal_mask class PretrainedBartModel(PreTrainedModel): @@ -130,12 +128,9 @@ class PretrainedBartModel(PreTrainedModel): def dummy_inputs(self): pad_token = self.config.pad_token_id input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) - decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(self.config, input_ids) dummy_inputs = { - "decoder_input_ids": decoder_input_ids, "attention_mask": input_ids.ne(pad_token), "input_ids": input_ids, - "decoder_attention_mask": decoder_attn_mask, } return dummy_inputs @@ -153,21 +148,6 @@ def _check_shapes(shape_1, shape2): raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2)) -def _combine_masks(key_padding_mask, causal_lm_mask, targ_size): - """Make one mask of shape (bsz, 1, tgt_len, src_len) """ - a = torch.zeros(targ_size) # targ_size is(bsz, tgt_len, src_len) - b = torch.zeros(targ_size) - if key_padding_mask is not None: # (bsz, tgt_len) -> targ_size - _check_shapes(key_padding_mask.shape, targ_size[:2]) - reshaped = key_padding_mask.unsqueeze(2).expand(*targ_size) - a[reshaped] = LARGE_NEGATIVE - - if causal_lm_mask is not None: # (tgt_len, src_len) -> targ_size - _check_shapes(causal_lm_mask.shape, targ_size[-2:]) - b = causal_lm_mask.unsqueeze(0).expand(*targ_size) - return (a + b).unsqueeze(1).clamp(LARGE_NEGATIVE,) - - def shift_tokens_right(input_ids, pad_token_id): """Shift input ids one token to the right, and wrap the last non pad token (usually ).""" prev_output_tokens = input_ids.clone() @@ -281,8 +261,7 @@ class BartEncoder(nn.Module): """ # check attention mask and invert if attention_mask is not None: - assert attention_mask.dim() == 2 - attention_mask = attention_mask.eq(0) + attention_mask = invert_mask(attention_mask) inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions(input_ids) @@ -339,7 +318,13 @@ class DecoderLayer(nn.Module): self.final_layer_norm = LayerNorm(self.embed_dim) def forward( - self, x, encoder_hidden_states, encoder_attn_mask=None, layer_state=None, attention_mask=None, + self, + x, + encoder_hidden_states, + encoder_attn_mask=None, + layer_state=None, + causal_mask=None, + decoder_padding_mask=None, ): residual = x @@ -347,7 +332,12 @@ class DecoderLayer(nn.Module): layer_state = {} # next line mutates layer state x, self_attn_weights = self.self_attn( - query=x, key=x, layer_state=layer_state, attn_mask=attention_mask, need_weights=self.output_attentions + query=x, + key=x, + layer_state=layer_state, + key_padding_mask=decoder_padding_mask, + attn_mask=causal_mask, + need_weights=self.output_attentions, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x @@ -412,7 +402,8 @@ class BartDecoder(nn.Module): input_ids, encoder_hidden_states, encoder_padding_mask, - combined_mask, + decoder_padding_mask, + decoder_causal_mask, decoder_cached_states=None, generation_mode=False, **unused @@ -437,8 +428,7 @@ class BartDecoder(nn.Module): """ # check attention mask and invert if encoder_padding_mask is not None: - assert encoder_padding_mask.dim() == 2 - encoder_padding_mask = encoder_padding_mask.eq(0) + encoder_padding_mask = invert_mask(encoder_padding_mask) # embed positions positions = self.embed_positions(input_ids, generation_mode=generation_mode) @@ -458,7 +448,6 @@ class BartDecoder(nn.Module): all_hidden_states = () all_self_attns = () next_decoder_cache = [] - for i, decoder_layer in enumerate(self.layers): decoder_layer # type: DecoderLayer # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -468,7 +457,12 @@ class BartDecoder(nn.Module): layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None x, layer_self_attn, layer_past = decoder_layer( - x, encoder_hidden_states, encoder_padding_mask, layer_state=layer_state, attention_mask=combined_mask, + x, + encoder_hidden_states, + encoder_attn_mask=encoder_padding_mask, + decoder_padding_mask=decoder_padding_mask, + layer_state=layer_state, + causal_mask=decoder_causal_mask, ) if self.output_past: @@ -736,6 +730,8 @@ def _filter_out_falsey_values(tup) -> Tuple: # Public API +def _get_shape(t): + return getattr(t, "shape", None) @add_start_docstrings( @@ -769,13 +765,16 @@ class BartModel(PretrainedBartModel): # make masks if user doesn't supply if not generation_mode: - decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs( + decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs( self.config, input_ids, decoder_input_ids=decoder_input_ids, - decoder_attn_mask=decoder_attention_mask, - mask_dtype=self.shared.weight.dtype, + decoder_padding_mask=decoder_attention_mask, + causal_mask_dtype=self.shared.weight.dtype, ) + else: + decoder_padding_mask, causal_mask = None, None + assert decoder_input_ids is not None if encoder_outputs is None: encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) @@ -785,7 +784,8 @@ class BartModel(PretrainedBartModel): decoder_input_ids, encoder_outputs[0], attention_mask, - decoder_attention_mask, + decoder_padding_mask, + decoder_causal_mask=causal_mask, decoder_cached_states=decoder_cached_states, generation_mode=generation_mode, ) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 4a807286f3..c463e4df3b 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -36,8 +36,8 @@ if is_torch_available(): from transformers.modeling_bart import ( BART_PRETRAINED_MODEL_ARCHIVE_MAP, shift_tokens_right, + invert_mask, _prepare_bart_decoder_inputs, - LARGE_NEGATIVE, ) from transformers.tokenization_bart import BartTokenizer @@ -123,10 +123,9 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase): def test_config(self): self.config_tester.run_common_tests() - def test_advanced_inputs(self): + def test_initialization_more(self): # (config, input_ids, token_type_ids, input_mask, *unused) = \ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, inputs_dict["input_ids"]) model = BartModel(config) model.to(torch_device) model.eval() @@ -142,9 +141,17 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase): _check_var(model.encoder.layers[0].fc1) _check_var(model.encoder.embed_positions) + def test_advanced_inputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + inputs_dict["input_ids"][:, -2:] = config.pad_token_id + decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs( + config, inputs_dict["input_ids"] + ) + model = BartModel(config).to(torch_device).eval() + decoder_features_with_created_mask = model(**inputs_dict)[0] decoder_features_with_passed_mask = model( - decoder_attention_mask=decoder_attn_mask, decoder_input_ids=decoder_input_ids, **inputs_dict + decoder_attention_mask=invert_mask(decoder_attn_mask), decoder_input_ids=decoder_input_ids, **inputs_dict )[0] _assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask) useless_mask = torch.zeros_like(decoder_attn_mask) @@ -238,7 +245,7 @@ class BartHeadTests(unittest.TestCase): lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device) lm_model = BartForConditionalGeneration(config) lm_model.to(torch_device) - loss, logits, enc_features = lm_model(input_ids=input_ids, lm_labels=lm_labels, decoder_input_ids=input_ids) + loss, logits, enc_features = lm_model(input_ids=input_ids, lm_labels=lm_labels) expected_shape = (batch_size, input_ids.shape[1], config.vocab_size) self.assertEqual(logits.shape, expected_shape) self.assertIsInstance(loss.item(), float) @@ -336,41 +343,23 @@ class BartHeadTests(unittest.TestCase): model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) def test_dummy_inputs(self): - config, *_ = self._get_config_and_data(output_past=True) + config, *_ = self._get_config_and_data() model = BartForConditionalGeneration(config).eval().to(torch_device) model(**model.dummy_inputs) def test_prepare_bart_decoder_inputs(self): config, *_ = self._get_config_and_data(output_past=False) - input_ids = _long_tensor(([4, 4, 2])) # only used for .device if decoder_input_ids is passed + input_ids = _long_tensor(([4, 4, 2])) decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]]) - ignore = LARGE_NEGATIVE - decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids) - expected_mask = torch.tensor( - [ - [0, ignore, ignore], - [0, 0, ignore], - [ignore, ignore, ignore], # never attend to the final token, because its pad - ] - ).to(input_ids.device) - self.assertEqual(decoder_attn_mask.size(), (1, 1, 3, 3)) - self.assertTrue(torch.eq(expected_mask, decoder_attn_mask).all()) - - # Test no causal mask - config, *_ = self._get_config_and_data(output_past=True) - expected_just_padding_mask = torch.tensor( - [[0, 0, 0], [0, 0, 0], [ignore, ignore, ignore]] # never attend to the final token, because its pad - ).to(input_ids.device) - _, decoder_attn_mask_no_causal_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids) - self.assertEqual(decoder_attn_mask_no_causal_mask.size(), (1, 1, 3, 3)) - self.assertTrue(torch.eq(expected_just_padding_mask, decoder_attn_mask_no_causal_mask).all()) - - decoder_input_ids = _long_tensor([[0, 26388, 4133, 2]]) - # Attend to everything if no pad tokens and no causal mask - _, decoder_attn_mask_no_padding_no_causal_mask = _prepare_bart_decoder_inputs( + ignore = float("-inf") + decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs( config, input_ids, decoder_input_ids ) - self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all()) + expected_causal_mask = torch.tensor( + [[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad + ).to(input_ids.device) + self.assertEqual(decoder_attn_mask.size(), decoder_input_ids.size()) + self.assertTrue(torch.eq(expected_causal_mask, causal_mask).all()) def test_resize_tokens_embeddings_more(self): config, input_ids, _ = self._get_config_and_data()