TF: use the correct config with (...)EncoderDecoder models (#18097)
This commit is contained in:
@@ -403,8 +403,13 @@ def unpack_inputs(func):
|
|||||||
# move any arg into kwargs, if they exist
|
# move any arg into kwargs, if they exist
|
||||||
fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
|
fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
|
||||||
|
|
||||||
# process the inputs and call the wrapped function
|
# Encoder Decoder models delegate the application of the configuration options to their inner models.
|
||||||
unpacked_inputs = input_processing(func, self.config, **fn_args_and_kwargs)
|
if "encoder_decoder" in str(self).lower():
|
||||||
|
config = None
|
||||||
|
else:
|
||||||
|
config = self.config
|
||||||
|
|
||||||
|
unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs)
|
||||||
return func(self, **unpacked_inputs)
|
return func(self, **unpacked_inputs)
|
||||||
|
|
||||||
# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
|
# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
|
||||||
@@ -559,18 +564,19 @@ def input_processing(func, config, **kwargs):
|
|||||||
if "kwargs" in output:
|
if "kwargs" in output:
|
||||||
del output["kwargs"]
|
del output["kwargs"]
|
||||||
|
|
||||||
boolean_dict = {
|
if config is not None:
|
||||||
k: v
|
boolean_dict = {
|
||||||
for k, v in output.items()
|
k: v
|
||||||
if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
|
for k, v in output.items()
|
||||||
}
|
if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
|
||||||
|
}
|
||||||
|
|
||||||
output.update(
|
output.update(
|
||||||
booleans_processing(
|
booleans_processing(
|
||||||
config=config,
|
config=config,
|
||||||
**boolean_dict,
|
**boolean_dict,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@@ -630,13 +630,13 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
||||||
loss = self.hf_compute_loss(labels, logits)
|
loss = self.hf_compute_loss(labels, logits)
|
||||||
|
|
||||||
past_key_values = None
|
if not return_dict:
|
||||||
if decoder_inputs["use_cache"]:
|
past_key_values = None
|
||||||
past_key_values = decoder_outputs[1]
|
if use_cache:
|
||||||
# The starting index of the remaining elements in `decoder_outputs`
|
past_key_values = decoder_outputs[1]
|
||||||
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
|
# The starting index of the remaining elements in `decoder_outputs`
|
||||||
|
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
|
||||||
|
|
||||||
if not decoder_inputs["return_dict"]:
|
|
||||||
if not isinstance(encoder_outputs, tuple):
|
if not isinstance(encoder_outputs, tuple):
|
||||||
encoder_outputs = encoder_outputs.to_tuple()
|
encoder_outputs = encoder_outputs.to_tuple()
|
||||||
output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs
|
output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs
|
||||||
@@ -646,7 +646,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
return TFSeq2SeqLMOutput(
|
return TFSeq2SeqLMOutput(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=decoder_outputs.logits,
|
logits=decoder_outputs.logits,
|
||||||
past_key_values=past_key_values,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
cross_attentions=decoder_outputs.cross_attentions,
|
cross_attentions=decoder_outputs.cross_attentions,
|
||||||
|
|||||||
@@ -663,13 +663,13 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
|||||||
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
||||||
loss = self.hf_compute_loss(labels, logits)
|
loss = self.hf_compute_loss(labels, logits)
|
||||||
|
|
||||||
past_key_values = None
|
if not return_dict:
|
||||||
if decoder_inputs["use_cache"]:
|
past_key_values = None
|
||||||
past_key_values = decoder_outputs[1]
|
if use_cache:
|
||||||
# The starting index of the remaining elements in `decoder_outputs`
|
past_key_values = decoder_outputs[1]
|
||||||
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
|
# The starting index of the remaining elements in `decoder_outputs`
|
||||||
|
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
|
||||||
|
|
||||||
if not decoder_inputs["return_dict"]:
|
|
||||||
if not isinstance(encoder_outputs, tuple):
|
if not isinstance(encoder_outputs, tuple):
|
||||||
encoder_outputs = encoder_outputs.to_tuple()
|
encoder_outputs = encoder_outputs.to_tuple()
|
||||||
output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs
|
output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs
|
||||||
@@ -679,7 +679,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
|||||||
return TFSeq2SeqLMOutput(
|
return TFSeq2SeqLMOutput(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=decoder_outputs.logits,
|
logits=decoder_outputs.logits,
|
||||||
past_key_values=past_key_values,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
cross_attentions=decoder_outputs.cross_attentions,
|
cross_attentions=decoder_outputs.cross_attentions,
|
||||||
|
|||||||
@@ -351,32 +351,9 @@ class EncoderDecoderMixin:
|
|||||||
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_encoder_decoder_model_output_attentions(
|
def _check_output_with_attentions(
|
||||||
self,
|
self, outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
attention_mask,
|
|
||||||
encoder_hidden_states,
|
|
||||||
decoder_config,
|
|
||||||
decoder_input_ids,
|
|
||||||
decoder_attention_mask,
|
|
||||||
labels,
|
|
||||||
**kwargs
|
|
||||||
):
|
):
|
||||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
|
||||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
|
||||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
|
||||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
|
||||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
|
||||||
enc_dec_model.to(torch_device)
|
|
||||||
outputs_encoder_decoder = enc_dec_model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
decoder_input_ids=decoder_input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
|
||||||
output_attentions=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||||
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
||||||
|
|
||||||
@@ -408,6 +385,85 @@ class EncoderDecoderMixin:
|
|||||||
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
|
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_encoder_decoder_model_output_attentions(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
labels,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||||
|
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||||
|
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||||
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
|
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
enc_dec_model.to(torch_device)
|
||||||
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
output_attentions=True,
|
||||||
|
)
|
||||||
|
self._check_output_with_attentions(
|
||||||
|
outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_encoder_decoder_model_output_attentions_from_config(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
labels,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# Similar to `check_encoder_decoder_model_output_attentions`, but with `output_attentions` triggered from the
|
||||||
|
# config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
|
||||||
|
# from the inner models' configurations.
|
||||||
|
|
||||||
|
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||||
|
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||||
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
|
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
enc_dec_model.config.output_attentions = True # model config -> won't work
|
||||||
|
enc_dec_model.to(torch_device)
|
||||||
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
all(
|
||||||
|
key not in outputs_encoder_decoder
|
||||||
|
for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
config.output_attentions = True # inner model config -> will work
|
||||||
|
decoder_config.output_attentions = True
|
||||||
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
|
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
enc_dec_model.to(torch_device)
|
||||||
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
)
|
||||||
|
self._check_output_with_attentions(
|
||||||
|
outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
|
||||||
|
)
|
||||||
|
|
||||||
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
|
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
|
||||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
@@ -543,6 +599,10 @@ class EncoderDecoderMixin:
|
|||||||
input_ids_dict = self.prepare_config_and_inputs()
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
|
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
|
||||||
|
|
||||||
|
def test_encoder_decoder_model_output_attentions_from_config(self):
|
||||||
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
|
self.check_encoder_decoder_model_output_attentions_from_config(**input_ids_dict)
|
||||||
|
|
||||||
def test_encoder_decoder_model_generate(self):
|
def test_encoder_decoder_model_generate(self):
|
||||||
input_ids_dict = self.prepare_config_and_inputs()
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||||
|
|||||||
@@ -255,31 +255,9 @@ class TFEncoderDecoderMixin:
|
|||||||
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_encoder_decoder_model_output_attentions(
|
def _check_output_with_attentions(
|
||||||
self,
|
self, outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
attention_mask,
|
|
||||||
encoder_hidden_states,
|
|
||||||
decoder_config,
|
|
||||||
decoder_input_ids,
|
|
||||||
decoder_attention_mask,
|
|
||||||
**kwargs
|
|
||||||
):
|
):
|
||||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
|
||||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
|
||||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
|
||||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
|
||||||
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
|
||||||
outputs_encoder_decoder = enc_dec_model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
decoder_input_ids=decoder_input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
|
||||||
output_attentions=True,
|
|
||||||
kwargs=kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||||
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
||||||
|
|
||||||
@@ -311,6 +289,83 @@ class TFEncoderDecoderMixin:
|
|||||||
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
|
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_encoder_decoder_model_output_attentions(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||||
|
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||||
|
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||||
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
|
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
output_attentions=True,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
self._check_output_with_attentions(
|
||||||
|
outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_encoder_decoder_model_output_attentions_from_config(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# Similar to `check_encoder_decoder_model_output_attentions`, but with `output_attentions` triggered from the
|
||||||
|
# config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
|
||||||
|
# from the inner models' configurations.
|
||||||
|
|
||||||
|
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||||
|
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||||
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
|
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
enc_dec_model.config.output_attentions = True # model config -> won't work
|
||||||
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
all(
|
||||||
|
key not in outputs_encoder_decoder
|
||||||
|
for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
config.output_attentions = True # inner model config -> will work
|
||||||
|
decoder_config.output_attentions = True
|
||||||
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
|
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
self._check_output_with_attentions(
|
||||||
|
outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
|
||||||
|
)
|
||||||
|
|
||||||
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
|
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
|
||||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
@@ -570,6 +625,10 @@ class TFEncoderDecoderMixin:
|
|||||||
input_ids_dict = self.prepare_config_and_inputs()
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
|
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
|
||||||
|
|
||||||
|
def test_encoder_decoder_model_output_attentions_from_config(self):
|
||||||
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
|
self.check_encoder_decoder_model_output_attentions_from_config(**input_ids_dict)
|
||||||
|
|
||||||
def test_encoder_decoder_model_generate(self):
|
def test_encoder_decoder_model_generate(self):
|
||||||
input_ids_dict = self.prepare_config_and_inputs()
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||||
|
|||||||
Reference in New Issue
Block a user