TF: unpack_inputs decorator independent from main_input_name (#18110)
This commit is contained in:
@@ -1881,6 +1881,7 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
def __init__(self):
|
||||
config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False}
|
||||
self.config = PretrainedConfig(**config_kwargs)
|
||||
self.main_input_name = "input_ids"
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
@@ -1888,9 +1889,14 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
):
|
||||
return input_ids, past, output_attentions, output_hidden_states, return_dict
|
||||
|
||||
@unpack_inputs
|
||||
def foo(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None):
|
||||
return pixel_values, output_attentions, output_hidden_states, return_dict
|
||||
|
||||
dummy_model = DummyModel()
|
||||
input_ids = tf.constant([0, 1, 2, 3])
|
||||
past = tf.constant([4, 5, 6, 7])
|
||||
pixel_values = tf.constant([8, 9, 10, 11])
|
||||
|
||||
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
|
||||
output = dummy_model.call(input_ids=input_ids, past=past)
|
||||
@@ -1937,6 +1943,14 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
self.assertFalse(output[3])
|
||||
self.assertFalse(output[4])
|
||||
|
||||
# test case 7: the decorator is independent from `main_input_name` -- it treats the first argument of the
|
||||
# decorated function as its main input.
|
||||
output = dummy_model.foo(pixel_values=pixel_values)
|
||||
tf.debugging.assert_equal(output[0], pixel_values)
|
||||
self.assertFalse(output[1])
|
||||
self.assertFalse(output[2])
|
||||
self.assertFalse(output[3])
|
||||
|
||||
# Tests whether the stable softmax is stable on CPU, with and without XLA
|
||||
def test_xla_stable_softmax(self):
|
||||
large_penalty = -1e9
|
||||
|
||||
Reference in New Issue
Block a user