TF: properly handle kwargs in encoder_decoder architectures (#16465)
* properly handle kwargs in encoder_decoder architectures * make fixup
This commit is contained in:
@@ -91,6 +91,7 @@ class TFEncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
@@ -122,6 +123,7 @@ class TFEncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
@@ -137,6 +139,7 @@ class TFEncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
@@ -167,6 +170,7 @@ class TFEncoderDecoderMixin:
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
return_dict=True,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
@@ -195,6 +199,7 @@ class TFEncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
out_2 = np.array(outputs[0])
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
@@ -208,6 +213,7 @@ class TFEncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
out_1 = np.array(after_outputs[0])
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
@@ -235,6 +241,7 @@ class TFEncoderDecoderMixin:
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
labels=labels,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Make sure `loss` exist
|
||||
@@ -269,6 +276,7 @@ class TFEncoderDecoderMixin:
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||
|
||||
Reference in New Issue
Block a user