[Reformer] Add Masked LM Reformer (#5426)

* fix conflicts

* fix

* happy rebasing
This commit is contained in:
Patrick von Platen
2020-07-01 22:43:18 +02:00
committed by GitHub
parent f4323dbf8c
commit d16e36c7e5
5 changed files with 130 additions and 11 deletions

View File

@@ -25,6 +25,7 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
from transformers import (
ReformerConfig,
ReformerForMaskedLM,
ReformerModel,
ReformerModelWithLMHead,
ReformerTokenizer,
@@ -209,7 +210,24 @@ class ReformerModelTester:
)
self.check_loss_output(result)
def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, choice_labels, is_decoder):
def create_and_check_reformer_with_mlm(self, config, input_ids, input_mask, choice_labels):
config.is_decoder = False
model = ReformerForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.check_loss_output(result)
def create_and_check_reformer_model_with_attn_mask(
self, config, input_ids, input_mask, choice_labels, is_decoder=False
):
# no special position embeddings
config.axial_pos_embds = False
config.is_decoder = is_decoder
@@ -250,7 +268,9 @@ class ReformerModelTester:
self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3))
def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, choice_labels, is_decoder):
def create_and_check_reformer_layer_dropout_seed(
self, config, input_ids, input_mask, choice_labels, is_decoder=False
):
config.is_decoder = is_decoder
layer = ReformerLayer(config).to(torch_device)
layer.train()
@@ -441,17 +461,21 @@ class ReformerTesterMixin:
def test_reformer_model_attn_masking(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, True)
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, False)
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, is_decoder=True)
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, is_decoder=False)
def test_reformer_with_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_with_lm(*config_and_inputs)
def test_reformer_with_mlm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_with_mlm(*config_and_inputs)
def test_reformer_layer_training_dropout(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True)
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False)
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=True)
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=False)
def test_reformer_chunking_forward_equality(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
@@ -501,7 +525,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
"batch_size": 13,
"seq_length": 32,
"is_training": True,
"is_decoder": False,
"is_decoder": True,
"use_input_mask": True,
"use_labels": True,
"vocab_size": 32,
@@ -560,7 +584,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
"use_input_mask": True,
"use_labels": True,
"is_training": False,
"is_decoder": False,
"is_decoder": True,
"vocab_size": 32,
"attention_head_size": 16,
"hidden_size": 64,
@@ -910,7 +934,7 @@ class ReformerIntegrationTests(unittest.TestCase):
config["num_buckets"] = [2, 4]
config["is_decoder"] = False
torch.manual_seed(0)
model = ReformerModelWithLMHead(ReformerConfig(**config)).to(torch_device)
model = ReformerForMaskedLM(ReformerConfig(**config)).to(torch_device)
model.eval()
input_ids, attn_mask = self._get_input_ids_and_mask()
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]