Feed forward chunking (#6024)
* Chunked feed forward for Bert This is an initial implementation to test applying feed forward chunking for BERT. Will need additional modifications based on output and benchmark results. * Black and cleanup * Feed forward chunking in BertLayer class. * Isort * add chunking for all models * fix docs * Fix typo Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -60,6 +60,7 @@ class ModelTesterMixin:
|
||||
test_resize_embeddings = True
|
||||
test_head_masking = True
|
||||
test_missing_keys = True
|
||||
test_chunking = False
|
||||
is_encoder_decoder = False
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class):
|
||||
@@ -519,6 +520,29 @@ class ModelTesterMixin:
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def test_feed_forward_chunking(self):
|
||||
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_chunking:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
torch.manual_seed(0)
|
||||
config = copy.deepcopy(original_config)
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
config.chunk_size_feed_forward = 1
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
|
||||
|
||||
def test_resize_tokens_embeddings(self):
|
||||
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_resize_embeddings:
|
||||
|
||||
Reference in New Issue
Block a user