[Reformer] Improved memory if input is shorter than chunk length (#4720)

* improve handling of short inputs for reformer

* correct typo in assert statement

* fix other tests
This commit is contained in:
Patrick von Platen
2020-06-02 23:08:39 +02:00
committed by GitHub
parent b231a413f5
commit 9ca485734a
3 changed files with 161 additions and 96 deletions

View File

@@ -388,6 +388,16 @@ class ReformerModelTester:
output = model.generate(input_ids, attention_mask=input_mask, do_sample=False)
self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask):
# force chunk length to be bigger than input_ids
config.lsh_attn_chunk_length = 2 * input_ids.shape[-1]
config.local_attn_chunk_length = 2 * input_ids.shape[-1]
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
output_logits = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask,) = config_and_inputs
@@ -433,6 +443,10 @@ class ReformerTesterMixin:
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs)
def test_reformer_no_chunking(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_no_chunking(*config_and_inputs)
@slow
def test_dropout_random_seed_is_changing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
@@ -772,6 +786,7 @@ class ReformerIntegrationTests(unittest.TestCase):
def test_lsh_layer_forward(self):
config = self._get_basic_config_and_input()
config["lsh_num_chunks_before"] = 0
config["attn_layers"] = ["lsh"]
config["is_decoder"] = False
hidden_states = self._get_hidden_states()
@@ -787,6 +802,7 @@ class ReformerIntegrationTests(unittest.TestCase):
def test_lsh_layer_forward_complex(self):
config = self._get_basic_config_and_input()
config["lsh_num_chunks_before"] = 0
config["attn_layers"] = ["lsh"]
config["num_buckets"] = [2, 4]
attn_mask = self._get_attn_mask()
@@ -805,6 +821,7 @@ class ReformerIntegrationTests(unittest.TestCase):
def test_local_layer_forward(self):
config = self._get_basic_config_and_input()
config["local_num_chunks_before"] = 0
config["attn_layers"] = ["local"]
config["is_decoder"] = False
hidden_states = self._get_hidden_states()
@@ -820,6 +837,7 @@ class ReformerIntegrationTests(unittest.TestCase):
def test_local_layer_forward_complex(self):
config = self._get_basic_config_and_input()
config["local_num_chunks_before"] = 0
config["attn_layers"] = ["local"]
attn_mask = self._get_attn_mask()
hidden_states = self._get_hidden_states()
@@ -829,7 +847,7 @@ class ReformerIntegrationTests(unittest.TestCase):
reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask,)
output_slice = reformer_output.hidden_states[0, 0, :5]
expected_output_slice = torch.tensor(
[1.5476, -1.9020, -0.9902, 1.5013, -0.1950], dtype=torch.float, device=torch_device,
[1.4750, -2.0235, -0.9743, 1.4463, -0.1269], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))