add test for initialization of Bert2Rnd
This commit is contained in:
@@ -259,12 +259,12 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
||||
config.num_choices = self.num_choices
|
||||
model = Bert2Rnd(config=config)
|
||||
model.eval()
|
||||
bert2bert_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
bert2bert_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
bert2bert_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
_ = model(bert2bert_inputs_ids,
|
||||
attention_mask=bert2bert_input_mask,
|
||||
token_type_ids=bert2bert_token_type_ids)
|
||||
bert2rnd_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
bert2rnd_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
bert2rnd_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
_ = model(bert2rnd_inputs_ids,
|
||||
attention_mask=bert2rnd_input_mask,
|
||||
token_type_ids=bert2rnd_token_type_ids)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
||||
Reference in New Issue
Block a user