Add head_mask/decoder_head_mask for BART (#9569)
* Add head_mask/decoder_head_mask for BART
This branch implement head_mask and decoder_head_mask
for BART-based models. Full list below:
- BART
- MBart
- Blenderbot
- BlenderbotSmall
- Marian
- Pegasus
Everything is accompanied with updated testing.
* Fix test_headmasking for BART models
* Fix text_headmasking for BART-like models
which has only 2 layers in each modules.
The condition
```
self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
```
is, therefore, invalid for encoder-decoder models considering
the `head_mask`
```
head_mask = torch.ones(
self.model_tester.num_hidden_layers,
self.model_tester.num_attention_heads,
device=torch_device,
)
head_mask[0, 0] = 0
head_mask[-1, :-1] = 0
```
specified in the `test_headmasking` test/function.
* Adjust test_modeling_common.py to reflect T5 input args
* Update tests/test_modeling_common.py
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
* Apply suggestions from code review
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
* make style
* make fix-copies
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -53,16 +53,24 @@ def prepare_bart_inputs_dict(
|
||||
decoder_input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -142,9 +150,10 @@ class BartModelTester:
|
||||
model = BartModel(config=config).get_decoder().to(torch_device).eval()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
@@ -393,7 +402,7 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_head_masking = True
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -40,16 +40,24 @@ def prepare_blenderbot_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -129,9 +137,10 @@ class BlenderbotModelTester:
|
||||
model = BlenderbotModel(config=config).get_decoder().to(torch_device).eval()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
@@ -197,7 +206,7 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
||||
all_generative_model_classes = (BlenderbotForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_head_masking = True
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -48,16 +48,24 @@ def prepare_blenderbot_small_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -137,9 +145,10 @@ class BlenderbotSmallModelTester:
|
||||
model = BlenderbotSmallModel(config=config).get_decoder().to(torch_device).eval()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
@@ -205,7 +214,7 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, unittest
|
||||
all_generative_model_classes = (BlenderbotSmallForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_head_masking = True
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -204,9 +204,13 @@ class ModelTesterMixin:
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
"encoder_outputs",
|
||||
]
|
||||
self.assertListEqual(arg_names[:5], expected_arg_names)
|
||||
expected_arg_names.extend(
|
||||
["head_mask", "decoder_head_mask", "encoder_outputs"]
|
||||
if "head_mask" and "decoder_head_mask" in arg_names
|
||||
else ["encoder_outputs"]
|
||||
)
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
else:
|
||||
expected_arg_names = ["input_ids"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
@@ -395,7 +399,6 @@ class ModelTesterMixin:
|
||||
attention_mask = inputs["attention_mask"]
|
||||
decoder_input_ids = inputs["decoder_input_ids"]
|
||||
decoder_attention_mask = inputs["decoder_attention_mask"]
|
||||
|
||||
traced_model = torch.jit.trace(
|
||||
model, (input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||
)
|
||||
@@ -465,6 +468,11 @@ class ModelTesterMixin:
|
||||
head_mask.requires_grad_(requires_grad=True)
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
|
||||
inputs["head_mask"] = head_mask
|
||||
if model.config.is_encoder_decoder:
|
||||
signature = inspect.signature(model.forward)
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
|
||||
inputs["decoder_head_mask"] = head_mask
|
||||
|
||||
outputs = model(**inputs, return_dict=True)
|
||||
|
||||
@@ -474,24 +482,31 @@ class ModelTesterMixin:
|
||||
output.backward()
|
||||
multihead_outputs = head_mask.grad
|
||||
|
||||
attentions = outputs[-1]
|
||||
|
||||
# Remove Nan
|
||||
for t in attentions:
|
||||
self.assertLess(
|
||||
torch.sum(torch.isnan(t)), t.numel() / 4
|
||||
) # Check we don't have more than 25% nans (arbitrary)
|
||||
attentions = [
|
||||
t.masked_fill(torch.isnan(t), 0.0) for t in attentions
|
||||
] # remove them (the test is less complete)
|
||||
|
||||
self.assertIsNotNone(multihead_outputs)
|
||||
self.assertEqual(len(multihead_outputs), self.model_tester.num_hidden_layers)
|
||||
self.assertAlmostEqual(attentions[0][..., 0, :, :].flatten().sum().item(), 0.0)
|
||||
self.assertNotEqual(attentions[0][..., -1, :, :].flatten().sum().item(), 0.0)
|
||||
self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
|
||||
self.assertAlmostEqual(attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0)
|
||||
self.assertNotEqual(attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
|
||||
|
||||
def check_attentions_validity(attentions):
|
||||
# Remove Nan
|
||||
for t in attentions:
|
||||
self.assertLess(
|
||||
torch.sum(torch.isnan(t)), t.numel() / 4
|
||||
) # Check we don't have more than 25% nans (arbitrary)
|
||||
attentions = [
|
||||
t.masked_fill(torch.isnan(t), 0.0) for t in attentions
|
||||
] # remove them (the test is less complete)
|
||||
|
||||
self.assertAlmostEqual(attentions[0][..., 0, :, :].flatten().sum().item(), 0.0)
|
||||
self.assertNotEqual(attentions[0][..., -1, :, :].flatten().sum().item(), 0.0)
|
||||
if len(attentions) > 2: # encoder-decoder models have only 2 layers in each module
|
||||
self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
|
||||
self.assertAlmostEqual(attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0)
|
||||
self.assertNotEqual(attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
check_attentions_validity(outputs.encoder_attentions)
|
||||
check_attentions_validity(outputs.decoder_attentions)
|
||||
else:
|
||||
check_attentions_validity(outputs.attentions)
|
||||
|
||||
def test_head_pruning(self):
|
||||
if not self.test_pruning:
|
||||
|
||||
@@ -54,16 +54,24 @@ def prepare_marian_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -146,9 +154,10 @@ class MarianModelTester:
|
||||
model = MarianModel(config=config).get_decoder().to(torch_device).eval()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
@@ -214,7 +223,7 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
||||
all_generative_model_classes = (MarianMTModel,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_head_masking = True
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -49,16 +49,24 @@ def prepare_mbart_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -138,9 +146,10 @@ class MBartModelTester:
|
||||
model = MBartModel(config=config).get_decoder().to(torch_device).eval()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
@@ -210,7 +219,7 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
||||
all_generative_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_head_masking = True
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -41,16 +41,24 @@ def prepare_pegasus_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -130,9 +138,10 @@ class PegasusModelTester:
|
||||
model = PegasusModel(config=config).get_decoder().to(torch_device).eval()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
@@ -198,7 +207,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_head_masking = True
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user