TF: properly handle kwargs in encoder_decoder architectures (#16465)

* properly handle kwargs in encoder_decoder architectures

* make fixup
This commit is contained in:
Joao Gante
2022-03-29 18:17:47 +01:00
committed by GitHub
parent 0540d1b6c0
commit 7a9ef8181c
4 changed files with 20 additions and 8 deletions

View File

@@ -96,6 +96,7 @@ class TFVisionEncoderDecoderMixin:
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
@@ -124,6 +125,7 @@ class TFVisionEncoderDecoderMixin:
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
@@ -137,6 +139,7 @@ class TFVisionEncoderDecoderMixin:
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
@@ -164,6 +167,7 @@ class TFVisionEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
kwargs=kwargs,
)
self.assertEqual(
@@ -189,6 +193,7 @@ class TFVisionEncoderDecoderMixin:
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0
@@ -201,6 +206,7 @@ class TFVisionEncoderDecoderMixin:
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0
@@ -226,6 +232,7 @@ class TFVisionEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
labels=labels,
kwargs=kwargs,
)
# Make sure `loss` exist
@@ -257,6 +264,7 @@ class TFVisionEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
kwargs=kwargs,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]