[Reformer] Improved memory if input is shorter than chunk length (#4720)
* improve handling of short inputs for reformer * correct typo in assert statement * fix other tests
This commit is contained in:
committed by
GitHub
parent
b231a413f5
commit
9ca485734a
@@ -388,6 +388,16 @@ class ReformerModelTester:
|
||||
output = model.generate(input_ids, attention_mask=input_mask, do_sample=False)
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask):
|
||||
# force chunk length to be bigger than input_ids
|
||||
config.lsh_attn_chunk_length = 2 * input_ids.shape[-1]
|
||||
config.local_attn_chunk_length = 2 * input_ids.shape[-1]
|
||||
model = ReformerModelWithLMHead(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
output_logits = model(input_ids, attention_mask=input_mask)[0]
|
||||
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, input_mask,) = config_and_inputs
|
||||
@@ -433,6 +443,10 @@ class ReformerTesterMixin:
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs)
|
||||
|
||||
def test_reformer_no_chunking(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_no_chunking(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_dropout_random_seed_is_changing(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
@@ -772,6 +786,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_lsh_layer_forward(self):
|
||||
config = self._get_basic_config_and_input()
|
||||
config["lsh_num_chunks_before"] = 0
|
||||
config["attn_layers"] = ["lsh"]
|
||||
config["is_decoder"] = False
|
||||
hidden_states = self._get_hidden_states()
|
||||
@@ -787,6 +802,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_lsh_layer_forward_complex(self):
|
||||
config = self._get_basic_config_and_input()
|
||||
config["lsh_num_chunks_before"] = 0
|
||||
config["attn_layers"] = ["lsh"]
|
||||
config["num_buckets"] = [2, 4]
|
||||
attn_mask = self._get_attn_mask()
|
||||
@@ -805,6 +821,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_local_layer_forward(self):
|
||||
config = self._get_basic_config_and_input()
|
||||
config["local_num_chunks_before"] = 0
|
||||
config["attn_layers"] = ["local"]
|
||||
config["is_decoder"] = False
|
||||
hidden_states = self._get_hidden_states()
|
||||
@@ -820,6 +837,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_local_layer_forward_complex(self):
|
||||
config = self._get_basic_config_and_input()
|
||||
config["local_num_chunks_before"] = 0
|
||||
config["attn_layers"] = ["local"]
|
||||
attn_mask = self._get_attn_mask()
|
||||
hidden_states = self._get_hidden_states()
|
||||
@@ -829,7 +847,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask,)
|
||||
output_slice = reformer_output.hidden_states[0, 0, :5]
|
||||
expected_output_slice = torch.tensor(
|
||||
[1.5476, -1.9020, -0.9902, 1.5013, -0.1950], dtype=torch.float, device=torch_device,
|
||||
[1.4750, -2.0235, -0.9743, 1.4463, -0.1269], dtype=torch.float, device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user