[Tests] fix attention masks in Tests (#6621)

* fix distilbert

* fix typo
This commit is contained in:
Patrick von Platen
2020-08-20 19:23:47 +02:00
committed by GitHub
parent c9454507cf
commit 505f2d749e
15 changed files with 35 additions and 31 deletions

View File

@@ -704,9 +704,6 @@ class ModelTesterMixin:
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
elif torch.isinf(tuple_object).any() and torch.isinf(dict_object).any():
# TODO: (Lysandre) - maybe take a look if that's ok here
return
else:
self.assertTrue(
torch.allclose(tuple_object, dict_object, atol=1e-5),
@@ -937,6 +934,13 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous()
def random_attention_mask(shape, rng=None, name=None):
attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None)
# make sure that at least one token is attended to for each batch
attn_mask[:, -1] = 1
return attn_mask
def floats_tensor(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None: