[T5, generation] Add decoder caching for T5 (#3682)

* initial commit to add decoder caching for T5

* better naming for caching

* finish T5 decoder caching

* correct test

* added extensive past testing for T5

* clean files

* make tests cleaner

* improve docstring

* improve docstring

* better reorder cache

* make style

* Update src/transformers/modeling_t5.py

Co-Authored-By: Yacine Jernite <yjernite@users.noreply.github.com>

* make set output past work for all layers

* improve docstring

* improve docstring

Co-authored-by: Yacine Jernite <yjernite@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2020-04-10 01:02:50 +02:00
committed by GitHub
parent 9384e5f6de
commit ce2298fb5f
4 changed files with 386 additions and 82 deletions

View File

@@ -128,6 +128,7 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
config.output_attentions = True
config.output_hidden_states = False
config.output_past = False
model = model_class(config)
model.to(torch_device)
model.eval()
@@ -144,10 +145,9 @@ class ModelTesterMixin:
out_len = len(outputs)
if self.is_encoder_decoder:
correct_outlen = (
4 # decoder_features_or_logits, decoder_attentions, encoder_features, encoder_attentions
)
correct_outlen = 4
decoder_attention_idx = 1
if "lm_labels" in inputs_dict: # loss will come first
correct_outlen += 1 # compute loss
decoder_attention_idx += 1