From 501307b58bdc2db1b6a25271a3f60975130f1c6c Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Mon, 27 Dec 2021 12:37:52 +0100 Subject: [PATCH] Add `ElectraForCausalLM` -> Enable Electra encoder-decoder model (#14729) * Add ElectraForCausalLM and cover some basic tests & need to fix a few tests * Fix bugs * make style * make fix-copies * Update doc * Change docstring to markdown format * Remove redundant update_keys_to_ignore --- docs/source/model_doc/electra.mdx | 5 + src/transformers/__init__.py | 2 + src/transformers/models/auto/modeling_auto.py | 1 + src/transformers/models/electra/__init__.py | 2 + .../models/electra/modeling_electra.py | 243 +++++++++++++++--- src/transformers/utils/dummy_pt_objects.py | 12 + tests/test_modeling_electra.py | 92 ++++++- 7 files changed, 323 insertions(+), 34 deletions(-) diff --git a/docs/source/model_doc/electra.mdx b/docs/source/model_doc/electra.mdx index 8e429891a6..aa1df2d503 100644 --- a/docs/source/model_doc/electra.mdx +++ b/docs/source/model_doc/electra.mdx @@ -83,6 +83,11 @@ This model was contributed by [lysandre](https://huggingface.co/lysandre). The o [[autodoc]] ElectraForPreTraining - forward +## ElectraForCausalLM + +[[autodoc]] ElectraForCausalLM + - forward + ## ElectraForMaskedLM [[autodoc]] ElectraForMaskedLM diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a0044f849c..a691ad7c8e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -885,6 +885,7 @@ if is_torch_available(): _import_structure["models.electra"].extend( [ "ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST", + "ElectraForCausalLM", "ElectraForMaskedLM", "ElectraForMultipleChoice", "ElectraForPreTraining", @@ -2830,6 +2831,7 @@ if TYPE_CHECKING: ) from .models.electra import ( ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST, + ElectraForCausalLM, ElectraForMaskedLM, ElectraForMultipleChoice, ElectraForPreTraining, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 003cc088b6..73d5089ed7 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -218,6 +218,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("transfo-xl", "TransfoXLLMHeadModel"), ("xlnet", "XLNetLMHeadModel"), ("xlm", "XLMWithLMHeadModel"), + ("electra", "ElectraForCausalLM"), ("ctrl", "CTRLLMHeadModel"), ("reformer", "ReformerModelWithLMHead"), ("bert-generation", "BertGenerationDecoder"), diff --git a/src/transformers/models/electra/__init__.py b/src/transformers/models/electra/__init__.py index 8cff3d2236..794abb2154 100644 --- a/src/transformers/models/electra/__init__.py +++ b/src/transformers/models/electra/__init__.py @@ -32,6 +32,7 @@ if is_tokenizers_available(): if is_torch_available(): _import_structure["modeling_electra"] = [ "ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST", + "ElectraForCausalLM", "ElectraForMaskedLM", "ElectraForMultipleChoice", "ElectraForPreTraining", @@ -79,6 +80,7 @@ if TYPE_CHECKING: if is_torch_available(): from .modeling_electra import ( ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST, + ElectraForCausalLM, ElectraForMaskedLM, ElectraForMultipleChoice, ElectraForPreTraining, diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 0c4d0e626f..8c343e4c68 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -36,6 +36,7 @@ from ...file_utils import ( from ...modeling_outputs import ( BaseModelOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, MaskedLMOutput, MultipleChoiceModelOutput, QuestionAnsweringModelOutput, @@ -846,6 +847,10 @@ class ElectraModel(ElectraPreTrainedModel): position_ids=None, head_mask=None, inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -868,6 +873,9 @@ class ElectraModel(ElectraPreTrainedModel): batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if attention_mask is None: attention_mask = torch.ones(input_shape, device=device) if token_type_ids is None: @@ -879,10 +887,26 @@ class ElectraModel(ElectraPreTrainedModel): token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) hidden_states = self.embeddings( - input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, ) if hasattr(self, "embeddings_project"): @@ -892,6 +916,10 @@ class ElectraModel(ElectraPreTrainedModel): hidden_states, attention_mask=extended_attention_mask, head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -969,14 +997,14 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel): discriminator_hidden_states = self.electra( input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, ) sequence_output = discriminator_hidden_states[0] @@ -1075,14 +1103,14 @@ class ElectraForPreTraining(ElectraPreTrainedModel): discriminator_hidden_states = self.electra( input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, ) discriminator_sequence_output = discriminator_hidden_states[0] @@ -1166,14 +1194,14 @@ class ElectraForMaskedLM(ElectraPreTrainedModel): generator_hidden_states = self.electra( input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, ) generator_sequence_output = generator_hidden_states[0] @@ -1247,14 +1275,14 @@ class ElectraForTokenClassification(ElectraPreTrainedModel): discriminator_hidden_states = self.electra( input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, ) discriminator_sequence_output = discriminator_hidden_states[0] @@ -1481,3 +1509,152 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel): hidden_states=discriminator_hidden_states.hidden_states, attentions=discriminator_hidden_states.attentions, ) + + +@add_start_docstrings( + """ELECTRA Model with a `language modeling` head on top for CLM fine-tuning. """, ELECTRA_START_DOCSTRING +) +class ElectraForCausalLM(ElectraPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `ElectraLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.electra = ElectraModel(config) + self.generator_predictions = ElectraGeneratorPredictions(config) + self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size) + + self.init_weights() + + def get_output_embeddings(self): + return self.generator_lm_head + + def set_output_embeddings(self, new_embeddings): + self.generator_lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` + (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` + instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up + decoding (see `past_key_values`). + + Returns: + + Example: + + ```python + >>> from transformers import ElectraTokenizer, ElectraForCausalLM, ElectraConfig + >>> import torch + + >>> tokenizer = ElectraTokenizer.from_pretrained("google/electra-base-generator") + >>> config = ElectraConfig.from_pretrained("google/electra-base-generator") + >>> config.is_decoder = True + >>> model = ElectraForCausalLM.from_pretrained("google/electra-base-generator", config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.electra( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.generator_lm_head(self.generator_predictions(sequence_output)) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} + + # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM._reorder_cache + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 02afebeca0..18abff2d48 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1924,6 +1924,18 @@ class DPRReader: ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = None +class ElectraForCausalLM: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def forward(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ElectraForMaskedLM: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) diff --git a/tests/test_modeling_electra.py b/tests/test_modeling_electra.py index 933e542025..be19f8d610 100644 --- a/tests/test_modeling_electra.py +++ b/tests/test_modeling_electra.py @@ -21,7 +21,7 @@ from transformers.models.auto import get_values from transformers.testing_utils import require_torch, slow, torch_device from .test_configuration_common import ConfigTester -from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask +from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask if is_torch_available(): @@ -29,6 +29,7 @@ if is_torch_available(): from transformers import ( MODEL_FOR_PRETRAINING_MAPPING, + ElectraForCausalLM, ElectraForMaskedLM, ElectraForMultipleChoice, ElectraForPreTraining, @@ -117,6 +118,34 @@ class ElectraModelTester: initializer_range=self.initializer_range, ) + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + _, + ) = self.prepare_config_and_inputs() + + config.is_decoder = True + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + def create_and_check_electra_model( self, config, @@ -136,6 +165,38 @@ class ElectraModelTester: result = model(input_ids) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + def create_and_check_electra_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + model = ElectraModel(config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + encoder_hidden_states=encoder_hidden_states, + ) + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + def create_and_check_electra_for_masked_lm( self, config, @@ -153,6 +214,24 @@ class ElectraModelTester: result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + def create_and_check_electra_for_causal_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + model = ElectraForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + def create_and_check_electra_for_token_classification( self, config, @@ -281,6 +360,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): ElectraModel, ElectraForPreTraining, ElectraForMaskedLM, + ElectraForCausalLM, ElectraForMultipleChoice, ElectraForTokenClassification, ElectraForSequenceClassification, @@ -289,6 +369,8 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + all_generative_model_classes = (ElectraForCausalLM,) if is_torch_available() else () + fx_ready_model_classes = all_model_classes fx_dynamic_ready_model_classes = all_model_classes @@ -314,6 +396,10 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_electra_model(*config_and_inputs) + def test_electra_model_as_decoder(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_electra_model_as_decoder(*config_and_inputs) + def test_electra_model_various_embeddings(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() for type in ["absolute", "relative_key", "relative_key_query"]: @@ -350,6 +436,10 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): model = ElectraModel.from_pretrained(model_name) self.assertIsNotNone(model) + def test_for_causal_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_electra_for_causal_lm(*config_and_inputs) + @require_torch class ElectraModelIntegrationTest(unittest.TestCase):