[cleanup] Hoist ModelTester objects to top level (#4939)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Amil Khare
2020-06-16 17:33:43 +05:30
committed by GitHub
parent 0c55a384f8
commit c852036b4a
25 changed files with 4721 additions and 5212 deletions

View File

@@ -34,27 +34,6 @@ if is_torch_available():
DistilBertForSequenceClassification,
)
@require_torch
class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
DistilBertModel,
DistilBertForMaskedLM,
DistilBertForMultipleChoice,
DistilBertForQuestionAnswering,
DistilBertForSequenceClassification,
DistilBertForTokenClassification,
)
if is_torch_available()
else None
)
test_pruning = True
test_torchscript = True
test_resize_embeddings = True
test_head_masking = True
class DistilBertModelTester(object):
def __init__(
self,
@@ -245,8 +224,29 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
DistilBertModel,
DistilBertForMaskedLM,
DistilBertForMultipleChoice,
DistilBertForQuestionAnswering,
DistilBertForSequenceClassification,
DistilBertForTokenClassification,
)
if is_torch_available()
else None
)
test_pruning = True
test_torchscript = True
test_resize_embeddings = True
test_head_masking = True
def setUp(self):
self.model_tester = DistilBertModelTest.DistilBertModelTester(self)
self.model_tester = DistilBertModelTester(self)
self.config_tester = ConfigTester(self, config_class=DistilBertConfig, dim=37)
def test_config(self):