rename Bert2Bert -> Bert2Rnd

This commit is contained in:
Rémi Louf
2019-10-08 16:07:25 +02:00
parent 82628b0fc9
commit 8abfee9ec3
3 changed files with 7 additions and 6 deletions

View File

@@ -29,7 +29,7 @@ if is_torch_available():
from transformers import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification, BertForMultipleChoice, Bert2Bert)
BertForTokenClassification, BertForMultipleChoice, Bert2Rnd)
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
else:
pytestmark = pytest.mark.skip("Require Torch")
@@ -257,7 +257,7 @@ class BertModelTest(CommonTestCases.CommonModelTester):
def create_and_check_bert2bert(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
config.num_choices = self.num_choices
model = Bert2Bert(config=config)
model = Bert2Rnd(config=config)
model.eval()
bert2bert_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
bert2bert_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()