[Tests] fix attention masks in Tests (#6621)
* fix distilbert * fix typo
This commit is contained in:
committed by
GitHub
parent
c9454507cf
commit
505f2d749e
@@ -20,7 +20,7 @@ from transformers import is_torch_available
|
|||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -71,7 +71,7 @@ class AlbertModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from transformers import is_torch_available
|
|||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -93,7 +93,7 @@ class BertModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
|||||||
@@ -704,9 +704,6 @@ class ModelTesterMixin:
|
|||||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||||
elif tuple_object is None:
|
elif tuple_object is None:
|
||||||
return
|
return
|
||||||
elif torch.isinf(tuple_object).any() and torch.isinf(dict_object).any():
|
|
||||||
# TODO: (Lysandre) - maybe take a look if that's ok here
|
|
||||||
return
|
|
||||||
else:
|
else:
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
torch.allclose(tuple_object, dict_object, atol=1e-5),
|
torch.allclose(tuple_object, dict_object, atol=1e-5),
|
||||||
@@ -937,6 +934,13 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
|
|||||||
return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous()
|
return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def random_attention_mask(shape, rng=None, name=None):
|
||||||
|
attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None)
|
||||||
|
# make sure that at least one token is attended to for each batch
|
||||||
|
attn_mask[:, -1] = 1
|
||||||
|
return attn_mask
|
||||||
|
|
||||||
|
|
||||||
def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||||
"""Creates a random float32 tensor"""
|
"""Creates a random float32 tensor"""
|
||||||
if rng is None:
|
if rng is None:
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from transformers import is_torch_available
|
|||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -60,7 +60,7 @@ class CTRLModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from transformers import is_torch_available
|
|||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -89,7 +89,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
sequence_labels = None
|
sequence_labels = None
|
||||||
token_labels = None
|
token_labels = None
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from transformers import is_torch_available
|
|||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -88,7 +88,7 @@ class DPRModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from transformers import is_torch_available
|
|||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -69,7 +69,7 @@ class ElectraModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from transformers import is_torch_available
|
|||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -72,7 +72,7 @@ class FlaubertModelTester(object):
|
|||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float()
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
input_lengths = None
|
input_lengths = None
|
||||||
if self.use_input_lengths:
|
if self.use_input_lengths:
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from transformers import is_torch_available
|
|||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -92,7 +92,7 @@ class GPT2ModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from transformers import is_torch_available
|
|||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -82,7 +82,7 @@ class LongformerModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from transformers import is_torch_available
|
|||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -94,7 +94,7 @@ class MobileBertModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from transformers import is_torch_available
|
|||||||
from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device
|
from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -133,7 +133,7 @@ class ReformerModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
choice_labels = None
|
choice_labels = None
|
||||||
if self.use_labels:
|
if self.use_labels:
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from transformers import is_torch_available
|
|||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -71,7 +71,7 @@ class RobertaModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from transformers import is_torch_available
|
|||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -73,7 +73,7 @@ class XLMModelTester:
|
|||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float()
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
input_lengths = None
|
input_lengths = None
|
||||||
if self.use_input_lengths:
|
if self.use_input_lengths:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from transformers import is_torch_available
|
|||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -100,7 +100,7 @@ class XLNetModelTester:
|
|||||||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||||
input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float()
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
|
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
|
||||||
perm_mask = torch.zeros(
|
perm_mask = torch.zeros(
|
||||||
|
|||||||
Reference in New Issue
Block a user