[EncoderDecoder] Add functionality to tie encoder decoder weights (#6538)
* start adding tie encoder to decoder functionality * finish model tying * make style * Apply suggestions from code review * fix t5 list including cross attention * apply sams suggestions * Update src/transformers/modeling_encoder_decoder.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * add max depth break point Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
ab42d74850
commit
fe0b85e77a
@@ -58,6 +58,8 @@ class PretrainedConfig(object):
|
||||
Whether the model is used as decoder or not (in which case it's used as an encoder).
|
||||
add_cross_attention (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether cross-attention layers should be added to the model. Note, this option is only relevant for models that can be used as decoder models within the `:class:~transformers.EncoderDecoderModel` class, which consists of all models in ``AUTO_MODELS_FOR_CAUSAL_LM``.
|
||||
tie_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`)
|
||||
Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder and decoder model to have the exact same parameter names.
|
||||
prune_heads (:obj:`Dict[int, List[int]]`, `optional`, defaults to :obj:`{}`):
|
||||
Pruned heads of the model. The keys are the selected layer indices and the associated values, the list
|
||||
of heads to prune in said layer.
|
||||
@@ -153,6 +155,7 @@ class PretrainedConfig(object):
|
||||
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
|
||||
self.is_decoder = kwargs.pop("is_decoder", False)
|
||||
self.add_cross_attention = kwargs.pop("add_cross_attention", False)
|
||||
self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
|
||||
|
||||
# Parameters for sequence generation
|
||||
self.max_length = kwargs.pop("max_length", 20)
|
||||
|
||||
Reference in New Issue
Block a user