Add head_mask and decoder_head_mask to PyTorch LED (#9856)

* Add {decoder_,}head_mask to LED

* Fix create_custom_forward signatue in encoder

* Add head_mask to longformer

* Add head_mask to longformer to fix dependencies
of LED on Longformer.

* Not working yet

* Add mising one input in longofrmer_modeling.py

* make fix-copies
This commit is contained in:
Daniel Stancl
2021-02-02 20:06:52 +01:00
committed by GitHub
parent d6217fb30c
commit 71bdc076dd
5 changed files with 238 additions and 7 deletions

View File

@@ -473,7 +473,6 @@ class ModelTesterMixin:
arg_names = [*signature.parameters.keys()]
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
inputs["decoder_head_mask"] = head_mask
outputs = model(**inputs, return_dict=True)
# Test that we can get a gradient back for importance score computation

View File

@@ -49,16 +49,24 @@ def prepare_led_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
}
@@ -160,9 +168,10 @@ class LEDModelTester:
model = LEDModel(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
head_mask = inputs_dict["head_mask"]
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
@@ -258,7 +267,6 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_generative_model_classes = (LEDForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
test_missing_keys = False
def setUp(self):

View File

@@ -273,7 +273,6 @@ class LongformerModelTester:
@require_torch
class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False # pruning is not supported
test_headmasking = False # head masking is not supported
test_torchscript = False
all_model_classes = (