Black 20 release
This commit is contained in:
@@ -217,7 +217,10 @@ class MobileBertModelTester:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels,
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
next_sentence_label=sequence_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
|
||||
|
||||
@@ -397,7 +400,11 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
def _long_tensor(tok_lst):
|
||||
return torch.tensor(tok_lst, dtype=torch.long, device=torch_device,)
|
||||
return torch.tensor(
|
||||
tok_lst,
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
|
||||
TOLERANCE = 1e-3
|
||||
|
||||
Reference in New Issue
Block a user