Add head_mask and decoder_head_mask to PyTorch LED (#9856)
* Add {decoder_,}head_mask to LED
* Fix create_custom_forward signatue in encoder
* Add head_mask to longformer
* Add head_mask to longformer to fix dependencies
of LED on Longformer.
* Not working yet
* Add mising one input in longofrmer_modeling.py
* make fix-copies
This commit is contained in:
@@ -473,7 +473,6 @@ class ModelTesterMixin:
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
|
||||
inputs["decoder_head_mask"] = head_mask
|
||||
|
||||
outputs = model(**inputs, return_dict=True)
|
||||
|
||||
# Test that we can get a gradient back for importance score computation
|
||||
|
||||
@@ -49,16 +49,24 @@ def prepare_led_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -160,9 +168,10 @@ class LEDModelTester:
|
||||
model = LEDModel(config=config).get_decoder().to(torch_device).eval()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
@@ -258,7 +267,6 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (LEDForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -273,7 +273,6 @@ class LongformerModelTester:
|
||||
@require_torch
|
||||
class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_pruning = False # pruning is not supported
|
||||
test_headmasking = False # head masking is not supported
|
||||
test_torchscript = False
|
||||
|
||||
all_model_classes = (
|
||||
|
||||
Reference in New Issue
Block a user