This commit is contained in:
Lysandre
2020-02-03 18:28:32 -05:00
committed by Lysandre Debut
parent 950c6a4f09
commit 5f96ebc0be
2 changed files with 77 additions and 77 deletions

View File

@@ -185,7 +185,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
self.check_loss_output(result)
def create_and_check_roberta_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = RobertaForMultipleChoice(config=config)
@@ -208,7 +208,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
self.check_loss_output(result)
def create_and_check_roberta_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = RobertaForQuestionAnswering(config=config)
model.to(torch_device)