TF: properly handle kwargs in encoder_decoder architectures (#16465)

* properly handle kwargs in encoder_decoder architectures

* make fixup
This commit is contained in:
Joao Gante
2022-03-29 18:17:47 +01:00
committed by GitHub
parent 0540d1b6c0
commit 7a9ef8181c
4 changed files with 20 additions and 8 deletions

View File

@@ -91,6 +91,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
@@ -122,6 +123,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
@@ -137,6 +139,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
@@ -167,6 +170,7 @@ class TFEncoderDecoderMixin:
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
kwargs=kwargs,
)
self.assertEqual(
@@ -195,6 +199,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0
@@ -208,6 +213,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0
@@ -235,6 +241,7 @@ class TFEncoderDecoderMixin:
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
labels=labels,
kwargs=kwargs,
)
# Make sure `loss` exist
@@ -269,6 +276,7 @@ class TFEncoderDecoderMixin:
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
kwargs=kwargs,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]