From 3f07cd419ce973b4ca8fd4e12fe664d08408b343 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 30 Oct 2019 15:09:53 +0100 Subject: [PATCH] update test on Bert to include decoder mode --- transformers/tests/modeling_bert_test.py | 50 +++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/transformers/tests/modeling_bert_test.py b/transformers/tests/modeling_bert_test.py index 6c39c4e4db..67be910a7e 100644 --- a/transformers/tests/modeling_bert_test.py +++ b/transformers/tests/modeling_bert_test.py @@ -22,7 +22,7 @@ import pytest from transformers import is_torch_available -from .modeling_common_test import (CommonTestCases, ids_tensor) +from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor) from .configuration_common_test import ConfigTester if is_torch_available(): @@ -120,10 +120,20 @@ class BertModelTest(CommonTestCases.CommonModelTester): attention_probs_dropout_prob=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, + is_decoder=False, initializer_range=self.initializer_range) return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + def prepare_config_and_inputs_for_decoder(self): + config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels = self.prepare_config_and_inputs() + + config.is_decoder = True + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask + def check_loss_output(self, result): self.parent.assertListEqual( list(result["loss"].size()), @@ -145,6 +155,22 @@ class BertModelTest(CommonTestCases.CommonModelTester): [self.batch_size, self.seq_length, self.hidden_size]) self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) + def create_and_check_bert_model_as_decoder(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask): + model = BertModel(config) + model.eval() + sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask) + sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states) + sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + + result = { + "sequence_output": sequence_output, + "pooled_output": pooled_output, + } + self.parent.assertListEqual( + list(result["sequence_output"].size()), + [self.batch_size, self.seq_length, self.hidden_size]) + self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) + def create_and_check_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): model = BertForMaskedLM(config=config) model.eval() @@ -158,6 +184,20 @@ class BertModelTest(CommonTestCases.CommonModelTester): [self.batch_size, self.seq_length, self.vocab_size]) self.check_loss_output(result) + def create_and_check_bert_model_for_masked_lm_as_decoder(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask): + model = BertForMaskedLM(config=config) + model.eval() + loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask) + loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels, encoder_hidden_states=encoder_hidden_states) + result = { + "loss": loss, + "prediction_scores": prediction_scores, + } + self.parent.assertListEqual( + list(result["prediction_scores"].size()), + [self.batch_size, self.seq_length, self.vocab_size]) + self.check_loss_output(result) + def create_and_check_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): model = BertForNextSentencePrediction(config=config) model.eval() @@ -273,10 +313,18 @@ class BertModelTest(CommonTestCases.CommonModelTester): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_bert_model(*config_and_inputs) + def test_bert_model_as_decoder(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_bert_model_as_decoder(*config_and_inputs) + def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs) + def test_for_masked_lm_decoder(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_bert_model_for_masked_lm_as_decoder(*config_and_inputs) + def test_for_multiple_choice(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)