TF - Fix interchangeable past/past_key_values and revert output variable name in GPT2 (#16332)

* revert tf gpt2

* add test for unpack_inputs and fix test case

* add changes to vision encoder decoder
This commit is contained in:
Joao Gante
2022-03-23 18:41:18 +00:00
committed by GitHub
parent 12428f0ef1
commit 9e8c37dc82
5 changed files with 77 additions and 6 deletions

View File

@@ -27,6 +27,7 @@ from typing import List, Tuple
from huggingface_hub import delete_repo, login
from requests.exceptions import HTTPError
from transformers import is_tf_available
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import get_values
from transformers.testing_utils import tooslow # noqa: F401
from transformers.testing_utils import (
@@ -80,6 +81,7 @@ if is_tf_available():
TFSampleDecoderOnlyOutput,
TFSampleEncoderDecoderOutput,
)
from transformers.modeling_tf_utils import unpack_inputs
if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU")
@@ -1553,6 +1555,68 @@ class UtilsFunctionsTest(unittest.TestCase):
tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12)
tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx)
# tests whether the unpack_inputs function behaves as expected
def test_unpack_inputs(self):
class DummyModel:
def __init__(self):
config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False}
self.config = PretrainedConfig(**config_kwargs)
@unpack_inputs
def call(
self, input_ids=None, past=None, output_attentions=None, output_hidden_states=None, return_dict=None
):
return input_ids, past, output_attentions, output_hidden_states, return_dict
dummy_model = DummyModel()
input_ids = tf.constant([0, 1, 2, 3])
past = tf.constant([4, 5, 6, 7])
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
output = dummy_model.call(input_ids=input_ids, past=past)
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertFalse(output[4])
# test case 2: Same as above, but with positional arguments.
output = dummy_model.call(input_ids, past)
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertFalse(output[4])
# test case 3: We can also pack everything in the first input.
output = dummy_model.call(input_ids={"input_ids": input_ids, "past": past})
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertFalse(output[4])
# test case 4: Explicit boolean arguments should override the config.
output = dummy_model.call(input_ids=input_ids, past=past, output_attentions=False, return_dict=True)
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertTrue(output[4])
# test case 5: Unexpected arguments should raise an exception.
with self.assertRaises(ValueError):
output = dummy_model.call(input_ids=input_ids, past=past, foo="bar")
# test case 6: Despite the above, `past_key_values` should be interchangeable with `past`
# (the decorator moves it to `past`, or vice-versa, depending on the signature).
output = dummy_model.call(input_ids=input_ids, past_key_values=past)
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertFalse(output[4])
@require_tf
@is_staging_test