[EncoderDecoderModel] add a add_cross_attention boolean to config (#6377)
* correct encoder decoder model * Apply suggestions from code review * apply sylvains suggestions
This commit is contained in:
committed by
GitHub
parent
06bc347c97
commit
3425936643
@@ -58,6 +58,7 @@ class EncoderDecoderConfig(PretrainedConfig):
|
||||
>>> config_decoder = model.config.decoder
|
||||
>>> # set decoder config to causal lm
|
||||
>>> config_decoder.is_decoder = True
|
||||
>>> config_decoder.add_cross_attention = True
|
||||
|
||||
>>> # Saving the model, including its configuration
|
||||
>>> model.save_pretrained('my-model')
|
||||
@@ -94,8 +95,9 @@ class EncoderDecoderConfig(PretrainedConfig):
|
||||
Returns:
|
||||
:class:`EncoderDecoderConfig`: An instance of a configuration object
|
||||
"""
|
||||
logger.info("Set `config.is_decoder=True` for decoder_config")
|
||||
logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.add_cross_attention = True
|
||||
|
||||
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict())
|
||||
|
||||
|
||||
@@ -56,6 +56,8 @@ class PretrainedConfig(object):
|
||||
Whether the model is used as an encoder/decoder or not.
|
||||
is_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether the model is used as decoder or not (in which case it's used as an encoder).
|
||||
add_cross_attention (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether cross-attention layers should be added to the model. Note, this option is only relevant for models that can be used as decoder models within the `:class:~transformers.EncoderDecoderModel` class, which consists of all models in ``AUTO_MODELS_FOR_CAUSAL_LM``.
|
||||
prune_heads (:obj:`Dict[int, List[int]]`, `optional`, defaults to :obj:`{}`):
|
||||
Pruned heads of the model. The keys are the selected layer indices and the associated values, the list
|
||||
of heads to prune in said layer.
|
||||
@@ -145,6 +147,7 @@ class PretrainedConfig(object):
|
||||
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
||||
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
|
||||
self.is_decoder = kwargs.pop("is_decoder", False)
|
||||
self.add_cross_attention = kwargs.pop("add_cross_attention", False)
|
||||
|
||||
# Parameters for sequence generation
|
||||
self.max_length = kwargs.pop("max_length", 20)
|
||||
|
||||
@@ -378,7 +378,9 @@ class BertLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.attention = BertAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
if self.is_decoder:
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
|
||||
self.crossattention = BertAttention(config)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
@@ -399,6 +401,9 @@ class BertLayer(nn.Module):
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
assert hasattr(
|
||||
self, "crossattention"
|
||||
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
@@ -695,8 +700,10 @@ class BertModel(BertPreTrainedModel):
|
||||
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||
|
||||
To behave as an decoder the model needs to be initialized with the
|
||||
:obj:`is_decoder` argument of the configuration set to :obj:`True`; an
|
||||
:obj:`encoder_hidden_states` is expected as an input to the forward pass.
|
||||
:obj:`is_decoder` argument of the configuration set to :obj:`True`.
|
||||
To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
|
||||
argument and :obj:`add_cross_attention` set to :obj:`True`; an
|
||||
:obj:`encoder_hidden_states` is then expected as an input to the forward pass.
|
||||
|
||||
.. _`Attention is all you need`:
|
||||
https://arxiv.org/abs/1706.03762
|
||||
|
||||
@@ -168,17 +168,18 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
from .configuration_auto import AutoConfig
|
||||
|
||||
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
||||
if decoder_config.is_decoder is False:
|
||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||
logger.info(
|
||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
||||
)
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.add_cross_attention = True
|
||||
|
||||
kwargs_decoder["config"] = decoder_config
|
||||
|
||||
if kwargs_decoder["config"].is_decoder is False:
|
||||
if kwargs_decoder["config"].is_decoder is False or decoder_config.add_cross_attention is False:
|
||||
logger.warning(
|
||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attribute `is_decoder` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` is set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||
)
|
||||
|
||||
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||
|
||||
@@ -176,6 +176,7 @@ class BertModelTester:
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.add_cross_attention = True
|
||||
model = BertModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -235,6 +236,7 @@ class BertModelTester:
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.add_cross_attention = True
|
||||
model = BertLMHeadModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
@@ -59,6 +59,9 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
return {
|
||||
"config": config,
|
||||
"input_ids": input_ids,
|
||||
@@ -119,6 +122,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
decoder_model = BertLMHeadModel(decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
|
||||
self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
|
||||
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
@@ -330,7 +334,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_real_bert_model_from_pretrained_has_cross_attention(self):
|
||||
def test_real_bert_model_from_pretrained_add_cross_attention(self):
|
||||
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
|
||||
self.assertTrue(hasattr(model.decoder.bert.encoder.layer[0], "crossattention"))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user