TF - Fix interchangeable past/past_key_values and revert output variable name in GPT2 (#16332)
* revert tf gpt2 * add test for unpack_inputs and fix test case * add changes to vision encoder decoder
This commit is contained in:
@@ -372,7 +372,7 @@ def unpack_inputs(func):
|
|||||||
|
|
||||||
# process the inputs and call the wrapped function
|
# process the inputs and call the wrapped function
|
||||||
main_input_name = getattr(self, "main_input_name", func.__code__.co_varnames[1])
|
main_input_name = getattr(self, "main_input_name", func.__code__.co_varnames[1])
|
||||||
main_input = fn_args_and_kwargs.pop(main_input_name)
|
main_input = fn_args_and_kwargs.pop(main_input_name, None)
|
||||||
unpacked_inputs = input_processing(func, self.config, main_input, **fn_args_and_kwargs)
|
unpacked_inputs = input_processing(func, self.config, main_input, **fn_args_and_kwargs)
|
||||||
return func(self, **unpacked_inputs)
|
return func(self, **unpacked_inputs)
|
||||||
|
|
||||||
@@ -423,13 +423,13 @@ def input_processing(func, config, input_ids, **kwargs):
|
|||||||
)
|
)
|
||||||
output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")
|
output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")
|
||||||
|
|
||||||
if "past" in kwargs["kwargs_call"] and "past_key_values" in kwargs:
|
if "past" in kwargs["kwargs_call"] and "past_key_values" in parameter_names:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past")
|
kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past")
|
||||||
elif "past_key_values" in kwargs["kwargs_call"] and "past" in kwargs:
|
elif "past_key_values" in kwargs["kwargs_call"] and "past" in parameter_names:
|
||||||
kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values")
|
kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values")
|
||||||
|
|
||||||
if len(kwargs["kwargs_call"]) > 0:
|
if len(kwargs["kwargs_call"]) > 0:
|
||||||
@@ -497,6 +497,7 @@ def input_processing(func, config, input_ids, **kwargs):
|
|||||||
f"Data of type {type(input_ids)} is not allowed only {allowed_types} is accepted for {parameter_names[0]}."
|
f"Data of type {type(input_ids)} is not allowed only {allowed_types} is accepted for {parameter_names[0]}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Populates any unspecified argument with their default value, according to the signature.
|
||||||
for name in parameter_names:
|
for name in parameter_names:
|
||||||
if name not in list(output.keys()) and name != "args":
|
if name not in list(output.keys()) and name != "args":
|
||||||
output[name] = kwargs.pop(name, signature[name].default)
|
output[name] = kwargs.pop(name, signature[name].default)
|
||||||
|
|||||||
@@ -694,6 +694,9 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
):
|
):
|
||||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
|
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
|
||||||
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
||||||
|
past_key_values = decoder_inputs.get("past_key_values")
|
||||||
|
if past_key_values is None:
|
||||||
|
past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2
|
||||||
input_dict = {
|
input_dict = {
|
||||||
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
|
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
@@ -701,7 +704,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
"decoder_input_ids": decoder_inputs["input_ids"],
|
"decoder_input_ids": decoder_inputs["input_ids"],
|
||||||
# TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete
|
# TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete
|
||||||
"encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
|
"encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
|
||||||
"past_key_values": decoder_inputs["past_key_values"],
|
"past_key_values": past_key_values,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
}
|
}
|
||||||
return input_dict
|
return input_dict
|
||||||
|
|||||||
@@ -878,7 +878,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
"input_ids": inputs,
|
"input_ids": inputs,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
"past_key_values": past,
|
"past": past,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -725,6 +725,9 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
|||||||
):
|
):
|
||||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
|
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
|
||||||
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
||||||
|
past_key_values = decoder_inputs.get("past_key_values")
|
||||||
|
if past_key_values is None:
|
||||||
|
past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2
|
||||||
input_dict = {
|
input_dict = {
|
||||||
"pixel_values": None, # needs to be passed to make Keras.layer.__call__ happy
|
"pixel_values": None, # needs to be passed to make Keras.layer.__call__ happy
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
@@ -732,7 +735,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
|||||||
"decoder_input_ids": decoder_inputs["input_ids"],
|
"decoder_input_ids": decoder_inputs["input_ids"],
|
||||||
# TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete
|
# TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete
|
||||||
"encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
|
"encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
|
||||||
"past_key_values": decoder_inputs["past_key_values"],
|
"past_key_values": past_key_values,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
}
|
}
|
||||||
return input_dict
|
return input_dict
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from typing import List, Tuple
|
|||||||
from huggingface_hub import delete_repo, login
|
from huggingface_hub import delete_repo, login
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers import is_tf_available
|
from transformers import is_tf_available
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import tooslow # noqa: F401
|
from transformers.testing_utils import tooslow # noqa: F401
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
@@ -80,6 +81,7 @@ if is_tf_available():
|
|||||||
TFSampleDecoderOnlyOutput,
|
TFSampleDecoderOnlyOutput,
|
||||||
TFSampleEncoderDecoderOutput,
|
TFSampleEncoderDecoderOutput,
|
||||||
)
|
)
|
||||||
|
from transformers.modeling_tf_utils import unpack_inputs
|
||||||
|
|
||||||
if _tf_gpu_memory_limit is not None:
|
if _tf_gpu_memory_limit is not None:
|
||||||
gpus = tf.config.list_physical_devices("GPU")
|
gpus = tf.config.list_physical_devices("GPU")
|
||||||
@@ -1553,6 +1555,68 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12)
|
tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12)
|
||||||
tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx)
|
tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx)
|
||||||
|
|
||||||
|
# tests whether the unpack_inputs function behaves as expected
|
||||||
|
def test_unpack_inputs(self):
|
||||||
|
class DummyModel:
|
||||||
|
def __init__(self):
|
||||||
|
config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False}
|
||||||
|
self.config = PretrainedConfig(**config_kwargs)
|
||||||
|
|
||||||
|
@unpack_inputs
|
||||||
|
def call(
|
||||||
|
self, input_ids=None, past=None, output_attentions=None, output_hidden_states=None, return_dict=None
|
||||||
|
):
|
||||||
|
return input_ids, past, output_attentions, output_hidden_states, return_dict
|
||||||
|
|
||||||
|
dummy_model = DummyModel()
|
||||||
|
input_ids = tf.constant([0, 1, 2, 3])
|
||||||
|
past = tf.constant([4, 5, 6, 7])
|
||||||
|
|
||||||
|
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
|
||||||
|
output = dummy_model.call(input_ids=input_ids, past=past)
|
||||||
|
tf.debugging.assert_equal(output[0], input_ids)
|
||||||
|
tf.debugging.assert_equal(output[1], past)
|
||||||
|
self.assertFalse(output[2])
|
||||||
|
self.assertFalse(output[3])
|
||||||
|
self.assertFalse(output[4])
|
||||||
|
|
||||||
|
# test case 2: Same as above, but with positional arguments.
|
||||||
|
output = dummy_model.call(input_ids, past)
|
||||||
|
tf.debugging.assert_equal(output[0], input_ids)
|
||||||
|
tf.debugging.assert_equal(output[1], past)
|
||||||
|
self.assertFalse(output[2])
|
||||||
|
self.assertFalse(output[3])
|
||||||
|
self.assertFalse(output[4])
|
||||||
|
|
||||||
|
# test case 3: We can also pack everything in the first input.
|
||||||
|
output = dummy_model.call(input_ids={"input_ids": input_ids, "past": past})
|
||||||
|
tf.debugging.assert_equal(output[0], input_ids)
|
||||||
|
tf.debugging.assert_equal(output[1], past)
|
||||||
|
self.assertFalse(output[2])
|
||||||
|
self.assertFalse(output[3])
|
||||||
|
self.assertFalse(output[4])
|
||||||
|
|
||||||
|
# test case 4: Explicit boolean arguments should override the config.
|
||||||
|
output = dummy_model.call(input_ids=input_ids, past=past, output_attentions=False, return_dict=True)
|
||||||
|
tf.debugging.assert_equal(output[0], input_ids)
|
||||||
|
tf.debugging.assert_equal(output[1], past)
|
||||||
|
self.assertFalse(output[2])
|
||||||
|
self.assertFalse(output[3])
|
||||||
|
self.assertTrue(output[4])
|
||||||
|
|
||||||
|
# test case 5: Unexpected arguments should raise an exception.
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
output = dummy_model.call(input_ids=input_ids, past=past, foo="bar")
|
||||||
|
|
||||||
|
# test case 6: Despite the above, `past_key_values` should be interchangeable with `past`
|
||||||
|
# (the decorator moves it to `past`, or vice-versa, depending on the signature).
|
||||||
|
output = dummy_model.call(input_ids=input_ids, past_key_values=past)
|
||||||
|
tf.debugging.assert_equal(output[0], input_ids)
|
||||||
|
tf.debugging.assert_equal(output[1], past)
|
||||||
|
self.assertFalse(output[2])
|
||||||
|
self.assertFalse(output[3])
|
||||||
|
self.assertFalse(output[4])
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
Reference in New Issue
Block a user