[EncoderDecoder] Add Cross Attention for GPT2 (#6415)
* add cross attention layers for gpt2 * make gpt2 cross attention work * finish bert2gpt2 * add explicit comments * remove attention mask since not yet supported * revert attn mask in pipeline * Update src/transformers/modeling_gpt2.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_encoder_decoder.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
eb613b566a
commit
1d6e71e116
@@ -372,11 +372,16 @@ class GenerationMixin:
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
if decoder_start_token_id is None:
|
||||
decoder_start_token_id = bos_token_id
|
||||
# see if BOS token can be used for decoder_start_token_id
|
||||
if bos_token_id is not None:
|
||||
decoder_start_token_id = bos_token_id
|
||||
elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"):
|
||||
decoder_start_token_id = self.config.decoder.bos_token_id
|
||||
else:
|
||||
raise ValueError(
|
||||
"decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
|
||||
)
|
||||
|
||||
assert (
|
||||
decoder_start_token_id is not None
|
||||
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
|
||||
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
|
||||
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user