Fix usage of head masks by PT encoder-decoder models' generate() function (#11621)
* Add missing head masking for generate() function * Add head_mask, decoder_head_mask and cross_attn_head_mask into prepare_inputs_for_generation for generate() function for multiple encoder-decoder models. * Add test_genereate_with_head_masking * [WIP] Update the new test and handle special cases * make style * Omit ProphetNet test so far * make fix-copies
This commit is contained in:
@@ -409,7 +409,9 @@ class GenerationMixin:
|
|||||||
# retrieve encoder hidden states
|
# retrieve encoder hidden states
|
||||||
encoder = self.get_encoder()
|
encoder = self.get_encoder()
|
||||||
encoder_kwargs = {
|
encoder_kwargs = {
|
||||||
argument: value for argument, value in model_kwargs.items() if not argument.startswith("decoder_")
|
argument: value
|
||||||
|
for argument, value in model_kwargs.items()
|
||||||
|
if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
|
||||||
}
|
}
|
||||||
model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
|
model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|||||||
@@ -1327,6 +1327,8 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
|||||||
past=None,
|
past=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -1342,6 +1344,8 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
|||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2530,6 +2530,8 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
|
|||||||
past=None,
|
past=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -2545,6 +2547,8 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
|
|||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1321,6 +1321,8 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
|
|||||||
past=None,
|
past=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -1336,6 +1338,8 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
|
|||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1296,6 +1296,8 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
|||||||
past=None,
|
past=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -1311,6 +1313,8 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
|||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1215,7 +1215,16 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
self,
|
||||||
|
decoder_input_ids,
|
||||||
|
past=None,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
|
use_cache=None,
|
||||||
|
encoder_outputs=None,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
return {
|
return {
|
||||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||||
@@ -1223,6 +1232,9 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
|||||||
"past_key_values": past,
|
"past_key_values": past,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2356,6 +2356,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
|||||||
past=None,
|
past=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -2371,6 +2373,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
|||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1324,6 +1324,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
|||||||
past=None,
|
past=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -1339,6 +1341,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
|||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1309,6 +1309,8 @@ class MarianMTModel(MarianPreTrainedModel):
|
|||||||
past=None,
|
past=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -1324,6 +1326,8 @@ class MarianMTModel(MarianPreTrainedModel):
|
|||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1327,7 +1327,16 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
self,
|
||||||
|
decoder_input_ids,
|
||||||
|
past=None,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
|
use_cache=None,
|
||||||
|
encoder_outputs=None,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
# cut decoder_input_ids if past is used
|
# cut decoder_input_ids if past is used
|
||||||
if past is not None:
|
if past is not None:
|
||||||
@@ -1339,6 +1348,9 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
|
|||||||
"past_key_values": past,
|
"past_key_values": past,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1312,6 +1312,8 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
|||||||
past=None,
|
past=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -1327,6 +1329,8 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
|||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2020,6 +2020,8 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
|
|||||||
past=None,
|
past=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -2036,6 +2038,8 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
|
|||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1655,7 +1655,16 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
self,
|
||||||
|
input_ids,
|
||||||
|
past=None,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
|
use_cache=None,
|
||||||
|
encoder_outputs=None,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
|
|
||||||
# cut decoder_input_ids if past is used
|
# cut decoder_input_ids if past is used
|
||||||
@@ -1667,6 +1676,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
"past_key_values": past,
|
"past_key_values": past,
|
||||||
"encoder_outputs": encoder_outputs,
|
"encoder_outputs": encoder_outputs,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import inspect
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
@@ -1072,6 +1073,40 @@ class GenerationTesterMixin:
|
|||||||
output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams
|
output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_generate_with_head_masking(self):
|
||||||
|
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
||||||
|
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||||
|
model = model_class(config)
|
||||||
|
# We want to test only encoder-decoder models
|
||||||
|
if not config.is_encoder_decoder:
|
||||||
|
continue
|
||||||
|
|
||||||
|
head_masking = {
|
||||||
|
"head_mask": torch.zeros(config.encoder_layers, config.encoder_attention_heads),
|
||||||
|
"decoder_head_mask": torch.zeros(config.decoder_layers, config.decoder_attention_heads),
|
||||||
|
"cross_attn_head_mask": torch.zeros(config.decoder_layers, config.decoder_attention_heads),
|
||||||
|
}
|
||||||
|
|
||||||
|
signature = inspect.signature(model.forward)
|
||||||
|
# We want to test only models where encoder/decoder head masking is implemented
|
||||||
|
if set(head_masking.keys()) < set([*signature.parameters.keys()]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
||||||
|
out = model.generate(
|
||||||
|
input_ids,
|
||||||
|
num_beams=1,
|
||||||
|
max_length=max_length,
|
||||||
|
output_attentions=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
**{name: mask},
|
||||||
|
)
|
||||||
|
# We check the state of decoder_attentions and cross_attentions just from the last step
|
||||||
|
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||||
|
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
||||||
|
|
||||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
num_sequences_in_output = batch_size * num_return_sequences
|
num_sequences_in_output = batch_size * num_return_sequences
|
||||||
|
|||||||
@@ -1088,6 +1088,10 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
|||||||
self.assertIsNotNone(encoder_hidden_states.grad)
|
self.assertIsNotNone(encoder_hidden_states.grad)
|
||||||
self.assertIsNotNone(encoder_attentions.grad)
|
self.assertIsNotNone(encoder_attentions.grad)
|
||||||
|
|
||||||
|
def test_generate_with_head_masking(self):
|
||||||
|
"""Generating with head_masking has not been implemented for ProphetNet models yet."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
|
|||||||
@@ -600,6 +600,37 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
input_names=["input_ids", "decoder_input_ids"],
|
input_names=["input_ids", "decoder_input_ids"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_generate_with_head_masking(self):
|
||||||
|
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
config = config_and_inputs[0]
|
||||||
|
max_length = config_and_inputs[1].shape[-1] + 3
|
||||||
|
model = T5ForConditionalGeneration(config)
|
||||||
|
|
||||||
|
head_masking = {
|
||||||
|
"head_mask": torch.zeros(config.num_layers, config.num_heads),
|
||||||
|
"decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads),
|
||||||
|
"cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads),
|
||||||
|
}
|
||||||
|
|
||||||
|
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
||||||
|
head_masks = {name: mask}
|
||||||
|
# Explicitly pass decoder_head_mask as it is required from T5 model when head_mask specified
|
||||||
|
if name == "head_mask":
|
||||||
|
head_masks["decoder_head_mask"] = torch.ones(config.num_decoder_layers, config.num_heads)
|
||||||
|
|
||||||
|
out = model.generate(
|
||||||
|
config_and_inputs[1],
|
||||||
|
num_beams=1,
|
||||||
|
max_length=max_length,
|
||||||
|
output_attentions=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
**head_masks,
|
||||||
|
)
|
||||||
|
# We check the state of decoder_attentions and cross_attentions just from the last step
|
||||||
|
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||||
|
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
||||||
|
|
||||||
|
|
||||||
class T5EncoderOnlyModelTester:
|
class T5EncoderOnlyModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Reference in New Issue
Block a user