TF - Fix interchangeable past/past_key_values and revert output variable name in GPT2 (#16332)
* revert tf gpt2 * add test for unpack_inputs and fix test case * add changes to vision encoder decoder
This commit is contained in:
@@ -27,6 +27,7 @@ from typing import List, Tuple
|
||||
from huggingface_hub import delete_repo, login
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import is_tf_available
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import tooslow # noqa: F401
|
||||
from transformers.testing_utils import (
|
||||
@@ -80,6 +81,7 @@ if is_tf_available():
|
||||
TFSampleDecoderOnlyOutput,
|
||||
TFSampleEncoderDecoderOutput,
|
||||
)
|
||||
from transformers.modeling_tf_utils import unpack_inputs
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
@@ -1553,6 +1555,68 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12)
|
||||
tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx)
|
||||
|
||||
# tests whether the unpack_inputs function behaves as expected
|
||||
def test_unpack_inputs(self):
|
||||
class DummyModel:
|
||||
def __init__(self):
|
||||
config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False}
|
||||
self.config = PretrainedConfig(**config_kwargs)
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self, input_ids=None, past=None, output_attentions=None, output_hidden_states=None, return_dict=None
|
||||
):
|
||||
return input_ids, past, output_attentions, output_hidden_states, return_dict
|
||||
|
||||
dummy_model = DummyModel()
|
||||
input_ids = tf.constant([0, 1, 2, 3])
|
||||
past = tf.constant([4, 5, 6, 7])
|
||||
|
||||
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
|
||||
output = dummy_model.call(input_ids=input_ids, past=past)
|
||||
tf.debugging.assert_equal(output[0], input_ids)
|
||||
tf.debugging.assert_equal(output[1], past)
|
||||
self.assertFalse(output[2])
|
||||
self.assertFalse(output[3])
|
||||
self.assertFalse(output[4])
|
||||
|
||||
# test case 2: Same as above, but with positional arguments.
|
||||
output = dummy_model.call(input_ids, past)
|
||||
tf.debugging.assert_equal(output[0], input_ids)
|
||||
tf.debugging.assert_equal(output[1], past)
|
||||
self.assertFalse(output[2])
|
||||
self.assertFalse(output[3])
|
||||
self.assertFalse(output[4])
|
||||
|
||||
# test case 3: We can also pack everything in the first input.
|
||||
output = dummy_model.call(input_ids={"input_ids": input_ids, "past": past})
|
||||
tf.debugging.assert_equal(output[0], input_ids)
|
||||
tf.debugging.assert_equal(output[1], past)
|
||||
self.assertFalse(output[2])
|
||||
self.assertFalse(output[3])
|
||||
self.assertFalse(output[4])
|
||||
|
||||
# test case 4: Explicit boolean arguments should override the config.
|
||||
output = dummy_model.call(input_ids=input_ids, past=past, output_attentions=False, return_dict=True)
|
||||
tf.debugging.assert_equal(output[0], input_ids)
|
||||
tf.debugging.assert_equal(output[1], past)
|
||||
self.assertFalse(output[2])
|
||||
self.assertFalse(output[3])
|
||||
self.assertTrue(output[4])
|
||||
|
||||
# test case 5: Unexpected arguments should raise an exception.
|
||||
with self.assertRaises(ValueError):
|
||||
output = dummy_model.call(input_ids=input_ids, past=past, foo="bar")
|
||||
|
||||
# test case 6: Despite the above, `past_key_values` should be interchangeable with `past`
|
||||
# (the decorator moves it to `past`, or vice-versa, depending on the signature).
|
||||
output = dummy_model.call(input_ids=input_ids, past_key_values=past)
|
||||
tf.debugging.assert_equal(output[0], input_ids)
|
||||
tf.debugging.assert_equal(output[1], past)
|
||||
self.assertFalse(output[2])
|
||||
self.assertFalse(output[3])
|
||||
self.assertFalse(output[4])
|
||||
|
||||
|
||||
@require_tf
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user