[Reformer] Adapt Reformer MaskedLM Attn mask (#5560)
* fix attention mask * fix slow test * refactor attn masks * fix fp16 generate test
This commit is contained in:
committed by
GitHub
parent
3dcb748e31
commit
989ae326b5
@@ -407,7 +407,8 @@ class ReformerModelTester:
|
||||
model.to(torch_device)
|
||||
model.half()
|
||||
model.eval()
|
||||
output = model.generate(input_ids, attention_mask=input_mask, do_sample=False)
|
||||
# only use last 10 inputs for generation
|
||||
output = model.generate(input_ids[:, -10:], attention_mask=input_mask, do_sample=False)
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||
@@ -623,7 +624,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
|
||||
@require_torch
|
||||
class ReformerIntegrationTests(unittest.TestCase):
|
||||
"""
|
||||
These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/04/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `branch_to_save_trax_integration_tests`.
|
||||
These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/06/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `reformer_trax_tests`.
|
||||
"""
|
||||
|
||||
def _get_basic_config_and_input(self):
|
||||
@@ -940,7 +941,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
|
||||
output_slice = hidden_states[1, -1, :5]
|
||||
expected_output_slice = torch.tensor(
|
||||
[0.0324, -0.0121, 0.0615, 0.0031, -0.0297], dtype=torch.float, device=torch_device,
|
||||
[0.0256, -0.0121, 0.0636, 0.0024, -0.0393], dtype=torch.float, device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user