[T5, generation] Add decoder caching for T5 (#3682)

* initial commit to add decoder caching for T5

* better naming for caching

* finish T5 decoder caching

* correct test

* added extensive past testing for T5

* clean files

* make tests cleaner

* improve docstring

* improve docstring

* better reorder cache

* make style

* Update src/transformers/modeling_t5.py

Co-Authored-By: Yacine Jernite <yjernite@users.noreply.github.com>

* make set output past work for all layers

* improve docstring

* improve docstring

Co-authored-by: Yacine Jernite <yjernite@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2020-04-10 01:02:50 +02:00
committed by GitHub
parent 9384e5f6de
commit ce2298fb5f
4 changed files with 386 additions and 82 deletions

View File

@@ -1417,17 +1417,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches
assert reordered_layer_past.shape == layer_past.shape
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
return past
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):