Feed forward chunking (#6024)
* Chunked feed forward for Bert This is an initial implementation to test applying feed forward chunking for BERT. Will need additional modifications based on output and benchmark results. * Black and cleanup * Feed forward chunking in BertLayer class. * Isort * add chunking for all models * fix docs * Fix typo Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -370,6 +370,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_chunking = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BertModelTester(self)
|
||||
|
||||
@@ -60,6 +60,7 @@ class ModelTesterMixin:
|
||||
test_resize_embeddings = True
|
||||
test_head_masking = True
|
||||
test_missing_keys = True
|
||||
test_chunking = False
|
||||
is_encoder_decoder = False
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class):
|
||||
@@ -519,6 +520,29 @@ class ModelTesterMixin:
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def test_feed_forward_chunking(self):
|
||||
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_chunking:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
torch.manual_seed(0)
|
||||
config = copy.deepcopy(original_config)
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
config.chunk_size_feed_forward = 1
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
|
||||
|
||||
def test_resize_tokens_embeddings(self):
|
||||
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_resize_embeddings:
|
||||
|
||||
@@ -291,24 +291,6 @@ class ReformerModelTester:
|
||||
torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,)
|
||||
)
|
||||
|
||||
def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||
torch.manual_seed(0)
|
||||
model = ReformerModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0]
|
||||
|
||||
config.chunk_size_lm_head = 1
|
||||
config.chunk_size_feed_forward = 1
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = ReformerModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)["last_hidden_state"]
|
||||
self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
|
||||
|
||||
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||
if not self.is_training:
|
||||
return
|
||||
@@ -517,10 +499,6 @@ class ReformerTesterMixin:
|
||||
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=True)
|
||||
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=False)
|
||||
|
||||
def test_reformer_chunking_forward_equality(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_feed_forward_chunking(*config_and_inputs)
|
||||
|
||||
def test_reformer_chunking_backward_equality(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs)
|
||||
@@ -577,6 +555,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
test_chunking = True
|
||||
|
||||
def prepare_kwargs(self):
|
||||
return {
|
||||
@@ -637,6 +616,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
test_chunking = True
|
||||
|
||||
def prepare_kwargs(self):
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user