Split LMBert model in two (#4874)

* Split LMBert model in two

* Fix example

* Remove lm_labels

* Adapt tests, refactor prepare_for_generation

* Fix merge

* Hide BeartLMHeadModel
This commit is contained in:
Sylvain Gugger
2020-06-10 18:26:42 -04:00
committed by GitHub
parent f6da8b2200
commit 1e2631d6f8
4 changed files with 163 additions and 95 deletions

View File

@@ -35,7 +35,7 @@ if is_torch_available():
BertForTokenClassification,
BertForMultipleChoice,
)
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST, BertLMHeadModel
class BertModelTester:
@@ -211,6 +211,33 @@ class BertModelTester:
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_bert_for_causal_lm(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
model = BertLMHeadModel(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result)
def create_and_check_bert_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
@@ -229,7 +256,7 @@ class BertModelTester:
)
self.check_loss_output(result)
def create_and_check_bert_model_for_masked_lm_as_decoder(
def create_and_check_bert_model_for_causal_lm_as_decoder(
self,
config,
input_ids,
@@ -241,7 +268,7 @@ class BertModelTester:
encoder_hidden_states,
encoder_attention_mask,
):
model = BertForMaskedLM(config=config)
model = BertLMHeadModel(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
@@ -461,13 +488,17 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
encoder_attention_mask,
)
def test_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_bert_for_causal_lm(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
def test_for_masked_lm_decoder(self):
def test_for_causal_lm_decoder(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_bert_model_for_masked_lm_as_decoder(*config_and_inputs)
self.model_tester.create_and_check_bert_model_for_causal_lm_as_decoder(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()