Feed forward chunking (#6024)

* Chunked feed forward for Bert

This is an initial implementation to test applying feed forward chunking for BERT.
Will need additional modifications based on output and benchmark results.

* Black and cleanup

* Feed forward chunking in BertLayer class.

* Isort

* add chunking for all models

* fix docs

* Fix typo

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Pradhy729
2020-08-11 00:12:45 -07:00
committed by GitHub
parent 8a3db6b303
commit b25cec13c5
6 changed files with 50 additions and 32 deletions

View File

@@ -370,6 +370,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
test_chunking = True
def setUp(self):
self.model_tester = BertModelTester(self)

View File

@@ -60,6 +60,7 @@ class ModelTesterMixin:
test_resize_embeddings = True
test_head_masking = True
test_missing_keys = True
test_chunking = False
is_encoder_decoder = False
def _prepare_for_class(self, inputs_dict, model_class):
@@ -519,6 +520,29 @@ class ModelTesterMixin:
check_hidden_states_output(inputs_dict, config, model_class)
def test_feed_forward_chunking(self):
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_chunking:
return
for model_class in self.all_model_classes:
torch.manual_seed(0)
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
model.eval()
hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
torch.manual_seed(0)
config.chunk_size_feed_forward = 1
model = model_class(config)
model.to(torch_device)
model.eval()
hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
def test_resize_tokens_embeddings(self):
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:

View File

@@ -291,24 +291,6 @@ class ReformerModelTester:
torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,)
)
def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask, choice_labels):
torch.manual_seed(0)
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()
hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0]
config.chunk_size_lm_head = 1
config.chunk_size_feed_forward = 1
torch.manual_seed(0)
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)["last_hidden_state"]
self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels):
if not self.is_training:
return
@@ -517,10 +499,6 @@ class ReformerTesterMixin:
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=True)
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=False)
def test_reformer_chunking_forward_equality(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_feed_forward_chunking(*config_and_inputs)
def test_reformer_chunking_backward_equality(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs)
@@ -577,6 +555,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
test_pruning = False
test_headmasking = False
test_torchscript = False
test_chunking = True
def prepare_kwargs(self):
return {
@@ -637,6 +616,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
test_pruning = False
test_headmasking = False
test_torchscript = False
test_chunking = True
def prepare_kwargs(self):
return {