[T5, generation] Add decoder caching for T5 (#3682)
* initial commit to add decoder caching for T5 * better naming for caching * finish T5 decoder caching * correct test * added extensive past testing for T5 * clean files * make tests cleaner * improve docstring * improve docstring * better reorder cache * make style * Update src/transformers/modeling_t5.py Co-Authored-By: Yacine Jernite <yjernite@users.noreply.github.com> * make set output past work for all layers * improve docstring * improve docstring Co-authored-by: Yacine Jernite <yjernite@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
9384e5f6de
commit
ce2298fb5f
@@ -128,6 +128,7 @@ class ModelTesterMixin:
|
||||
for model_class in self.all_model_classes:
|
||||
config.output_attentions = True
|
||||
config.output_hidden_states = False
|
||||
config.output_past = False
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -144,10 +145,9 @@ class ModelTesterMixin:
|
||||
out_len = len(outputs)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
correct_outlen = (
|
||||
4 # decoder_features_or_logits, decoder_attentions, encoder_features, encoder_attentions
|
||||
)
|
||||
correct_outlen = 4
|
||||
decoder_attention_idx = 1
|
||||
|
||||
if "lm_labels" in inputs_dict: # loss will come first
|
||||
correct_outlen += 1 # compute loss
|
||||
decoder_attention_idx += 1
|
||||
|
||||
Reference in New Issue
Block a user