[TFT5, Cache] Add cache to TFT5 (#3772)
* correct gpt2 test inputs * make style * delete modeling_gpt2 change in test file * translate from pytorch * correct tests * fix conflicts * fix conflicts * fix conflicts * fix conflicts * make tensorflow t5 caching work * make style * clean reorder cache * remove unnecessary spaces * fix test
This commit is contained in:
committed by
GitHub
parent
a5b249472e
commit
38f7461df3
@@ -1299,17 +1299,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = []
|
||||
for layer_past in past:
|
||||
# get the correct batch idx from layer past batch dim
|
||||
# batch dim of `past` and `mems` is at 2nd position
|
||||
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx]
|
||||
reordered_layer_past = tf.concat(reordered_layer_past, axis=1)
|
||||
# check that shape matches
|
||||
assert shape_list(reordered_layer_past) == shape_list(layer_past)
|
||||
reordered_past.append(reordered_layer_past)
|
||||
past = tuple(reordered_past)
|
||||
return past
|
||||
return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
|
||||
|
||||
|
||||
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
|
||||
|
||||
Reference in New Issue
Block a user