TF: properly handle kwargs in encoder_decoder architectures (#16465)
* properly handle kwargs in encoder_decoder architectures * make fixup
This commit is contained in:
@@ -96,6 +96,7 @@ class TFVisionEncoderDecoderMixin:
|
||||
pixel_values=pixel_values,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
@@ -124,6 +125,7 @@ class TFVisionEncoderDecoderMixin:
|
||||
pixel_values=pixel_values,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
@@ -137,6 +139,7 @@ class TFVisionEncoderDecoderMixin:
|
||||
encoder_outputs=encoder_outputs,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
@@ -164,6 +167,7 @@ class TFVisionEncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
return_dict=True,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
@@ -189,6 +193,7 @@ class TFVisionEncoderDecoderMixin:
|
||||
pixel_values=pixel_values,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
out_2 = np.array(outputs[0])
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
@@ -201,6 +206,7 @@ class TFVisionEncoderDecoderMixin:
|
||||
pixel_values=pixel_values,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
out_1 = np.array(after_outputs[0])
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
@@ -226,6 +232,7 @@ class TFVisionEncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
labels=labels,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Make sure `loss` exist
|
||||
@@ -257,6 +264,7 @@ class TFVisionEncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||
|
||||
Reference in New Issue
Block a user