diff --git a/docs/source/model_doc/reformer.rst b/docs/source/model_doc/reformer.rst index 40539c3d68..f82d972b72 100644 --- a/docs/source/model_doc/reformer.rst +++ b/docs/source/model_doc/reformer.rst @@ -114,9 +114,15 @@ ReformerModelWithLMHead :members: +ReformerForMaskedLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.ReformerForMaskedLM + :members: + + ReformerForQuestionAnswering ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.ReformerForQuestionAnswering :members: - diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 763ba15729..e7aaed6718 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -366,6 +366,7 @@ if is_torch_available(): ReformerAttention, ReformerLayer, ReformerModel, + ReformerForMaskedLM, ReformerModelWithLMHead, ReformerForQuestionAnswering, REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index c33c00b38e..a6c0688f73 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -122,7 +122,12 @@ from .modeling_mobilebert import ( MobileBertModel, ) from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel -from .modeling_reformer import ReformerForQuestionAnswering, ReformerModel, ReformerModelWithLMHead +from .modeling_reformer import ( + ReformerForMaskedLM, + ReformerForQuestionAnswering, + ReformerModel, + ReformerModelWithLMHead, +) from .modeling_retribert import RetriBertModel from .modeling_roberta import ( RobertaForMaskedLM, @@ -266,6 +271,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( (FlaubertConfig, FlaubertWithLMHeadModel), (XLMConfig, XLMWithLMHeadModel), (ElectraConfig, ElectraForMaskedLM), + (ReformerConfig, ReformerForMaskedLM), ] ) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 8190c14fbf..89cbf14403 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -1704,6 +1704,7 @@ class ReformerModel(ReformerPreTrainedModel): class ReformerModelWithLMHead(ReformerPreTrainedModel): def __init__(self, config): super().__init__(config) + assert config.is_decoder, "If you want to use `ReformerLMHeadModel` make sure that `is_decoder=True`." self.reformer = ReformerModel(config) self.lm_head = ReformerOnlyLMHead(config) @@ -1791,6 +1792,87 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): return inputs_dict +@add_start_docstrings("""Reformer Model with a `language modeling` head on top. """, REFORMER_START_DOCSTRING) +class ReformerForMaskedLM(ReformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + assert ( + not config.is_decoder + ), "If you want to use `ReformerForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention." + self.reformer = ReformerModel(config) + self.lm_head = ReformerOnlyLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def tie_weights(self): + # word embeddings are not tied in Reformer + pass + + @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/reformer-crime-and-punishment") + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + num_hashes=None, + labels=None, + output_hidden_states=None, + output_attentions=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Labels for computing the masked language modeling loss. + Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels + + Return: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Classification loss (cross entropy). + prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + reformer_outputs = self.reformer( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + num_hashes=num_hashes, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + + sequence_output = reformer_outputs[0] + logits = self.lm_head(sequence_output) + outputs = (logits,) + reformer_outputs[1:] + + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + outputs = (masked_lm_loss,) + outputs + + return outputs # (mlm_loss), lm_logits, (hidden_states), (attentions) + + @add_start_docstrings( """Reformer Model with a span classification head on top for extractive question-answering tasks like SQuAD / TriviaQA ( a linear layer on diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index e3a2667dd6..7faa43346e 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -25,6 +25,7 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor if is_torch_available(): from transformers import ( ReformerConfig, + ReformerForMaskedLM, ReformerModel, ReformerModelWithLMHead, ReformerTokenizer, @@ -209,7 +210,24 @@ class ReformerModelTester: ) self.check_loss_output(result) - def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, choice_labels, is_decoder): + def create_and_check_reformer_with_mlm(self, config, input_ids, input_mask, choice_labels): + config.is_decoder = False + model = ReformerForMaskedLM(config=config) + model.to(torch_device) + model.eval() + loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) + result = { + "loss": loss, + "prediction_scores": prediction_scores, + } + self.parent.assertListEqual( + list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size], + ) + self.check_loss_output(result) + + def create_and_check_reformer_model_with_attn_mask( + self, config, input_ids, input_mask, choice_labels, is_decoder=False + ): # no special position embeddings config.axial_pos_embds = False config.is_decoder = is_decoder @@ -250,7 +268,9 @@ class ReformerModelTester: self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) - def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, choice_labels, is_decoder): + def create_and_check_reformer_layer_dropout_seed( + self, config, input_ids, input_mask, choice_labels, is_decoder=False + ): config.is_decoder = is_decoder layer = ReformerLayer(config).to(torch_device) layer.train() @@ -441,17 +461,21 @@ class ReformerTesterMixin: def test_reformer_model_attn_masking(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, True) - self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, False) + self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, is_decoder=True) + self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, is_decoder=False) def test_reformer_with_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_with_lm(*config_and_inputs) + def test_reformer_with_mlm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_with_mlm(*config_and_inputs) + def test_reformer_layer_training_dropout(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True) - self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False) + self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=True) + self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=False) def test_reformer_chunking_forward_equality(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -501,7 +525,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest "batch_size": 13, "seq_length": 32, "is_training": True, - "is_decoder": False, + "is_decoder": True, "use_input_mask": True, "use_labels": True, "vocab_size": 32, @@ -560,7 +584,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T "use_input_mask": True, "use_labels": True, "is_training": False, - "is_decoder": False, + "is_decoder": True, "vocab_size": 32, "attention_head_size": 16, "hidden_size": 64, @@ -910,7 +934,7 @@ class ReformerIntegrationTests(unittest.TestCase): config["num_buckets"] = [2, 4] config["is_decoder"] = False torch.manual_seed(0) - model = ReformerModelWithLMHead(ReformerConfig(**config)).to(torch_device) + model = ReformerForMaskedLM(ReformerConfig(**config)).to(torch_device) model.eval() input_ids, attn_mask = self._get_input_ids_and_mask() hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]