Feed forward chunking (#6024)

* Chunked feed forward for Bert

This is an initial implementation to test applying feed forward chunking for BERT.
Will need additional modifications based on output and benchmark results.

* Black and cleanup

* Feed forward chunking in BertLayer class.

* Isort

* add chunking for all models

* fix docs

* Fix typo

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Pradhy729
2020-08-11 00:12:45 -07:00
committed by GitHub
parent 8a3db6b303
commit b25cec13c5
6 changed files with 50 additions and 32 deletions

View File

@@ -291,24 +291,6 @@ class ReformerModelTester:
torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,)
)
def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask, choice_labels):
torch.manual_seed(0)
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()
hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0]
config.chunk_size_lm_head = 1
config.chunk_size_feed_forward = 1
torch.manual_seed(0)
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)["last_hidden_state"]
self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels):
if not self.is_training:
return
@@ -517,10 +499,6 @@ class ReformerTesterMixin:
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=True)
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=False)
def test_reformer_chunking_forward_equality(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_feed_forward_chunking(*config_and_inputs)
def test_reformer_chunking_backward_equality(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs)
@@ -577,6 +555,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
test_pruning = False
test_headmasking = False
test_torchscript = False
test_chunking = True
def prepare_kwargs(self):
return {
@@ -637,6 +616,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
test_pruning = False
test_headmasking = False
test_torchscript = False
test_chunking = True
def prepare_kwargs(self):
return {