[EncoderDecoder] Fix initialization and save/load bug (#4680)
* fix bug * add more tests
This commit is contained in:
committed by
GitHub
parent
6f82aea66b
commit
0866669e75
@@ -35,6 +35,7 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
class method for the encoder and `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` class method for the decoder.
|
||||
"""
|
||||
config_class = EncoderDecoderConfig
|
||||
base_model_prefix = "encoder_decoder"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -158,12 +159,26 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
|
||||
from .modeling_auto import AutoModelWithLMHead
|
||||
|
||||
if "config" not in kwargs_decoder:
|
||||
from transformers import AutoConfig
|
||||
|
||||
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
||||
if decoder_config.is_decoder is False:
|
||||
logger.info(
|
||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
||||
)
|
||||
decoder_config.is_decoder = True
|
||||
|
||||
kwargs_decoder["config"] = decoder_config
|
||||
|
||||
if kwargs_decoder["config"].is_decoder is False:
|
||||
logger.warning(
|
||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attribute `is_decoder` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` is set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||
)
|
||||
|
||||
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||
decoder.config.is_decoder = True
|
||||
|
||||
model = cls(encoder=encoder, decoder=decoder)
|
||||
|
||||
return model
|
||||
return cls(encoder=encoder, decoder=decoder)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user