From 0735def8e1200ed45a2c33a075bc1595b12ef56a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 12 Aug 2020 18:23:30 +0200 Subject: [PATCH] [EncoderDecoder] Add encoder-decoder for roberta/ vanilla longformer (#6411) * add encoder-decoder for roberta * fix headmask * apply Sylvains suggestions * fix typo * Apply suggestions from code review --- docs/source/model_doc/roberta.rst | 7 + src/transformers/__init__.py | 1 + src/transformers/modeling_auto.py | 2 + src/transformers/modeling_bert.py | 39 +- src/transformers/modeling_encoder_decoder.py | 12 - src/transformers/modeling_roberta.py | 133 +++- src/transformers/modeling_tf_bert.py | 4 +- tests/test_modeling_bert.py | 52 +- tests/test_modeling_encoder_decoder.py | 638 ++++++++++--------- tests/test_modeling_roberta.py | 140 +++- 10 files changed, 671 insertions(+), 357 deletions(-) diff --git a/docs/source/model_doc/roberta.rst b/docs/source/model_doc/roberta.rst index fae0d91a29..ac83dde4fc 100644 --- a/docs/source/model_doc/roberta.rst +++ b/docs/source/model_doc/roberta.rst @@ -63,6 +63,13 @@ RobertaModel :members: +RobertaForCausalLM +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.RobertaForCausalLM + :members: + + RobertaForMaskedLM ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 956a03f283..b9ddf890eb 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -302,6 +302,7 @@ if is_torch_available(): from .tokenization_marian import MarianTokenizer from .modeling_roberta import ( RobertaForMaskedLM, + RobertaForCausalLM, RobertaModel, RobertaForSequenceClassification, RobertaForMultipleChoice, diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 0d1452ee24..7a45267703 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -135,6 +135,7 @@ from .modeling_reformer import ( ) from .modeling_retribert import RetriBertModel from .modeling_roberta import ( + RobertaForCausalLM, RobertaForMaskedLM, RobertaForMultipleChoice, RobertaForQuestionAnswering, @@ -250,6 +251,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( [ + (RobertaConfig, RobertaForCausalLM), (BertConfig, BertLMHeadModel), (OpenAIGPTConfig, OpenAIGPTLMHeadModel), (GPT2Config, GPT2LMHeadModel), diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index a40d54f611..0c956f82bb 100755 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -683,14 +683,6 @@ BERT_INPUTS_DOCSTRING = r""" Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention - if the model is configured as a decoder. - encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask - is used in the cross-attention if the model is configured as a decoder. - Mask values selected in ``[0, 1]``: - ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`): @@ -769,6 +761,16 @@ class BertModel(BertPreTrainedModel): output_hidden_states=None, return_dict=None, ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask + is used in the cross-attention if the model is configured as a decoder. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -956,7 +958,7 @@ class BertLMHeadModel(BertPreTrainedModel): super().__init__(config) if not config.is_decoder: - logger.info("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`") + logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`") self.bert = BertModel(config) self.cls = BertOnlyMLMHead(config) @@ -976,22 +978,27 @@ class BertLMHeadModel(BertPreTrainedModel): position_ids=None, head_mask=None, inputs_embeds=None, - labels=None, encoder_hidden_states=None, encoder_attention_mask=None, + labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, - **kwargs ): r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask + is used in the cross-attention if the model is configured as a decoder. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` - kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): - Used to hide legacy arguments that have been deprecated. Returns: @@ -1061,8 +1068,8 @@ class BertForMaskedLM(BertPreTrainedModel): super().__init__(config) if config.is_decoder: - logger.info( - "If you want to use `TFBertForMaskedLM` make sure `config.is_decoder=False` for " + logger.warning( + "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " "bi-directional self-attention." ) @@ -1089,9 +1096,9 @@ class BertForMaskedLM(BertPreTrainedModel): position_ids=None, head_mask=None, inputs_embeds=None, - labels=None, encoder_hidden_states=None, encoder_attention_mask=None, + labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, diff --git a/src/transformers/modeling_encoder_decoder.py b/src/transformers/modeling_encoder_decoder.py index 6fdb961b65..0ee5f2962b 100644 --- a/src/transformers/modeling_encoder_decoder.py +++ b/src/transformers/modeling_encoder_decoder.py @@ -191,11 +191,9 @@ class EncoderDecoderModel(PreTrainedModel): input_ids=None, inputs_embeds=None, attention_mask=None, - head_mask=None, encoder_outputs=None, decoder_input_ids=None, decoder_attention_mask=None, - decoder_head_mask=None, decoder_inputs_embeds=None, labels=None, **kwargs, @@ -216,10 +214,6 @@ class EncoderDecoderModel(PreTrainedModel): Mask to avoid performing attention on padding token indices for the encoder. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. - head_mask: (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): - Mask to nullify selected heads of the self-attention modules for the encoder. - Mask values selected in ``[0, 1]``: - ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`): Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`) `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder. @@ -231,10 +225,6 @@ class EncoderDecoderModel(PreTrainedModel): :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. - decoder_head_mask: (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): - Mask to nullify selected heads of the self-attention modules for the decoder. - Mask values selected in ``[0, 1]``: - ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors @@ -279,7 +269,6 @@ class EncoderDecoderModel(PreTrainedModel): input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, - head_mask=head_mask, return_dict=False, **kwargs_encoder, ) @@ -293,7 +282,6 @@ class EncoderDecoderModel(PreTrainedModel): attention_mask=decoder_attention_mask, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, labels=labels, return_dict=False, **kwargs_decoder, diff --git a/src/transformers/modeling_roberta.py b/src/transformers/modeling_roberta.py index 7779e81ece..59030ab28b 100644 --- a/src/transformers/modeling_roberta.py +++ b/src/transformers/modeling_roberta.py @@ -24,9 +24,15 @@ import torch.nn as nn from torch.nn import CrossEntropyLoss, MSELoss from .configuration_roberta import RobertaConfig -from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable +from .file_utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_callable, + replace_return_docstrings, +) from .modeling_bert import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel, gelu from .modeling_outputs import ( + CausalLMOutput, MaskedLMOutput, MultipleChoiceModelOutput, QuestionAnsweringModelOutput, @@ -175,6 +181,121 @@ class RobertaModel(BertModel): self.embeddings.word_embeddings = value +@add_start_docstrings( + """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING +) +class RobertaForCausalLM(BertPreTrainedModel): + config_class = RobertaConfig + base_model_prefix = "roberta" + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.roberta = RobertaModel(config) + self.lm_head = RobertaLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.lm_head.decoder + + @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask + is used in the cross-attention if the model is configured as a decoder. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Labels for computing the left-to-right language modeling loss (next word prediction). + Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels + in ``[0, ..., config.vocab_size]`` + + Returns: + + Example:: + + >>> from transformers import RobertaTokenizer, RobertaLMHeadModel, RobertaConfig + >>> import torch + + >>> tokenizer = RobertaTokenizer.from_pretrained('roberta-base') + >>> config = RobertaConfig.from_pretrained("roberta-base") + >>> config.is_decoder = True + >>> model = RobertaLMHeadModel.from_pretrained('roberta-base', config=config, return_dict=True) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutput( + loss=lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING) class RobertaForMaskedLM(BertPreTrainedModel): config_class = RobertaConfig @@ -183,6 +304,12 @@ class RobertaForMaskedLM(BertPreTrainedModel): def __init__(self, config): super().__init__(config) + if config.is_decoder: + logger.warning( + "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + self.roberta = RobertaModel(config) self.lm_head = RobertaLMHead(config) @@ -206,6 +333,8 @@ class RobertaForMaskedLM(BertPreTrainedModel): position_ids=None, head_mask=None, inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, labels=None, output_attentions=None, output_hidden_states=None, @@ -237,6 +366,8 @@ class RobertaForMaskedLM(BertPreTrainedModel): position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index 532997e097..a66724a371 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -862,7 +862,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): super().__init__(config, *inputs, **kwargs) if config.is_decoder: - logger.info( + logger.warning( "If you want to use `TFBertForMaskedLM` make sure `config.is_decoder=False` for " "bi-directional self-attention." ) @@ -941,7 +941,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): super().__init__(config, *inputs, **kwargs) if not config.is_decoder: - logger.info("If you want to use `TFBertLMHeadModel` as a standalone, add `is_decoder=True.`") + logger.warning("If you want to use `TFBertLMHeadModel` as a standalone, add `is_decoder=True.`") self.bert = TFBertMainLayer(config, name="bert") self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls") diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index 87382337d5..200b567666 100644 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -152,7 +152,7 @@ class BertModelTester: encoder_attention_mask, ) - def create_and_check_bert_model( + def create_and_check_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = BertModel(config=config) @@ -164,7 +164,7 @@ class BertModelTester: self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) - def create_and_check_bert_model_as_decoder( + def create_and_check_model_as_decoder( self, config, input_ids, @@ -197,7 +197,7 @@ class BertModelTester: self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) - def create_and_check_bert_for_causal_lm( + def create_and_check_for_causal_lm( self, config, input_ids, @@ -215,7 +215,7 @@ class BertModelTester: result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - def create_and_check_bert_for_masked_lm( + def create_and_check_for_masked_lm( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = BertForMaskedLM(config=config) @@ -224,7 +224,7 @@ class BertModelTester: result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - def create_and_check_bert_model_for_causal_lm_as_decoder( + def create_and_check_model_for_causal_lm_as_decoder( self, config, input_ids, @@ -257,7 +257,7 @@ class BertModelTester: ) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - def create_and_check_bert_for_next_sequence_prediction( + def create_and_check_for_next_sequence_prediction( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = BertForNextSentencePrediction(config=config) @@ -268,7 +268,7 @@ class BertModelTester: ) self.parent.assertEqual(result.logits.shape, (self.batch_size, 2)) - def create_and_check_bert_for_pretraining( + def create_and_check_for_pretraining( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = BertForPreTraining(config=config) @@ -284,7 +284,7 @@ class BertModelTester: self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2)) - def create_and_check_bert_for_question_answering( + def create_and_check_for_question_answering( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = BertForQuestionAnswering(config=config) @@ -300,7 +300,7 @@ class BertModelTester: self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) - def create_and_check_bert_for_sequence_classification( + def create_and_check_for_sequence_classification( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.num_labels = self.num_labels @@ -310,7 +310,7 @@ class BertModelTester: result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) - def create_and_check_bert_for_token_classification( + def create_and_check_for_token_classification( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.num_labels = self.num_labels @@ -320,7 +320,7 @@ class BertModelTester: result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) - def create_and_check_bert_for_multiple_choice( + def create_and_check_for_multiple_choice( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.num_choices = self.num_choices @@ -379,15 +379,15 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase): def test_config(self): self.config_tester.run_common_tests() - def test_bert_model(self): + def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_model(*config_and_inputs) + self.model_tester.create_and_check_model(*config_and_inputs) - def test_bert_model_as_decoder(self): + def test_model_as_decoder(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() - self.model_tester.create_and_check_bert_model_as_decoder(*config_and_inputs) + self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) - def test_bert_model_as_decoder_with_default_input_mask(self): + def test_model_as_decoder_with_default_input_mask(self): # This regression test was failing with PyTorch < 1.3 ( config, @@ -403,7 +403,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase): input_mask = None - self.model_tester.create_and_check_bert_model_as_decoder( + self.model_tester.create_and_check_model_as_decoder( config, input_ids, token_type_ids, @@ -417,39 +417,39 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase): def test_for_causal_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() - self.model_tester.create_and_check_bert_for_causal_lm(*config_and_inputs) + self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs) + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) def test_for_causal_lm_decoder(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() - self.model_tester.create_and_check_bert_model_for_causal_lm_as_decoder(*config_and_inputs) + self.model_tester.create_and_check_model_for_causal_lm_as_decoder(*config_and_inputs) def test_for_multiple_choice(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs) + self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) def test_for_next_sequence_prediction(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs) + self.model_tester.create_and_check_for_next_sequence_prediction(*config_and_inputs) def test_for_pretraining(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs) + self.model_tester.create_and_check_for_pretraining(*config_and_inputs) def test_for_question_answering(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_question_answering(*config_and_inputs) + self.model_tester.create_and_check_for_question_answering(*config_and_inputs) def test_for_sequence_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_sequence_classification(*config_and_inputs) + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) def test_for_token_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs) + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) @slow def test_model_from_pretrained(self): diff --git a/tests/test_modeling_encoder_decoder.py b/tests/test_modeling_encoder_decoder.py index e62d2fb563..5bc21ca0c4 100644 --- a/tests/test_modeling_encoder_decoder.py +++ b/tests/test_modeling_encoder_decoder.py @@ -24,21 +24,317 @@ from transformers.testing_utils import require_torch, slow, torch_device # for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest from .test_modeling_bert import BertModelTester from .test_modeling_common import ids_tensor +from .test_modeling_roberta import RobertaModelTester if is_torch_available(): - from transformers import BertModel, EncoderDecoderModel, EncoderDecoderConfig - from transformers.modeling_bert import BertLMHeadModel + from transformers import ( + BertModel, + BertLMHeadModel, + RobertaModel, + RobertaForCausalLM, + EncoderDecoderModel, + EncoderDecoderConfig, + ) import numpy as np import torch @require_torch -class EncoderDecoderModelTest(unittest.TestCase): - def prepare_config_and_inputs_bert(self): - bert_model_tester = BertModelTester(self) - encoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs() - decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder() +class EncoderDecoderMixin: + def get_encoder_decoder_model(self, config, decoder_config): + pass + + def prepare_config_and_inputs(self): + pass + + def get_pretrained_model(self): + pass + + def check_encoder_decoder_model_from_pretrained_configs( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + **kwargs + ): + encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) + self.assertTrue(encoder_decoder_config.decoder.is_decoder) + + enc_dec_model = EncoderDecoderModel(encoder_decoder_config) + enc_dec_model.to(torch_device) + enc_dec_model.eval() + + self.assertTrue(enc_dec_model.config.is_encoder_decoder) + + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))) + self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,))) + + def check_encoder_decoder_model( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + **kwargs + ): + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + self.assertTrue(enc_dec_model.config.decoder.is_decoder) + self.assertTrue(enc_dec_model.config.decoder.add_cross_attention) + self.assertTrue(enc_dec_model.config.is_encoder_decoder) + enc_dec_model.to(torch_device) + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))) + self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,))) + encoder_outputs = (encoder_hidden_states,) + outputs_encoder_decoder = enc_dec_model( + encoder_outputs=encoder_outputs, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))) + self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,))) + + def check_encoder_decoder_model_from_pretrained( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + **kwargs + ): + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model} + enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs) + enc_dec_model.to(torch_device) + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))) + self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,))) + + def check_save_and_load( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + **kwargs + ): + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + enc_dec_model.to(torch_device) + enc_dec_model.eval() + with torch.no_grad(): + outputs = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + out_2 = outputs[0].cpu().numpy() + out_2[np.isnan(out_2)] = 0 + + with tempfile.TemporaryDirectory() as tmpdirname: + enc_dec_model.save_pretrained(tmpdirname) + EncoderDecoderModel.from_pretrained(tmpdirname) + + after_outputs = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + out_1 = after_outputs[0].cpu().numpy() + out_1[np.isnan(out_1)] = 0 + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5) + + def check_save_and_load_encoder_decoder_model( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + **kwargs + ): + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + enc_dec_model.to(torch_device) + enc_dec_model.eval() + with torch.no_grad(): + outputs = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + out_2 = outputs[0].cpu().numpy() + out_2[np.isnan(out_2)] = 0 + + with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname: + enc_dec_model.encoder.save_pretrained(encoder_tmp_dirname) + enc_dec_model.decoder.save_pretrained(decoder_tmp_dirname) + EncoderDecoderModel.from_encoder_decoder_pretrained( + encoder_pretrained_model_name_or_path=encoder_tmp_dirname, + decoder_pretrained_model_name_or_path=decoder_tmp_dirname, + ) + + after_outputs = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + out_1 = after_outputs[0].cpu().numpy() + out_1[np.isnan(out_1)] = 0 + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5) + + def check_encoder_decoder_model_labels( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + labels, + **kwargs + ): + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + enc_dec_model.to(torch_device) + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + labels=labels, + ) + + mlm_loss = outputs_encoder_decoder[0] + # check that backprop works + mlm_loss.backward() + + self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))) + self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,))) + + def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs): + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + enc_dec_model.to(torch_device) + + # Bert does not have a bos token id, so use pad_token_id instead + generated_output = enc_dec_model.generate( + input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id + ) + self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,)) + + def test_encoder_decoder_model(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model(**input_ids_dict) + + def test_encoder_decoder_model_from_pretrained_configs(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict) + + def test_encoder_decoder_model_from_pretrained(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_from_pretrained(**input_ids_dict) + + def test_save_and_load_from_pretrained(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_save_and_load(**input_ids_dict) + + def test_save_and_load_from_encoder_decoder_pretrained(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_save_and_load_encoder_decoder_model(**input_ids_dict) + + def test_encoder_decoder_model_labels(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_labels(**input_ids_dict) + + def test_encoder_decoder_model_generate(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_generate(**input_ids_dict) + + @slow + def test_real_model_save_load_from_pretrained(self): + model_2 = self.get_pretrained_model() + model_2.to(torch_device) + input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size) + decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size) + attention_mask = ids_tensor([13, 5], vocab_size=2) + with torch.no_grad(): + outputs = model_2(input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask,) + out_2 = outputs[0].cpu().numpy() + out_2[np.isnan(out_2)] = 0 + + with tempfile.TemporaryDirectory() as tmp_dirname: + model_2.save_pretrained(tmp_dirname) + model_1 = EncoderDecoderModel.from_pretrained(tmp_dirname) + model_1.to(torch_device) + + after_outputs = model_1( + input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, + ) + out_1 = after_outputs[0].cpu().numpy() + out_1[np.isnan(out_1)] = 0 + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5) + + +class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): + def get_pretrained_model(self): + return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "bert-base-cased") + + def get_encoder_decoder_model(self, config, decoder_config): + encoder_model = BertModel(config) + decoder_model = BertLMHeadModel(decoder_config) + return encoder_model, decoder_model + + def prepare_config_and_inputs(self): + model_tester = BertModelTester(self) + encoder_config_and_inputs = model_tester.prepare_config_and_inputs() + decoder_config_and_inputs = model_tester.prepare_config_and_inputs_for_decoder() ( config, input_ids, @@ -77,288 +373,54 @@ class EncoderDecoderModelTest(unittest.TestCase): "labels": decoder_token_labels, } - def create_and_check_bert_encoder_decoder_model_from_pretrained_configs( - self, - config, - input_ids, - attention_mask, - encoder_hidden_states, - decoder_config, - decoder_input_ids, - decoder_attention_mask, - **kwargs - ): - encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) - self.assertTrue(encoder_decoder_config.decoder.is_decoder) - enc_dec_model = EncoderDecoderModel(encoder_decoder_config) - enc_dec_model.to(torch_device) - enc_dec_model.eval() +class RoBertaEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): + def get_encoder_decoder_model(self, config, decoder_config): + encoder_model = RobertaModel(config) + decoder_model = RobertaForCausalLM(decoder_config) + return encoder_model, decoder_model - self.assertTrue(enc_dec_model.config.is_encoder_decoder) + def prepare_config_and_inputs(self): + model_tester = RobertaModelTester(self) + encoder_config_and_inputs = model_tester.prepare_config_and_inputs() + decoder_config_and_inputs = model_tester.prepare_config_and_inputs_for_decoder() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = encoder_config_and_inputs + ( + decoder_config, + decoder_input_ids, + decoder_token_type_ids, + decoder_input_mask, + decoder_sequence_labels, + decoder_token_labels, + decoder_choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) = decoder_config_and_inputs - outputs_encoder_decoder = enc_dec_model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) + # make sure that cross attention layers are added + decoder_config.add_cross_attention = True + return { + "config": config, + "input_ids": input_ids, + "attention_mask": input_mask, + "decoder_config": decoder_config, + "decoder_input_ids": decoder_input_ids, + "decoder_token_type_ids": decoder_token_type_ids, + "decoder_attention_mask": decoder_input_mask, + "decoder_sequence_labels": decoder_sequence_labels, + "decoder_token_labels": decoder_token_labels, + "decoder_choice_labels": decoder_choice_labels, + "encoder_hidden_states": encoder_hidden_states, + "labels": decoder_token_labels, + } - self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))) - self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,))) - - def create_and_check_bert_encoder_decoder_model( - self, - config, - input_ids, - attention_mask, - encoder_hidden_states, - decoder_config, - decoder_input_ids, - decoder_attention_mask, - **kwargs - ): - encoder_model = BertModel(config) - decoder_model = BertLMHeadModel(decoder_config) - enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) - self.assertTrue(enc_dec_model.config.decoder.is_decoder) - self.assertTrue(enc_dec_model.config.decoder.add_cross_attention) - self.assertTrue(enc_dec_model.config.is_encoder_decoder) - enc_dec_model.to(torch_device) - outputs_encoder_decoder = enc_dec_model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - - self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))) - self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,))) - encoder_outputs = (encoder_hidden_states,) - outputs_encoder_decoder = enc_dec_model( - encoder_outputs=encoder_outputs, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - - self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))) - self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,))) - - def create_and_check_bert_encoder_decoder_model_from_pretrained( - self, - config, - input_ids, - attention_mask, - encoder_hidden_states, - decoder_config, - decoder_input_ids, - decoder_attention_mask, - **kwargs - ): - encoder_model = BertModel(config) - decoder_model = BertLMHeadModel(decoder_config) - kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model} - enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs) - enc_dec_model.to(torch_device) - outputs_encoder_decoder = enc_dec_model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - - self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))) - self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,))) - - def create_and_check_save_and_load( - self, - config, - input_ids, - attention_mask, - encoder_hidden_states, - decoder_config, - decoder_input_ids, - decoder_attention_mask, - **kwargs - ): - encoder_model = BertModel(config) - decoder_model = BertLMHeadModel(decoder_config) - enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) - enc_dec_model.to(torch_device) - enc_dec_model.eval() - with torch.no_grad(): - outputs = enc_dec_model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - out_2 = outputs[0].cpu().numpy() - out_2[np.isnan(out_2)] = 0 - - with tempfile.TemporaryDirectory() as tmpdirname: - enc_dec_model.save_pretrained(tmpdirname) - EncoderDecoderModel.from_pretrained(tmpdirname) - - after_outputs = enc_dec_model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - out_1 = after_outputs[0].cpu().numpy() - out_1[np.isnan(out_1)] = 0 - max_diff = np.amax(np.abs(out_1 - out_2)) - self.assertLessEqual(max_diff, 1e-5) - - def create_and_check_save_and_load_encoder_decoder_model( - self, - config, - input_ids, - attention_mask, - encoder_hidden_states, - decoder_config, - decoder_input_ids, - decoder_attention_mask, - **kwargs - ): - encoder_model = BertModel(config) - decoder_model = BertLMHeadModel(decoder_config) - enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) - enc_dec_model.to(torch_device) - enc_dec_model.eval() - with torch.no_grad(): - outputs = enc_dec_model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - out_2 = outputs[0].cpu().numpy() - out_2[np.isnan(out_2)] = 0 - - with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname: - enc_dec_model.encoder.save_pretrained(encoder_tmp_dirname) - enc_dec_model.decoder.save_pretrained(decoder_tmp_dirname) - EncoderDecoderModel.from_encoder_decoder_pretrained( - encoder_pretrained_model_name_or_path=encoder_tmp_dirname, - decoder_pretrained_model_name_or_path=decoder_tmp_dirname, - ) - - after_outputs = enc_dec_model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - out_1 = after_outputs[0].cpu().numpy() - out_1[np.isnan(out_1)] = 0 - max_diff = np.amax(np.abs(out_1 - out_2)) - self.assertLessEqual(max_diff, 1e-5) - - def create_and_check_bert_encoder_decoder_model_labels( - self, - config, - input_ids, - attention_mask, - encoder_hidden_states, - decoder_config, - decoder_input_ids, - decoder_attention_mask, - labels, - **kwargs - ): - encoder_model = BertModel(config) - decoder_model = BertLMHeadModel(decoder_config) - enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) - enc_dec_model.to(torch_device) - outputs_encoder_decoder = enc_dec_model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - labels=labels, - ) - - mlm_loss = outputs_encoder_decoder[0] - # check that backprop works - mlm_loss.backward() - - self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))) - self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,))) - - def create_and_check_bert_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs): - encoder_model = BertModel(config) - decoder_model = BertLMHeadModel(decoder_config) - enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) - enc_dec_model.to(torch_device) - - # Bert does not have a bos token id, so use pad_token_id instead - generated_output = enc_dec_model.generate( - input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id - ) - self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,)) - - def test_bert_encoder_decoder_model(self): - input_ids_dict = self.prepare_config_and_inputs_bert() - self.create_and_check_bert_encoder_decoder_model(**input_ids_dict) - - def test_bert_encoder_decoder_model_from_pretrained_configs(self): - input_ids_dict = self.prepare_config_and_inputs_bert() - self.create_and_check_bert_encoder_decoder_model_from_pretrained_configs(**input_ids_dict) - - def test_bert_encoder_decoder_model_from_pretrained(self): - input_ids_dict = self.prepare_config_and_inputs_bert() - self.create_and_check_bert_encoder_decoder_model_from_pretrained(**input_ids_dict) - - def test_save_and_load_from_pretrained(self): - input_ids_dict = self.prepare_config_and_inputs_bert() - self.create_and_check_save_and_load(**input_ids_dict) - - def test_save_and_load_from_encoder_decoder_pretrained(self): - input_ids_dict = self.prepare_config_and_inputs_bert() - self.create_and_check_save_and_load_encoder_decoder_model(**input_ids_dict) - - def test_bert_encoder_decoder_model_labels(self): - input_ids_dict = self.prepare_config_and_inputs_bert() - self.create_and_check_bert_encoder_decoder_model_labels(**input_ids_dict) - - def test_bert_encoder_decoder_model_generate(self): - input_ids_dict = self.prepare_config_and_inputs_bert() - self.create_and_check_bert_encoder_decoder_model_generate(**input_ids_dict) - - @slow - def test_real_bert_model_from_pretrained(self): - model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased") - self.assertIsNotNone(model) - - @slow - def test_real_bert_model_from_pretrained_add_cross_attention(self): - model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased") - self.assertTrue(hasattr(model.decoder.bert.encoder.layer[0], "crossattention")) - - @slow - def test_real_bert_model_save_load_from_pretrained(self): - model_2 = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased") - model_2.to(torch_device) - input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size) - decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size) - attention_mask = ids_tensor([13, 5], vocab_size=2) - with torch.no_grad(): - outputs = model_2(input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask,) - out_2 = outputs[0].cpu().numpy() - out_2[np.isnan(out_2)] = 0 - - with tempfile.TemporaryDirectory() as tmp_dirname: - model_2.save_pretrained(tmp_dirname) - model_1 = EncoderDecoderModel.from_pretrained(tmp_dirname) - model_1.to(torch_device) - - after_outputs = model_1( - input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, - ) - out_1 = after_outputs[0].cpu().numpy() - out_1[np.isnan(out_1)] = 0 - max_diff = np.amax(np.abs(out_1 - out_2)) - self.assertLessEqual(max_diff, 1e-5) + def get_pretrained_model(self): + return EncoderDecoderModel.from_encoder_decoder_pretrained("roberta-base", "roberta-base") diff --git a/tests/test_modeling_roberta.py b/tests/test_modeling_roberta.py index 00b0b79e54..ddf4695127 100644 --- a/tests/test_modeling_roberta.py +++ b/tests/test_modeling_roberta.py @@ -20,7 +20,7 @@ from transformers import is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from .test_configuration_common import ConfigTester -from .test_modeling_common import ModelTesterMixin, ids_tensor +from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor if is_torch_available(): @@ -28,6 +28,7 @@ if is_torch_available(): from transformers import ( RobertaConfig, RobertaModel, + RobertaForCausalLM, RobertaForMaskedLM, RobertaForMultipleChoice, RobertaForQuestionAnswering, @@ -101,7 +102,34 @@ class RobertaModelTester: return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - def create_and_check_roberta_model( + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + config.is_decoder = True + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def create_and_check_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = RobertaModel(config=config) @@ -114,7 +142,58 @@ class RobertaModelTester: self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) - def create_and_check_roberta_for_masked_lm( + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + model = RobertaModel(config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + encoder_hidden_states=encoder_hidden_states, + ) + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + def create_and_check_for_causal_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + model = RobertaForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_for_masked_lm( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = RobertaForMaskedLM(config=config) @@ -123,7 +202,7 @@ class RobertaModelTester: result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - def create_and_check_roberta_for_token_classification( + def create_and_check_for_token_classification( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.num_labels = self.num_labels @@ -133,7 +212,7 @@ class RobertaModelTester: result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) - def create_and_check_roberta_for_multiple_choice( + def create_and_check_for_multiple_choice( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.num_choices = self.num_choices @@ -151,7 +230,7 @@ class RobertaModelTester: ) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices)) - def create_and_check_roberta_for_question_answering( + def create_and_check_for_question_answering( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = RobertaForQuestionAnswering(config=config) @@ -187,6 +266,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = ( ( + RobertaForCausalLM, RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification, @@ -205,25 +285,61 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase): def test_config(self): self.config_tester.run_common_tests() - def test_roberta_model(self): + def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_roberta_model(*config_and_inputs) + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_as_decoder(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) + + def test_model_as_decoder_with_default_input_mask(self): + # This regression test was failing with PyTorch < 1.3 + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) = self.model_tester.prepare_config_and_inputs_for_decoder() + + input_mask = None + + self.model_tester.create_and_check_model_as_decoder( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def test_for_causal_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_roberta_for_masked_lm(*config_and_inputs) + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) def test_for_token_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_roberta_for_token_classification(*config_and_inputs) + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) def test_for_multiple_choice(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_roberta_for_multiple_choice(*config_and_inputs) + self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) def test_for_question_answering(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_roberta_for_question_answering(*config_and_inputs) + self.model_tester.create_and_check_for_question_answering(*config_and_inputs) @slow def test_model_from_pretrained(self):