diff --git a/src/transformers/configuration_encoder_decoder.py b/src/transformers/configuration_encoder_decoder.py index 785a4c654e..95cabaa82e 100644 --- a/src/transformers/configuration_encoder_decoder.py +++ b/src/transformers/configuration_encoder_decoder.py @@ -58,6 +58,7 @@ class EncoderDecoderConfig(PretrainedConfig): >>> config_decoder = model.config.decoder >>> # set decoder config to causal lm >>> config_decoder.is_decoder = True + >>> config_decoder.add_cross_attention = True >>> # Saving the model, including its configuration >>> model.save_pretrained('my-model') @@ -94,8 +95,9 @@ class EncoderDecoderConfig(PretrainedConfig): Returns: :class:`EncoderDecoderConfig`: An instance of a configuration object """ - logger.info("Set `config.is_decoder=True` for decoder_config") + logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config") decoder_config.is_decoder = True + decoder_config.add_cross_attention = True return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict()) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index af31087697..3e1d4bcbf4 100644 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -56,6 +56,8 @@ class PretrainedConfig(object): Whether the model is used as an encoder/decoder or not. is_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether the model is used as decoder or not (in which case it's used as an encoder). + add_cross_attention (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether cross-attention layers should be added to the model. Note, this option is only relevant for models that can be used as decoder models within the `:class:~transformers.EncoderDecoderModel` class, which consists of all models in ``AUTO_MODELS_FOR_CAUSAL_LM``. prune_heads (:obj:`Dict[int, List[int]]`, `optional`, defaults to :obj:`{}`): Pruned heads of the model. The keys are the selected layer indices and the associated values, the list of heads to prune in said layer. @@ -145,6 +147,7 @@ class PretrainedConfig(object): # Is decoder is used in encoder-decoder models to differentiate encoder from decoder self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False) self.is_decoder = kwargs.pop("is_decoder", False) + self.add_cross_attention = kwargs.pop("add_cross_attention", False) # Parameters for sequence generation self.max_length = kwargs.pop("max_length", 20) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 9d46495fbd..9605d29cb2 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -378,7 +378,9 @@ class BertLayer(nn.Module): super().__init__() self.attention = BertAttention(config) self.is_decoder = config.is_decoder - if self.is_decoder: + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added" self.crossattention = BertAttention(config) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) @@ -399,6 +401,9 @@ class BertLayer(nn.Module): outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: + assert hasattr( + self, "crossattention" + ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" cross_attention_outputs = self.crossattention( attention_output, attention_mask, @@ -695,8 +700,10 @@ class BertModel(BertPreTrainedModel): Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. To behave as an decoder the model needs to be initialized with the - :obj:`is_decoder` argument of the configuration set to :obj:`True`; an - :obj:`encoder_hidden_states` is expected as an input to the forward pass. + :obj:`is_decoder` argument of the configuration set to :obj:`True`. + To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder` + argument and :obj:`add_cross_attention` set to :obj:`True`; an + :obj:`encoder_hidden_states` is then expected as an input to the forward pass. .. _`Attention is all you need`: https://arxiv.org/abs/1706.03762 diff --git a/src/transformers/modeling_encoder_decoder.py b/src/transformers/modeling_encoder_decoder.py index 772ae74e22..6fdb961b65 100644 --- a/src/transformers/modeling_encoder_decoder.py +++ b/src/transformers/modeling_encoder_decoder.py @@ -168,17 +168,18 @@ class EncoderDecoderModel(PreTrainedModel): from .configuration_auto import AutoConfig decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) - if decoder_config.is_decoder is False: + if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: logger.info( f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." ) decoder_config.is_decoder = True + decoder_config.add_cross_attention = True kwargs_decoder["config"] = decoder_config - if kwargs_decoder["config"].is_decoder is False: + if kwargs_decoder["config"].is_decoder is False or decoder_config.add_cross_attention is False: logger.warning( - f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attribute `is_decoder` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` is set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`" + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`" ) decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index 60460aeb94..0ec9a0b874 100644 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -176,6 +176,7 @@ class BertModelTester: encoder_hidden_states, encoder_attention_mask, ): + config.add_cross_attention = True model = BertModel(config) model.to(torch_device) model.eval() @@ -235,6 +236,7 @@ class BertModelTester: encoder_hidden_states, encoder_attention_mask, ): + config.add_cross_attention = True model = BertLMHeadModel(config=config) model.to(torch_device) model.eval() diff --git a/tests/test_modeling_encoder_decoder.py b/tests/test_modeling_encoder_decoder.py index f46fbeb82a..e62d2fb563 100644 --- a/tests/test_modeling_encoder_decoder.py +++ b/tests/test_modeling_encoder_decoder.py @@ -59,6 +59,9 @@ class EncoderDecoderModelTest(unittest.TestCase): encoder_hidden_states, encoder_attention_mask, ) = decoder_config_and_inputs + + # make sure that cross attention layers are added + decoder_config.add_cross_attention = True return { "config": config, "input_ids": input_ids, @@ -119,6 +122,7 @@ class EncoderDecoderModelTest(unittest.TestCase): decoder_model = BertLMHeadModel(decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) self.assertTrue(enc_dec_model.config.decoder.is_decoder) + self.assertTrue(enc_dec_model.config.decoder.add_cross_attention) self.assertTrue(enc_dec_model.config.is_encoder_decoder) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( @@ -330,7 +334,7 @@ class EncoderDecoderModelTest(unittest.TestCase): self.assertIsNotNone(model) @slow - def test_real_bert_model_from_pretrained_has_cross_attention(self): + def test_real_bert_model_from_pretrained_add_cross_attention(self): model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased") self.assertTrue(hasattr(model.decoder.bert.encoder.layer[0], "crossattention"))