[EncoderDecoder] Add functionality to tie encoder decoder weights (#6538)

* start adding tie encoder to decoder functionality

* finish model tying

* make style

* Apply suggestions from code review

* fix t5 list including cross attention

* apply sams suggestions

* Update src/transformers/modeling_encoder_decoder.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add max depth break point

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2020-08-19 14:23:45 +02:00
committed by GitHub
parent ab42d74850
commit fe0b85e77a
7 changed files with 288 additions and 26 deletions

View File

@@ -58,6 +58,8 @@ class PretrainedConfig(object):
Whether the model is used as decoder or not (in which case it's used as an encoder).
add_cross_attention (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether cross-attention layers should be added to the model. Note, this option is only relevant for models that can be used as decoder models within the `:class:~transformers.EncoderDecoderModel` class, which consists of all models in ``AUTO_MODELS_FOR_CAUSAL_LM``.
tie_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`)
Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder and decoder model to have the exact same parameter names.
prune_heads (:obj:`Dict[int, List[int]]`, `optional`, defaults to :obj:`{}`):
Pruned heads of the model. The keys are the selected layer indices and the associated values, the list
of heads to prune in said layer.
@@ -153,6 +155,7 @@ class PretrainedConfig(object):
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
self.is_decoder = kwargs.pop("is_decoder", False)
self.add_cross_attention = kwargs.pop("add_cross_attention", False)
self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
# Parameters for sequence generation
self.max_length = kwargs.pop("max_length", 20)