TF: BART compatible with XLA generation (#17479)
* Also propagate changes to blenderbot, blenderbot_small, marian, mbart, and pegasus
This commit is contained in:
@@ -1716,7 +1716,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
|
||||
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
|
||||
|
||||
|
||||
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
|
||||
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user