[Use cache] Align logic of use_cache with output_attentions and output_hidden_states (#5194)
* fix use cache * add bart use cache * fix bart * finish bart
This commit is contained in:
committed by
GitHub
parent
64c393ee74
commit
c2a26ec8a6
@@ -153,6 +153,7 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def test_advanced_inputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.use_cache = False
|
||||
inputs_dict["input_ids"][:, -2:] = config.pad_token_id
|
||||
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
|
||||
config, inputs_dict["input_ids"]
|
||||
|
||||
@@ -168,7 +168,14 @@ class GPT2ModelTester:
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
output, past = model(input_ids, token_type_ids=token_type_ids)
|
||||
outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True)
|
||||
outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids)
|
||||
outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False)
|
||||
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
output, past = outputs
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
||||
@@ -193,7 +193,14 @@ class T5ModelTester:
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
output, past_key_value_states = model(input_ids, use_cache=True)
|
||||
outputs = model(input_ids, use_cache=True)
|
||||
outputs_use_cache_conf = model(input_ids)
|
||||
outputs_no_past = model(input_ids, use_cache=False)
|
||||
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
output, past_key_value_states = outputs
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
||||
@@ -126,6 +126,7 @@ class TFModelTesterMixin:
|
||||
if "T5" in main_layer_class.__name__:
|
||||
# Take the same values than in TFT5ModelTester for this shared layer
|
||||
shared = TFSharedEmbeddings(99, 32, name="shared")
|
||||
config.use_cache = False
|
||||
main_layer = main_layer_class(config, embed_tokens=shared)
|
||||
else:
|
||||
main_layer = main_layer_class(config)
|
||||
|
||||
@@ -143,7 +143,14 @@ class TFGPT2ModelTester:
|
||||
model = TFGPT2Model(config=config)
|
||||
|
||||
# first forward pass
|
||||
output, past = model(input_ids, token_type_ids=token_type_ids)
|
||||
outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True)
|
||||
outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids)
|
||||
outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False)
|
||||
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
output, past = outputs
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
||||
@@ -135,7 +135,15 @@ class TFT5ModelTester:
|
||||
self.batch_size = 1
|
||||
|
||||
# first forward pass
|
||||
_, past_key_value_states = model(input_ids, use_cache=True)
|
||||
outputs = model(input_ids, use_cache=True)
|
||||
|
||||
outputs_use_cache_conf = model(input_ids)
|
||||
outputs_no_past = model(input_ids, use_cache=False)
|
||||
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
output, past_key_value_states = outputs
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
||||
Reference in New Issue
Block a user