TF: unpack_inputs decorator independent from main_input_name (#18110)
This commit is contained in:
@@ -404,9 +404,7 @@ def unpack_inputs(func):
|
|||||||
fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
|
fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
|
||||||
|
|
||||||
# process the inputs and call the wrapped function
|
# process the inputs and call the wrapped function
|
||||||
main_input_name = getattr(self, "main_input_name", func.__code__.co_varnames[1])
|
unpacked_inputs = input_processing(func, self.config, **fn_args_and_kwargs)
|
||||||
main_input = fn_args_and_kwargs.pop(main_input_name, None)
|
|
||||||
unpacked_inputs = input_processing(func, self.config, main_input, **fn_args_and_kwargs)
|
|
||||||
return func(self, **unpacked_inputs)
|
return func(self, **unpacked_inputs)
|
||||||
|
|
||||||
# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
|
# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
|
||||||
@@ -417,7 +415,7 @@ def unpack_inputs(func):
|
|||||||
return run_call_with_unpacked_inputs
|
return run_call_with_unpacked_inputs
|
||||||
|
|
||||||
|
|
||||||
def input_processing(func, config, input_ids, **kwargs):
|
def input_processing(func, config, **kwargs):
|
||||||
"""
|
"""
|
||||||
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
|
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
|
||||||
has to be named accordingly to the parameters name, i.e. `input_ids = tf.keras.Input(shape=(128,), dtype='int32',
|
has to be named accordingly to the parameters name, i.e. `input_ids = tf.keras.Input(shape=(128,), dtype='int32',
|
||||||
@@ -438,6 +436,8 @@ def input_processing(func, config, input_ids, **kwargs):
|
|||||||
has_kwargs = bool(signature.pop("kwargs", None))
|
has_kwargs = bool(signature.pop("kwargs", None))
|
||||||
signature.pop("self", None)
|
signature.pop("self", None)
|
||||||
parameter_names = list(signature.keys())
|
parameter_names = list(signature.keys())
|
||||||
|
main_input_name = parameter_names[0]
|
||||||
|
main_input = kwargs.pop(main_input_name, None)
|
||||||
output = {}
|
output = {}
|
||||||
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray, KerasTensor)
|
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray, KerasTensor)
|
||||||
|
|
||||||
@@ -483,8 +483,8 @@ def input_processing(func, config, input_ids, **kwargs):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
||||||
|
|
||||||
if isinstance(input_ids, (tuple, list)):
|
if isinstance(main_input, (tuple, list)):
|
||||||
for i, input in enumerate(input_ids):
|
for i, input in enumerate(main_input):
|
||||||
# EagerTensors don't allow to use the .name property so we check for a real Tensor
|
# EagerTensors don't allow to use the .name property so we check for a real Tensor
|
||||||
if type(input) == tf.Tensor:
|
if type(input) == tf.Tensor:
|
||||||
# Tensor names have always the pattern `name:id` then we check only the
|
# Tensor names have always the pattern `name:id` then we check only the
|
||||||
@@ -502,25 +502,25 @@ def input_processing(func, config, input_ids, **kwargs):
|
|||||||
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
|
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
|
||||||
f" {parameter_names[i]}."
|
f" {parameter_names[i]}."
|
||||||
)
|
)
|
||||||
elif isinstance(input_ids, Mapping):
|
elif isinstance(main_input, Mapping):
|
||||||
if "inputs" in input_ids:
|
if "inputs" in main_input:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids`"
|
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids`"
|
||||||
" instead.",
|
" instead.",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
output["input_ids"] = input_ids.pop("inputs")
|
output["input_ids"] = main_input.pop("inputs")
|
||||||
|
|
||||||
if "decoder_cached_states" in input_ids:
|
if "decoder_cached_states" in main_input:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
|
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
|
||||||
" `past_key_values` instead.",
|
" `past_key_values` instead.",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
output["past_key_values"] = input_ids.pop("decoder_cached_states")
|
output["past_key_values"] = main_input.pop("decoder_cached_states")
|
||||||
|
|
||||||
for k, v in dict(input_ids).items():
|
for k, v in dict(main_input).items():
|
||||||
if isinstance(v, allowed_types) or v is None:
|
if isinstance(v, allowed_types) or v is None:
|
||||||
output[k] = v
|
output[k] = v
|
||||||
elif k not in parameter_names and "args" not in parameter_names:
|
elif k not in parameter_names and "args" not in parameter_names:
|
||||||
@@ -531,12 +531,12 @@ def input_processing(func, config, input_ids, **kwargs):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
||||||
else:
|
else:
|
||||||
if isinstance(input_ids, (tf.Tensor, KerasTensor)) or input_ids is None:
|
if isinstance(main_input, (tf.Tensor, KerasTensor)) or main_input is None:
|
||||||
output[parameter_names[0]] = input_ids
|
output[main_input_name] = main_input
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Data of type {type(input_ids)} is not allowed only {allowed_types} is accepted for"
|
f"Data of type {type(main_input)} is not allowed only {allowed_types} is accepted for"
|
||||||
f" {parameter_names[0]}."
|
f" {main_input_name}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Populates any unspecified argument with their default value, according to the signature.
|
# Populates any unspecified argument with their default value, according to the signature.
|
||||||
|
|||||||
@@ -1881,6 +1881,7 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False}
|
config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False}
|
||||||
self.config = PretrainedConfig(**config_kwargs)
|
self.config = PretrainedConfig(**config_kwargs)
|
||||||
|
self.main_input_name = "input_ids"
|
||||||
|
|
||||||
@unpack_inputs
|
@unpack_inputs
|
||||||
def call(
|
def call(
|
||||||
@@ -1888,9 +1889,14 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
):
|
):
|
||||||
return input_ids, past, output_attentions, output_hidden_states, return_dict
|
return input_ids, past, output_attentions, output_hidden_states, return_dict
|
||||||
|
|
||||||
|
@unpack_inputs
|
||||||
|
def foo(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None):
|
||||||
|
return pixel_values, output_attentions, output_hidden_states, return_dict
|
||||||
|
|
||||||
dummy_model = DummyModel()
|
dummy_model = DummyModel()
|
||||||
input_ids = tf.constant([0, 1, 2, 3])
|
input_ids = tf.constant([0, 1, 2, 3])
|
||||||
past = tf.constant([4, 5, 6, 7])
|
past = tf.constant([4, 5, 6, 7])
|
||||||
|
pixel_values = tf.constant([8, 9, 10, 11])
|
||||||
|
|
||||||
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
|
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
|
||||||
output = dummy_model.call(input_ids=input_ids, past=past)
|
output = dummy_model.call(input_ids=input_ids, past=past)
|
||||||
@@ -1937,6 +1943,14 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
self.assertFalse(output[3])
|
self.assertFalse(output[3])
|
||||||
self.assertFalse(output[4])
|
self.assertFalse(output[4])
|
||||||
|
|
||||||
|
# test case 7: the decorator is independent from `main_input_name` -- it treats the first argument of the
|
||||||
|
# decorated function as its main input.
|
||||||
|
output = dummy_model.foo(pixel_values=pixel_values)
|
||||||
|
tf.debugging.assert_equal(output[0], pixel_values)
|
||||||
|
self.assertFalse(output[1])
|
||||||
|
self.assertFalse(output[2])
|
||||||
|
self.assertFalse(output[3])
|
||||||
|
|
||||||
# Tests whether the stable softmax is stable on CPU, with and without XLA
|
# Tests whether the stable softmax is stable on CPU, with and without XLA
|
||||||
def test_xla_stable_softmax(self):
|
def test_xla_stable_softmax(self):
|
||||||
large_penalty = -1e9
|
large_penalty = -1e9
|
||||||
|
|||||||
Reference in New Issue
Block a user