[Use cache] Align logic of use_cache with output_attentions and output_hidden_states (#5194)

* fix use cache

* add bart use cache

* fix bart

* finish bart
This commit is contained in:
Patrick von Platen
2020-06-24 16:09:17 +02:00
committed by GitHub
parent 64c393ee74
commit c2a26ec8a6
13 changed files with 90 additions and 21 deletions

View File

@@ -168,7 +168,14 @@ class GPT2ModelTester:
model.eval()
# first forward pass
output, past = model(input_ids, token_type_ids=token_type_ids)
outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids)
outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
output, past = outputs
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)