Use random_attention_mask for TF tests (#16517)
* use random_attention_mask for TF tests * Fix for TFCLIP test (for now). Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -1440,7 +1440,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
|
||||
def random_attention_mask(shape, rng=None, name=None, dtype=None):
|
||||
attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None, dtype=dtype)
|
||||
# make sure that at least one token is attended to for each batch
|
||||
attn_mask = tf.concat([tf.constant(value=1, shape=(shape[0], 1), dtype=dtype), attn_mask[:, 1:]], axis=1)
|
||||
attn_mask = tf.concat([attn_mask[:, :-1], tf.ones_like(attn_mask[:, -1:], dtype=dtype)], axis=-1)
|
||||
return attn_mask
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user