remove Bert2Rnd test
This commit is contained in:
@@ -29,7 +29,7 @@ if is_torch_available():
|
|||||||
from transformers import (BertConfig, BertModel, BertForMaskedLM,
|
from transformers import (BertConfig, BertModel, BertForMaskedLM,
|
||||||
BertForNextSentencePrediction, BertForPreTraining,
|
BertForNextSentencePrediction, BertForPreTraining,
|
||||||
BertForQuestionAnswering, BertForSequenceClassification,
|
BertForQuestionAnswering, BertForSequenceClassification,
|
||||||
BertForTokenClassification, BertForMultipleChoice, Bert2Rnd)
|
BertForTokenClassification, BertForMultipleChoice)
|
||||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
else:
|
else:
|
||||||
pytestmark = pytest.mark.skip("Require Torch")
|
pytestmark = pytest.mark.skip("Require Torch")
|
||||||
@@ -255,17 +255,6 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
|||||||
[self.batch_size, self.num_choices])
|
[self.batch_size, self.num_choices])
|
||||||
self.check_loss_output(result)
|
self.check_loss_output(result)
|
||||||
|
|
||||||
def create_and_check_bert2rnd(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
|
||||||
config.num_choices = self.num_choices
|
|
||||||
model = Bert2Rnd(config=config)
|
|
||||||
model.eval()
|
|
||||||
bert2rnd_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
|
||||||
bert2rnd_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
|
||||||
bert2rnd_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
|
||||||
_ = model(bert2rnd_inputs_ids,
|
|
||||||
attention_mask=bert2rnd_input_mask,
|
|
||||||
token_type_ids=bert2rnd_token_type_ids)
|
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(config, input_ids, token_type_ids, input_mask,
|
(config, input_ids, token_type_ids, input_mask,
|
||||||
|
|||||||
Reference in New Issue
Block a user