TF: properly handle kwargs in encoder_decoder architectures (#16465)
* properly handle kwargs in encoder_decoder architectures * make fixup
This commit is contained in:
@@ -569,13 +569,12 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
"output_hidden_states": output_hidden_states,
|
"output_hidden_states": output_hidden_states,
|
||||||
"return_dict": return_dict,
|
"return_dict": return_dict,
|
||||||
"training": training,
|
"training": training,
|
||||||
"kwargs_call": kwargs_encoder,
|
"kwargs_call": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add arguments to encoder from `kwargs_encoder`
|
# Add arguments to encoder from `kwargs_encoder`
|
||||||
for k, v in kwargs_encoder.items():
|
for k, v in kwargs_encoder.items():
|
||||||
encoder_processing_inputs[k] = v
|
encoder_processing_inputs[k] = v
|
||||||
kwargs_encoder = {}
|
|
||||||
|
|
||||||
encoder_inputs = input_processing(**encoder_processing_inputs)
|
encoder_inputs = input_processing(**encoder_processing_inputs)
|
||||||
|
|
||||||
@@ -622,13 +621,12 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"return_dict": return_dict,
|
"return_dict": return_dict,
|
||||||
"training": training,
|
"training": training,
|
||||||
"kwargs_call": kwargs_decoder,
|
"kwargs_call": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add arguments to decoder from `kwargs_decoder`
|
# Add arguments to decoder from `kwargs_decoder`
|
||||||
for k, v in kwargs_decoder.items():
|
for k, v in kwargs_decoder.items():
|
||||||
decoder_processing_inputs[k] = v
|
decoder_processing_inputs[k] = v
|
||||||
kwargs_decoder = {}
|
|
||||||
|
|
||||||
decoder_inputs = input_processing(**decoder_processing_inputs)
|
decoder_inputs = input_processing(**decoder_processing_inputs)
|
||||||
decoder_outputs = self.decoder(**decoder_inputs)
|
decoder_outputs = self.decoder(**decoder_inputs)
|
||||||
|
|||||||
@@ -593,12 +593,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
|||||||
"output_hidden_states": output_hidden_states,
|
"output_hidden_states": output_hidden_states,
|
||||||
"return_dict": return_dict,
|
"return_dict": return_dict,
|
||||||
"training": training,
|
"training": training,
|
||||||
"kwargs_call": kwargs_encoder,
|
"kwargs_call": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add arguments to encoder from `kwargs_encoder`
|
# Add arguments to encoder from `kwargs_encoder`
|
||||||
encoder_processing_inputs.update(kwargs_encoder)
|
encoder_processing_inputs.update(kwargs_encoder)
|
||||||
kwargs_encoder = {}
|
|
||||||
|
|
||||||
encoder_inputs = input_processing(**encoder_processing_inputs)
|
encoder_inputs = input_processing(**encoder_processing_inputs)
|
||||||
|
|
||||||
@@ -654,12 +653,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
|||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"return_dict": return_dict,
|
"return_dict": return_dict,
|
||||||
"training": training,
|
"training": training,
|
||||||
"kwargs_call": kwargs_decoder,
|
"kwargs_call": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add arguments to decoder from `kwargs_decoder`
|
# Add arguments to decoder from `kwargs_decoder`
|
||||||
decoder_processing_inputs.update(kwargs_decoder)
|
decoder_processing_inputs.update(kwargs_decoder)
|
||||||
kwargs_decoder = {}
|
|
||||||
|
|
||||||
decoder_inputs = input_processing(**decoder_processing_inputs)
|
decoder_inputs = input_processing(**decoder_processing_inputs)
|
||||||
decoder_outputs = self.decoder(**decoder_inputs)
|
decoder_outputs = self.decoder(**decoder_inputs)
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ class TFEncoderDecoderMixin:
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@@ -122,6 +123,7 @@ class TFEncoderDecoderMixin:
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||||
@@ -137,6 +139,7 @@ class TFEncoderDecoderMixin:
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@@ -167,6 +170,7 @@ class TFEncoderDecoderMixin:
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@@ -195,6 +199,7 @@ class TFEncoderDecoderMixin:
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
out_2 = np.array(outputs[0])
|
out_2 = np.array(outputs[0])
|
||||||
out_2[np.isnan(out_2)] = 0
|
out_2[np.isnan(out_2)] = 0
|
||||||
@@ -208,6 +213,7 @@ class TFEncoderDecoderMixin:
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
out_1 = np.array(after_outputs[0])
|
out_1 = np.array(after_outputs[0])
|
||||||
out_1[np.isnan(out_1)] = 0
|
out_1[np.isnan(out_1)] = 0
|
||||||
@@ -235,6 +241,7 @@ class TFEncoderDecoderMixin:
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make sure `loss` exist
|
# Make sure `loss` exist
|
||||||
@@ -269,6 +276,7 @@ class TFEncoderDecoderMixin:
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||||
|
|||||||
@@ -96,6 +96,7 @@ class TFVisionEncoderDecoderMixin:
|
|||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@@ -124,6 +125,7 @@ class TFVisionEncoderDecoderMixin:
|
|||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||||
@@ -137,6 +139,7 @@ class TFVisionEncoderDecoderMixin:
|
|||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@@ -164,6 +167,7 @@ class TFVisionEncoderDecoderMixin:
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@@ -189,6 +193,7 @@ class TFVisionEncoderDecoderMixin:
|
|||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
out_2 = np.array(outputs[0])
|
out_2 = np.array(outputs[0])
|
||||||
out_2[np.isnan(out_2)] = 0
|
out_2[np.isnan(out_2)] = 0
|
||||||
@@ -201,6 +206,7 @@ class TFVisionEncoderDecoderMixin:
|
|||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
out_1 = np.array(after_outputs[0])
|
out_1 = np.array(after_outputs[0])
|
||||||
out_1[np.isnan(out_1)] = 0
|
out_1[np.isnan(out_1)] = 0
|
||||||
@@ -226,6 +232,7 @@ class TFVisionEncoderDecoderMixin:
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make sure `loss` exist
|
# Make sure `loss` exist
|
||||||
@@ -257,6 +264,7 @@ class TFVisionEncoderDecoderMixin:
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||||
|
|||||||
Reference in New Issue
Block a user