[Tests] fix attention masks in Tests (#6621)
* fix distilbert * fix typo
This commit is contained in:
committed by
GitHub
parent
c9454507cf
commit
505f2d749e
@@ -21,7 +21,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -100,7 +100,7 @@ class XLNetModelTester:
|
||||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float()
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
|
||||
perm_mask = torch.zeros(
|
||||
|
||||
Reference in New Issue
Block a user