[Bart/Memory] Two separate, smaller decoder attention masks (#3371)
This commit is contained in:
@@ -36,8 +36,8 @@ if is_torch_available():
|
||||
from transformers.modeling_bart import (
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
shift_tokens_right,
|
||||
invert_mask,
|
||||
_prepare_bart_decoder_inputs,
|
||||
LARGE_NEGATIVE,
|
||||
)
|
||||
from transformers.tokenization_bart import BartTokenizer
|
||||
|
||||
@@ -123,10 +123,9 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_advanced_inputs(self):
|
||||
def test_initialization_more(self):
|
||||
# (config, input_ids, token_type_ids, input_mask, *unused) = \
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, inputs_dict["input_ids"])
|
||||
model = BartModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -142,9 +141,17 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
_check_var(model.encoder.layers[0].fc1)
|
||||
_check_var(model.encoder.embed_positions)
|
||||
|
||||
def test_advanced_inputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
inputs_dict["input_ids"][:, -2:] = config.pad_token_id
|
||||
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
|
||||
config, inputs_dict["input_ids"]
|
||||
)
|
||||
model = BartModel(config).to(torch_device).eval()
|
||||
|
||||
decoder_features_with_created_mask = model(**inputs_dict)[0]
|
||||
decoder_features_with_passed_mask = model(
|
||||
decoder_attention_mask=decoder_attn_mask, decoder_input_ids=decoder_input_ids, **inputs_dict
|
||||
decoder_attention_mask=invert_mask(decoder_attn_mask), decoder_input_ids=decoder_input_ids, **inputs_dict
|
||||
)[0]
|
||||
_assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask)
|
||||
useless_mask = torch.zeros_like(decoder_attn_mask)
|
||||
@@ -238,7 +245,7 @@ class BartHeadTests(unittest.TestCase):
|
||||
lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
|
||||
lm_model = BartForConditionalGeneration(config)
|
||||
lm_model.to(torch_device)
|
||||
loss, logits, enc_features = lm_model(input_ids=input_ids, lm_labels=lm_labels, decoder_input_ids=input_ids)
|
||||
loss, logits, enc_features = lm_model(input_ids=input_ids, lm_labels=lm_labels)
|
||||
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
self.assertIsInstance(loss.item(), float)
|
||||
@@ -336,41 +343,23 @@ class BartHeadTests(unittest.TestCase):
|
||||
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||||
|
||||
def test_dummy_inputs(self):
|
||||
config, *_ = self._get_config_and_data(output_past=True)
|
||||
config, *_ = self._get_config_and_data()
|
||||
model = BartForConditionalGeneration(config).eval().to(torch_device)
|
||||
model(**model.dummy_inputs)
|
||||
|
||||
def test_prepare_bart_decoder_inputs(self):
|
||||
config, *_ = self._get_config_and_data(output_past=False)
|
||||
input_ids = _long_tensor(([4, 4, 2])) # only used for .device if decoder_input_ids is passed
|
||||
input_ids = _long_tensor(([4, 4, 2]))
|
||||
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
|
||||
ignore = LARGE_NEGATIVE
|
||||
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids)
|
||||
expected_mask = torch.tensor(
|
||||
[
|
||||
[0, ignore, ignore],
|
||||
[0, 0, ignore],
|
||||
[ignore, ignore, ignore], # never attend to the final token, because its pad
|
||||
]
|
||||
).to(input_ids.device)
|
||||
self.assertEqual(decoder_attn_mask.size(), (1, 1, 3, 3))
|
||||
self.assertTrue(torch.eq(expected_mask, decoder_attn_mask).all())
|
||||
|
||||
# Test no causal mask
|
||||
config, *_ = self._get_config_and_data(output_past=True)
|
||||
expected_just_padding_mask = torch.tensor(
|
||||
[[0, 0, 0], [0, 0, 0], [ignore, ignore, ignore]] # never attend to the final token, because its pad
|
||||
).to(input_ids.device)
|
||||
_, decoder_attn_mask_no_causal_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids)
|
||||
self.assertEqual(decoder_attn_mask_no_causal_mask.size(), (1, 1, 3, 3))
|
||||
self.assertTrue(torch.eq(expected_just_padding_mask, decoder_attn_mask_no_causal_mask).all())
|
||||
|
||||
decoder_input_ids = _long_tensor([[0, 26388, 4133, 2]])
|
||||
# Attend to everything if no pad tokens and no causal mask
|
||||
_, decoder_attn_mask_no_padding_no_causal_mask = _prepare_bart_decoder_inputs(
|
||||
ignore = float("-inf")
|
||||
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
|
||||
config, input_ids, decoder_input_ids
|
||||
)
|
||||
self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all())
|
||||
expected_causal_mask = torch.tensor(
|
||||
[[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad
|
||||
).to(input_ids.device)
|
||||
self.assertEqual(decoder_attn_mask.size(), decoder_input_ids.size())
|
||||
self.assertTrue(torch.eq(expected_causal_mask, causal_mask).all())
|
||||
|
||||
def test_resize_tokens_embeddings_more(self):
|
||||
config, input_ids, _ = self._get_config_and_data()
|
||||
|
||||
Reference in New Issue
Block a user