[Tests] fix attention masks in Tests (#6621)
* fix distilbert * fix typo
This commit is contained in:
committed by
GitHub
parent
c9454507cf
commit
505f2d749e
@@ -704,9 +704,6 @@ class ModelTesterMixin:
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
elif torch.isinf(tuple_object).any() and torch.isinf(dict_object).any():
|
||||
# TODO: (Lysandre) - maybe take a look if that's ok here
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
torch.allclose(tuple_object, dict_object, atol=1e-5),
|
||||
@@ -937,6 +934,13 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
|
||||
return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous()
|
||||
|
||||
|
||||
def random_attention_mask(shape, rng=None, name=None):
|
||||
attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None)
|
||||
# make sure that at least one token is attended to for each batch
|
||||
attn_mask[:, -1] = 1
|
||||
return attn_mask
|
||||
|
||||
|
||||
def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||
"""Creates a random float32 tensor"""
|
||||
if rng is None:
|
||||
|
||||
Reference in New Issue
Block a user