TF: use the correct config with (...)EncoderDecoder models (#18097)
This commit is contained in:
@@ -351,32 +351,9 @@ class EncoderDecoderMixin:
|
||||
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
labels,
|
||||
**kwargs
|
||||
def _check_output_with_attentions(
|
||||
self, outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
)
|
||||
|
||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
||||
|
||||
@@ -408,6 +385,85 @@ class EncoderDecoderMixin:
|
||||
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
labels,
|
||||
**kwargs
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
)
|
||||
self._check_output_with_attentions(
|
||||
outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_output_attentions_from_config(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
labels,
|
||||
**kwargs
|
||||
):
|
||||
# Similar to `check_encoder_decoder_model_output_attentions`, but with `output_attentions` triggered from the
|
||||
# config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
|
||||
# from the inner models' configurations.
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.config.output_attentions = True # model config -> won't work
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
self.assertTrue(
|
||||
all(
|
||||
key not in outputs_encoder_decoder
|
||||
for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
)
|
||||
)
|
||||
|
||||
config.output_attentions = True # inner model config -> will work
|
||||
decoder_config.output_attentions = True
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
self._check_output_with_attentions(
|
||||
outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
@@ -543,6 +599,10 @@ class EncoderDecoderMixin:
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_output_attentions_from_config(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_output_attentions_from_config(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_generate(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||
|
||||
@@ -255,31 +255,9 @@ class TFEncoderDecoderMixin:
|
||||
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
**kwargs
|
||||
def _check_output_with_attentions(
|
||||
self, outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
||||
|
||||
@@ -311,6 +289,83 @@ class TFEncoderDecoderMixin:
|
||||
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
**kwargs
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
self._check_output_with_attentions(
|
||||
outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_output_attentions_from_config(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
**kwargs
|
||||
):
|
||||
# Similar to `check_encoder_decoder_model_output_attentions`, but with `output_attentions` triggered from the
|
||||
# config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
|
||||
# from the inner models' configurations.
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.config.output_attentions = True # model config -> won't work
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
self.assertTrue(
|
||||
all(
|
||||
key not in outputs_encoder_decoder
|
||||
for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
)
|
||||
)
|
||||
|
||||
config.output_attentions = True # inner model config -> will work
|
||||
decoder_config.output_attentions = True
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
self._check_output_with_attentions(
|
||||
outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
@@ -570,6 +625,10 @@ class TFEncoderDecoderMixin:
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_output_attentions_from_config(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_output_attentions_from_config(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_generate(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||
|
||||
Reference in New Issue
Block a user