update model to use past

This commit is contained in:
thomwolf
2019-10-08 17:11:58 +02:00
parent bd5363cc83
commit 3edfa1d6aa
2 changed files with 19 additions and 11 deletions

View File

@@ -144,14 +144,16 @@ class CTRLModelTest(CommonTestCases.CommonModelTester):
model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
model(input_ids, token_type_ids=token_type_ids)
sequence_output, _ = model(input_ids)
sequence_output, presents = model(input_ids)
result = {
"sequence_output": sequence_output,
"presents": presents,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertEqual(len(result["presents"]), config.n_layer)
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = CTRLLMHeadModel(config)