add test for initialization of Bert2Rnd

This commit is contained in:
Rémi Louf
2019-10-10 18:07:11 +02:00
parent fa218e648a
commit 1e68c28670
2 changed files with 55 additions and 6 deletions

View File

@@ -259,12 +259,12 @@ class BertModelTest(CommonTestCases.CommonModelTester):
config.num_choices = self.num_choices
model = Bert2Rnd(config=config)
model.eval()
bert2bert_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
bert2bert_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
bert2bert_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
_ = model(bert2bert_inputs_ids,
attention_mask=bert2bert_input_mask,
token_type_ids=bert2bert_token_type_ids)
bert2rnd_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
bert2rnd_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
bert2rnd_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
_ = model(bert2rnd_inputs_ids,
attention_mask=bert2rnd_input_mask,
token_type_ids=bert2rnd_token_type_ids)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()