TF: BART compatible with XLA generation (#17479)

* Also propagate changes to blenderbot, blenderbot_small, marian, mbart, and pegasus
This commit is contained in:
Joao Gante
2022-06-20 11:07:46 +01:00
committed by GitHub
parent 6589e510fa
commit 132402d752
18 changed files with 421 additions and 86 deletions

View File

@@ -1716,7 +1716,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""