Split LMBert model in two (#4874)
* Split LMBert model in two * Fix example * Remove lm_labels * Adapt tests, refactor prepare_for_generation * Fix merge * Hide BeartLMHeadModel
This commit is contained in:
@@ -27,7 +27,8 @@ from .utils import require_torch, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import BertModel, BertForMaskedLM, EncoderDecoderModel, EncoderDecoderConfig
|
||||
from transformers import BertModel, EncoderDecoderModel, EncoderDecoderConfig
|
||||
from transformers.modeling_bert import BertLMHeadModel
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@@ -70,7 +71,6 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
"decoder_token_labels": decoder_token_labels,
|
||||
"decoder_choice_labels": decoder_choice_labels,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"lm_labels": decoder_token_labels,
|
||||
"labels": decoder_token_labels,
|
||||
}
|
||||
|
||||
@@ -116,7 +116,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
**kwargs
|
||||
):
|
||||
encoder_model = BertModel(config)
|
||||
decoder_model = BertForMaskedLM(decoder_config)
|
||||
decoder_model = BertLMHeadModel(decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
|
||||
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
||||
@@ -153,7 +153,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
**kwargs
|
||||
):
|
||||
encoder_model = BertModel(config)
|
||||
decoder_model = BertForMaskedLM(decoder_config)
|
||||
decoder_model = BertLMHeadModel(decoder_config)
|
||||
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
|
||||
enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
|
||||
enc_dec_model.to(torch_device)
|
||||
@@ -179,7 +179,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
**kwargs
|
||||
):
|
||||
encoder_model = BertModel(config)
|
||||
decoder_model = BertForMaskedLM(decoder_config)
|
||||
decoder_model = BertLMHeadModel(decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
enc_dec_model.eval()
|
||||
@@ -220,7 +220,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
**kwargs
|
||||
):
|
||||
encoder_model = BertModel(config)
|
||||
decoder_model = BertForMaskedLM(decoder_config)
|
||||
decoder_model = BertLMHeadModel(decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
enc_dec_model.eval()
|
||||
@@ -269,7 +269,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
**kwargs
|
||||
):
|
||||
encoder_model = BertModel(config)
|
||||
decoder_model = BertForMaskedLM(decoder_config)
|
||||
decoder_model = BertLMHeadModel(decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
@@ -288,41 +288,9 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
|
||||
self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,)))
|
||||
|
||||
def create_and_check_bert_encoder_decoder_model_lm_labels(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
**kwargs
|
||||
):
|
||||
encoder_model = BertModel(config)
|
||||
decoder_model = BertForMaskedLM(decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
lm_labels=lm_labels,
|
||||
)
|
||||
|
||||
lm_loss = outputs_encoder_decoder[0]
|
||||
self.check_loss_output(lm_loss)
|
||||
# check that backprop works
|
||||
lm_loss.backward()
|
||||
|
||||
self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
|
||||
self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,)))
|
||||
|
||||
def create_and_check_bert_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
|
||||
encoder_model = BertModel(config)
|
||||
decoder_model = BertForMaskedLM(decoder_config)
|
||||
decoder_model = BertLMHeadModel(decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
|
||||
@@ -356,10 +324,6 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
input_ids_dict = self.prepare_config_and_inputs_bert()
|
||||
self.create_and_check_bert_encoder_decoder_model_labels(**input_ids_dict)
|
||||
|
||||
def test_bert_encoder_decoder_model_lm_labels(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs_bert()
|
||||
self.create_and_check_bert_encoder_decoder_model_lm_labels(**input_ids_dict)
|
||||
|
||||
def test_bert_encoder_decoder_model_generate(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs_bert()
|
||||
self.create_and_check_bert_encoder_decoder_model_generate(**input_ids_dict)
|
||||
|
||||
Reference in New Issue
Block a user