From 33c01368b19701bc6e5ea886f108663752d31d86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 16 Oct 2019 18:13:05 +0200 Subject: [PATCH] remove Bert2Rnd test --- transformers/tests/modeling_bert_test.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/transformers/tests/modeling_bert_test.py b/transformers/tests/modeling_bert_test.py index e649cd8ce8..6c39c4e4db 100644 --- a/transformers/tests/modeling_bert_test.py +++ b/transformers/tests/modeling_bert_test.py @@ -29,7 +29,7 @@ if is_torch_available(): from transformers import (BertConfig, BertModel, BertForMaskedLM, BertForNextSentencePrediction, BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, - BertForTokenClassification, BertForMultipleChoice, Bert2Rnd) + BertForTokenClassification, BertForMultipleChoice) from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP else: pytestmark = pytest.mark.skip("Require Torch") @@ -255,17 +255,6 @@ class BertModelTest(CommonTestCases.CommonModelTester): [self.batch_size, self.num_choices]) self.check_loss_output(result) - def create_and_check_bert2rnd(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): - config.num_choices = self.num_choices - model = Bert2Rnd(config=config) - model.eval() - bert2rnd_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() - bert2rnd_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() - bert2rnd_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() - _ = model(bert2rnd_inputs_ids, - attention_mask=bert2rnd_input_mask, - token_type_ids=bert2rnd_token_type_ids) - def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, token_type_ids, input_mask,