[EncoderDecoderConfig] automatically set decoder config to decoder (#4809)
* automatically set decoder config to decoder * add more tests
This commit is contained in:
committed by
GitHub
parent
f1fe18465d
commit
8cca875569
@@ -85,6 +85,9 @@ class EncoderDecoderConfig(PretrainedConfig):
|
||||
Returns:
|
||||
:class:`EncoderDecoderConfig`: An instance of a configuration object
|
||||
"""
|
||||
logger.info("Set `config.is_decoder=True` for decoder_config")
|
||||
decoder_config.is_decoder = True
|
||||
|
||||
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict())
|
||||
|
||||
def to_dict(self):
|
||||
|
||||
@@ -27,7 +27,7 @@ from .utils import require_torch, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import BertModel, BertForMaskedLM, EncoderDecoderModel
|
||||
from transformers import BertModel, BertForMaskedLM, EncoderDecoderModel, EncoderDecoderConfig
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@@ -74,6 +74,36 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
"labels": decoder_token_labels,
|
||||
}
|
||||
|
||||
def create_and_check_bert_encoder_decoder_model_from_pretrained_configs(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
**kwargs
|
||||
):
|
||||
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
|
||||
|
||||
enc_dec_model = EncoderDecoderModel(encoder_decoder_config)
|
||||
enc_dec_model.to(torch_device)
|
||||
enc_dec_model.eval()
|
||||
|
||||
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
||||
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
|
||||
self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,)))
|
||||
|
||||
def create_and_check_bert_encoder_decoder_model(
|
||||
self,
|
||||
config,
|
||||
@@ -88,6 +118,8 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
encoder_model = BertModel(config)
|
||||
decoder_model = BertForMaskedLM(decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
|
||||
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
@@ -304,6 +336,10 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
input_ids_dict = self.prepare_config_and_inputs_bert()
|
||||
self.create_and_check_bert_encoder_decoder_model(**input_ids_dict)
|
||||
|
||||
def test_bert_encoder_decoder_model_from_pretrained_configs(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs_bert()
|
||||
self.create_and_check_bert_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
|
||||
|
||||
def test_bert_encoder_decoder_model_from_pretrained(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs_bert()
|
||||
self.create_and_check_bert_encoder_decoder_model_from_pretrained(**input_ids_dict)
|
||||
|
||||
Reference in New Issue
Block a user