Add more models to common tests (#4910)
This commit is contained in:
@@ -29,10 +29,12 @@ if is_torch_available():
|
||||
RobertaConfig,
|
||||
RobertaModel,
|
||||
RobertaForMaskedLM,
|
||||
RobertaForMultipleChoice,
|
||||
RobertaForQuestionAnswering,
|
||||
RobertaForSequenceClassification,
|
||||
RobertaForTokenClassification,
|
||||
)
|
||||
from transformers.modeling_roberta import RobertaEmbeddings, RobertaForMultipleChoice, RobertaForQuestionAnswering
|
||||
from transformers.modeling_roberta import RobertaEmbeddings
|
||||
from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.modeling_utils import create_position_ids_from_input_ids
|
||||
|
||||
@@ -40,7 +42,18 @@ if is_torch_available():
|
||||
@require_torch
|
||||
class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (RobertaForMaskedLM, RobertaModel) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(
|
||||
RobertaForMaskedLM,
|
||||
RobertaModel,
|
||||
RobertaForSequenceClassification,
|
||||
RobertaForTokenClassification,
|
||||
RobertaForMultipleChoice,
|
||||
RobertaForQuestionAnswering,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
class RobertaModelTester(object):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user