[TFT5, Cache] Add cache to TFT5 (#3772)

* correct gpt2 test inputs

* make style

* delete modeling_gpt2 change in test file

* translate from pytorch

* correct tests

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* make tensorflow t5 caching work

* make style

* clean reorder cache

* remove unnecessary spaces

* fix test
This commit is contained in:
Patrick von Platen
2020-04-16 16:14:52 +02:00
committed by GitHub
parent a5b249472e
commit 38f7461df3
6 changed files with 384 additions and 86 deletions

View File

@@ -1299,17 +1299,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx]
reordered_layer_past = tf.concat(reordered_layer_past, axis=1)
# check that shape matches
assert shape_list(reordered_layer_past) == shape_list(layer_past)
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
return past
return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):