[EncoderDecoder] Fix initialization and save/load bug (#4680)
* fix bug * add more tests
This commit is contained in:
committed by
GitHub
parent
6f82aea66b
commit
0866669e75
@@ -35,6 +35,7 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
class method for the encoder and `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` class method for the decoder.
|
class method for the encoder and `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` class method for the decoder.
|
||||||
"""
|
"""
|
||||||
config_class = EncoderDecoderConfig
|
config_class = EncoderDecoderConfig
|
||||||
|
base_model_prefix = "encoder_decoder"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -158,12 +159,26 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
|
), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
|
||||||
from .modeling_auto import AutoModelWithLMHead
|
from .modeling_auto import AutoModelWithLMHead
|
||||||
|
|
||||||
|
if "config" not in kwargs_decoder:
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
||||||
|
if decoder_config.is_decoder is False:
|
||||||
|
logger.info(
|
||||||
|
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
||||||
|
)
|
||||||
|
decoder_config.is_decoder = True
|
||||||
|
|
||||||
|
kwargs_decoder["config"] = decoder_config
|
||||||
|
|
||||||
|
if kwargs_decoder["config"].is_decoder is False:
|
||||||
|
logger.warning(
|
||||||
|
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attribute `is_decoder` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` is set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||||
|
)
|
||||||
|
|
||||||
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||||
decoder.config.is_decoder = True
|
|
||||||
|
|
||||||
model = cls(encoder=encoder, decoder=decoder)
|
return cls(encoder=encoder, decoder=decoder)
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from transformers import is_torch_available
|
|||||||
# TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented
|
# TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented
|
||||||
# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest
|
# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest
|
||||||
from .test_modeling_bert import BertModelTester
|
from .test_modeling_bert import BertModelTester
|
||||||
|
from .test_modeling_common import ids_tensor
|
||||||
from .utils import require_torch, slow, torch_device
|
from .utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
@@ -331,3 +332,33 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
|||||||
def test_real_bert_model_from_pretrained(self):
|
def test_real_bert_model_from_pretrained(self):
|
||||||
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
|
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_real_bert_model_from_pretrained_has_cross_attention(self):
|
||||||
|
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
|
||||||
|
self.assertTrue(hasattr(model.decoder.bert.encoder.layer[0], "crossattention"))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_real_bert_model_save_load_from_pretrained(self):
|
||||||
|
model_2 = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
|
||||||
|
model_2.to(torch_device)
|
||||||
|
input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size)
|
||||||
|
decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size)
|
||||||
|
attention_mask = ids_tensor([13, 5], vocab_size=2)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model_2(input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask,)
|
||||||
|
out_2 = outputs[0].cpu().numpy()
|
||||||
|
out_2[np.isnan(out_2)] = 0
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||||
|
model_2.save_pretrained(tmp_dirname)
|
||||||
|
model_1 = EncoderDecoderModel.from_pretrained(tmp_dirname)
|
||||||
|
model_1.to(torch_device)
|
||||||
|
|
||||||
|
after_outputs = model_1(
|
||||||
|
input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
out_1 = after_outputs[0].cpu().numpy()
|
||||||
|
out_1[np.isnan(out_1)] = 0
|
||||||
|
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||||
|
self.assertLessEqual(max_diff, 1e-5)
|
||||||
|
|||||||
Reference in New Issue
Block a user