Use random_attention_mask for TF tests (#16517)

* use random_attention_mask for TF tests

* Fix for TFCLIP test (for now).

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-04-01 16:53:07 +02:00
committed by GitHub
parent 823dbf8a41
commit 2199382dfd
29 changed files with 62 additions and 57 deletions

View File

@@ -1440,7 +1440,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
def random_attention_mask(shape, rng=None, name=None, dtype=None):
attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None, dtype=dtype)
# make sure that at least one token is attended to for each batch
attn_mask = tf.concat([tf.constant(value=1, shape=(shape[0], 1), dtype=dtype), attn_mask[:, 1:]], axis=1)
attn_mask = tf.concat([attn_mask[:, :-1], tf.ones_like(attn_mask[:, -1:], dtype=dtype)], axis=-1)
return attn_mask