[EncoderDecoder] Make tests more aggressive (#9256)
* add tests * make style and fix bart bug * fix bart past key value edge case * correct tf bart test * fix gpt2 tf * fix t5 test
This commit is contained in:
committed by
GitHub
parent
ec07da65e2
commit
e9d77ccd5a
@@ -201,19 +201,24 @@ class TFT5ModelTester:
|
||||
model = TFT5Model(config=config).get_decoder()
|
||||
|
||||
input_ids = input_ids[:1, :]
|
||||
attention_mask = attention_mask[:1, :]
|
||||
self.batch_size = 1
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||
next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids)[0]
|
||||
output_from_past = model(next_tokens, past_key_values=outputs.past_key_values)[0]
|
||||
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]
|
||||
output_from_past = model(
|
||||
next_tokens, attention_mask=next_attention_mask, past_key_values=outputs.past_key_values
|
||||
)[0]
|
||||
|
||||
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user