Split LMBert model in two (#4874)

* Split LMBert model in two

* Fix example

* Remove lm_labels

* Adapt tests, refactor prepare_for_generation

* Fix merge

* Hide BeartLMHeadModel
This commit is contained in:
Sylvain Gugger
2020-06-10 18:26:42 -04:00
committed by GitHub
parent f6da8b2200
commit 1e2631d6f8
4 changed files with 163 additions and 95 deletions

View File

@@ -27,7 +27,8 @@ from .utils import require_torch, slow, torch_device
if is_torch_available():
from transformers import BertModel, BertForMaskedLM, EncoderDecoderModel, EncoderDecoderConfig
from transformers import BertModel, EncoderDecoderModel, EncoderDecoderConfig
from transformers.modeling_bert import BertLMHeadModel
import numpy as np
import torch
@@ -70,7 +71,6 @@ class EncoderDecoderModelTest(unittest.TestCase):
"decoder_token_labels": decoder_token_labels,
"decoder_choice_labels": decoder_choice_labels,
"encoder_hidden_states": encoder_hidden_states,
"lm_labels": decoder_token_labels,
"labels": decoder_token_labels,
}
@@ -116,7 +116,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
decoder_model = BertLMHeadModel(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
@@ -153,7 +153,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
decoder_model = BertLMHeadModel(decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
enc_dec_model.to(torch_device)
@@ -179,7 +179,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
decoder_model = BertLMHeadModel(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
enc_dec_model.eval()
@@ -220,7 +220,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
decoder_model = BertLMHeadModel(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
enc_dec_model.eval()
@@ -269,7 +269,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
decoder_model = BertLMHeadModel(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
@@ -288,41 +288,9 @@ class EncoderDecoderModelTest(unittest.TestCase):
self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,)))
def create_and_check_bert_encoder_decoder_model_lm_labels(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
lm_labels,
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
lm_labels=lm_labels,
)
lm_loss = outputs_encoder_decoder[0]
self.check_loss_output(lm_loss)
# check that backprop works
lm_loss.backward()
self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,)))
def create_and_check_bert_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
decoder_model = BertLMHeadModel(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
@@ -356,10 +324,6 @@ class EncoderDecoderModelTest(unittest.TestCase):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_labels(**input_ids_dict)
def test_bert_encoder_decoder_model_lm_labels(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_lm_labels(**input_ids_dict)
def test_bert_encoder_decoder_model_generate(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_generate(**input_ids_dict)