[EncoderDecoder] Add functionality to tie encoder decoder weights (#6538)

* start adding tie encoder to decoder functionality

* finish model tying

* make style

* Apply suggestions from code review

* fix t5 list including cross attention

* apply sams suggestions

* Update src/transformers/modeling_encoder_decoder.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add max depth break point

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2020-08-19 14:23:45 +02:00
committed by GitHub
parent ab42d74850
commit fe0b85e77a
7 changed files with 288 additions and 26 deletions

View File

@@ -87,7 +87,7 @@ class EncoderDecoderConfig(PretrainedConfig):
@classmethod
def from_encoder_decoder_configs(
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
) -> PretrainedConfig:
r"""
Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model configuration and decoder model configuration.
@@ -99,7 +99,7 @@ class EncoderDecoderConfig(PretrainedConfig):
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict())
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
def to_dict(self):
"""