[T5, generation] Add decoder caching for T5 (#3682)
* initial commit to add decoder caching for T5 * better naming for caching * finish T5 decoder caching * correct test * added extensive past testing for T5 * clean files * make tests cleaner * improve docstring * improve docstring * better reorder cache * make style * Update src/transformers/modeling_t5.py Co-Authored-By: Yacine Jernite <yjernite@users.noreply.github.com> * make set output past work for all layers * improve docstring * improve docstring Co-authored-by: Yacine Jernite <yjernite@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
9384e5f6de
commit
ce2298fb5f
@@ -1417,17 +1417,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = []
|
||||
for layer_past in past:
|
||||
# get the correct batch idx from layer past batch dim
|
||||
# batch dim of `past` and `mems` is at 2nd position
|
||||
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
||||
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
|
||||
# check that shape matches
|
||||
assert reordered_layer_past.shape == layer_past.shape
|
||||
reordered_past.append(reordered_layer_past)
|
||||
past = tuple(reordered_past)
|
||||
return past
|
||||
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
|
||||
|
||||
|
||||
def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
|
||||
|
||||
Reference in New Issue
Block a user