[EncoderDecoder] Add functionality to tie encoder decoder weights (#6538)
* start adding tie encoder to decoder functionality * finish model tying * make style * Apply suggestions from code review * fix t5 list including cross attention * apply sams suggestions * Update src/transformers/modeling_encoder_decoder.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * add max depth break point Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
ab42d74850
commit
fe0b85e77a
@@ -87,7 +87,7 @@ class EncoderDecoderConfig(PretrainedConfig):
|
||||
|
||||
@classmethod
|
||||
def from_encoder_decoder_configs(
|
||||
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig
|
||||
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
|
||||
) -> PretrainedConfig:
|
||||
r"""
|
||||
Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model configuration and decoder model configuration.
|
||||
@@ -99,7 +99,7 @@ class EncoderDecoderConfig(PretrainedConfig):
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.add_cross_attention = True
|
||||
|
||||
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict())
|
||||
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user