The introduction of a decoder introduces 2 changes:
- We need to be able to specify a separate mask in the cross
attention to mask the positions corresponding to padding tokens in the
encoder state.
- The self-attention in the decoder needs to be causal on top of not
attending to padding tokens.
In Rothe et al.'s "Leveraging Pre-trained Checkpoints for Sequence
Generation Tasks", Bert2Bert is initialized with pre-trained weights for
the encoder, and only pre-trained embeddings for the decoder. The
current version of the code completely randomizes the weights of the
decoder.
We write a custom function to initiliaze the weights of the decoder; we
first initialize the decoder with the weights and then randomize
everything but the embeddings.
Since the preloading of weights relies on the name of the class's
attributes changing the namespace breaks loading pretrained weights on
Bert and all related models. I reverted `self_attention` to `attention`
and us `crossattention` for the decoder instead.
In the seq2seq model we need to both load pretrained weights in the
encoder and initialize the decoder randomly. Because the
`from_pretrained` method defined in the base class relies on module
names to assign weights, it would also initialize the decoder with
pretrained weights. To avoid this we override the method to only
initialize the encoder with pretrained weights.
The modifications that I introduced in a previous commit did break
Bert's internal API. I reverted these changes and added more general
classes to handle the encoder-decoder attention case.
There may be a more elegant way to deal with retro-compatibility (I am
not comfortable with the current state of the code), but I cannot see it
right now.
There is currently no way to specify the quey, key and value separately
in the Attention module. However, the decoder's "encoder-decoder
attention" layers take the decoder's last output as a query, the
encoder's states as key and value. We thus modify the existing code so
query, key and value can be added separately.
This obviously poses some naming conventions; `BertSelfAttention` is not
a self-attention module anymore. The way the residual is forwarded is
now awkard, etc. We will need to do some refacto once the decoder is
fully implemented.