[Reformer] Add Masked LM Reformer (#5426)
* fix conflicts * fix * happy rebasing
This commit is contained in:
committed by
GitHub
parent
f4323dbf8c
commit
d16e36c7e5
@@ -25,6 +25,7 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
if is_torch_available():
|
||||
from transformers import (
|
||||
ReformerConfig,
|
||||
ReformerForMaskedLM,
|
||||
ReformerModel,
|
||||
ReformerModelWithLMHead,
|
||||
ReformerTokenizer,
|
||||
@@ -209,7 +210,24 @@ class ReformerModelTester:
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, choice_labels, is_decoder):
|
||||
def create_and_check_reformer_with_mlm(self, config, input_ids, input_mask, choice_labels):
|
||||
config.is_decoder = False
|
||||
model = ReformerForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_reformer_model_with_attn_mask(
|
||||
self, config, input_ids, input_mask, choice_labels, is_decoder=False
|
||||
):
|
||||
# no special position embeddings
|
||||
config.axial_pos_embds = False
|
||||
config.is_decoder = is_decoder
|
||||
@@ -250,7 +268,9 @@ class ReformerModelTester:
|
||||
|
||||
self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3))
|
||||
|
||||
def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, choice_labels, is_decoder):
|
||||
def create_and_check_reformer_layer_dropout_seed(
|
||||
self, config, input_ids, input_mask, choice_labels, is_decoder=False
|
||||
):
|
||||
config.is_decoder = is_decoder
|
||||
layer = ReformerLayer(config).to(torch_device)
|
||||
layer.train()
|
||||
@@ -441,17 +461,21 @@ class ReformerTesterMixin:
|
||||
|
||||
def test_reformer_model_attn_masking(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, True)
|
||||
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, False)
|
||||
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, is_decoder=True)
|
||||
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, is_decoder=False)
|
||||
|
||||
def test_reformer_with_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_with_lm(*config_and_inputs)
|
||||
|
||||
def test_reformer_with_mlm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_with_mlm(*config_and_inputs)
|
||||
|
||||
def test_reformer_layer_training_dropout(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True)
|
||||
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False)
|
||||
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=True)
|
||||
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=False)
|
||||
|
||||
def test_reformer_chunking_forward_equality(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
@@ -501,7 +525,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
|
||||
"batch_size": 13,
|
||||
"seq_length": 32,
|
||||
"is_training": True,
|
||||
"is_decoder": False,
|
||||
"is_decoder": True,
|
||||
"use_input_mask": True,
|
||||
"use_labels": True,
|
||||
"vocab_size": 32,
|
||||
@@ -560,7 +584,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
|
||||
"use_input_mask": True,
|
||||
"use_labels": True,
|
||||
"is_training": False,
|
||||
"is_decoder": False,
|
||||
"is_decoder": True,
|
||||
"vocab_size": 32,
|
||||
"attention_head_size": 16,
|
||||
"hidden_size": 64,
|
||||
@@ -910,7 +934,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
config["num_buckets"] = [2, 4]
|
||||
config["is_decoder"] = False
|
||||
torch.manual_seed(0)
|
||||
model = ReformerModelWithLMHead(ReformerConfig(**config)).to(torch_device)
|
||||
model = ReformerForMaskedLM(ReformerConfig(**config)).to(torch_device)
|
||||
model.eval()
|
||||
input_ids, attn_mask = self._get_input_ids_and_mask()
|
||||
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
|
||||
|
||||
Reference in New Issue
Block a user