Add more models to common tests (#4910)
This commit is contained in:
@@ -31,6 +31,7 @@ if is_torch_available():
|
||||
XLNetConfig,
|
||||
XLNetModel,
|
||||
XLNetLMHeadModel,
|
||||
XLNetForMultipleChoice,
|
||||
XLNetForSequenceClassification,
|
||||
XLNetForTokenClassification,
|
||||
XLNetForQuestionAnswering,
|
||||
@@ -48,6 +49,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
XLNetForTokenClassification,
|
||||
XLNetForSequenceClassification,
|
||||
XLNetForQuestionAnswering,
|
||||
XLNetForMultipleChoice,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
@@ -84,6 +86,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
pad_token_id=5,
|
||||
num_choices=4,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -110,6 +113,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.bos_token_id = bos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.num_choices = num_choices
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
Reference in New Issue
Block a user