[EncoderDecoderConfig] automatically set decoder config to decoder (#4809)
* automatically set decoder config to decoder * add more tests
This commit is contained in:
committed by
GitHub
parent
f1fe18465d
commit
8cca875569
@@ -85,6 +85,9 @@ class EncoderDecoderConfig(PretrainedConfig):
|
|||||||
Returns:
|
Returns:
|
||||||
:class:`EncoderDecoderConfig`: An instance of a configuration object
|
:class:`EncoderDecoderConfig`: An instance of a configuration object
|
||||||
"""
|
"""
|
||||||
|
logger.info("Set `config.is_decoder=True` for decoder_config")
|
||||||
|
decoder_config.is_decoder = True
|
||||||
|
|
||||||
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict())
|
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict())
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from .utils import require_torch, slow, torch_device
|
|||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers import BertModel, BertForMaskedLM, EncoderDecoderModel
|
from transformers import BertModel, BertForMaskedLM, EncoderDecoderModel, EncoderDecoderConfig
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -74,6 +74,36 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
|||||||
"labels": decoder_token_labels,
|
"labels": decoder_token_labels,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def create_and_check_bert_encoder_decoder_model_from_pretrained_configs(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||||
|
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
|
||||||
|
|
||||||
|
enc_dec_model = EncoderDecoderModel(encoder_decoder_config)
|
||||||
|
enc_dec_model.to(torch_device)
|
||||||
|
enc_dec_model.eval()
|
||||||
|
|
||||||
|
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
||||||
|
|
||||||
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
|
||||||
|
self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,)))
|
||||||
|
|
||||||
def create_and_check_bert_encoder_decoder_model(
|
def create_and_check_bert_encoder_decoder_model(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
@@ -88,6 +118,8 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
|||||||
encoder_model = BertModel(config)
|
encoder_model = BertModel(config)
|
||||||
decoder_model = BertForMaskedLM(decoder_config)
|
decoder_model = BertForMaskedLM(decoder_config)
|
||||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
|
||||||
|
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
||||||
enc_dec_model.to(torch_device)
|
enc_dec_model.to(torch_device)
|
||||||
outputs_encoder_decoder = enc_dec_model(
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@@ -304,6 +336,10 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
|||||||
input_ids_dict = self.prepare_config_and_inputs_bert()
|
input_ids_dict = self.prepare_config_and_inputs_bert()
|
||||||
self.create_and_check_bert_encoder_decoder_model(**input_ids_dict)
|
self.create_and_check_bert_encoder_decoder_model(**input_ids_dict)
|
||||||
|
|
||||||
|
def test_bert_encoder_decoder_model_from_pretrained_configs(self):
|
||||||
|
input_ids_dict = self.prepare_config_and_inputs_bert()
|
||||||
|
self.create_and_check_bert_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
|
||||||
|
|
||||||
def test_bert_encoder_decoder_model_from_pretrained(self):
|
def test_bert_encoder_decoder_model_from_pretrained(self):
|
||||||
input_ids_dict = self.prepare_config_and_inputs_bert()
|
input_ids_dict = self.prepare_config_and_inputs_bert()
|
||||||
self.create_and_check_bert_encoder_decoder_model_from_pretrained(**input_ids_dict)
|
self.create_and_check_bert_encoder_decoder_model_from_pretrained(**input_ids_dict)
|
||||||
|
|||||||
Reference in New Issue
Block a user