diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 55e12b02ab..dec1d726e3 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -64,11 +64,6 @@ class ReformerConfig(PretrainedConfig): A chunk size of 0 means that the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time. For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ . - chunk_size_feed_forward (:obj:`int`, optional, defaults to 0): - The chunk size of all feed forward layers in the residual attention blocks. - A chunk size of 0 means that the feed forward layer is not chunked. - A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time. - For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ . eos_token_id (:obj:`int`, optional, defaults to 2): The token id for the token. feed_forward_size (:obj:`int`, optional, defaults to 512): @@ -147,7 +142,6 @@ class ReformerConfig(PretrainedConfig): axial_pos_shape=[64, 64], axial_pos_embds_dim=[64, 192], chunk_size_lm_head=0, - chunk_size_feed_forward=0, eos_token_id=2, feed_forward_size=512, hash_seed=None, @@ -202,5 +196,4 @@ class ReformerConfig(PretrainedConfig): self.axial_pos_embds_dim = tuple(axial_pos_embds_dim) self.axial_norm_std = axial_norm_std self.chunk_size_lm_head = chunk_size_lm_head - self.chunk_size_feed_forward = chunk_size_feed_forward self.attn_layers = attn_layers diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py old mode 100644 new mode 100755 index 3e1d4bcbf4..f71e2c2c05 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -66,6 +66,11 @@ class PretrainedConfig(object): 2. xla_device (:obj:`bool`, `optional`): A flag to indicate if TPU are available or not. + chunk_size_feed_forward (:obj:`int`, `optional`, defaults to :obj:`0`): + The chunk size of all feed forward layers in the residual attention blocks. + A chunk size of :obj:`0` means that the feed forward layer is not chunked. + A chunk size of n means that the feed forward layer processes :obj:`n` < sequence_length embeddings at a time. + For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ . Parameters for sequence generation - **max_length** (:obj:`int`, `optional`, defaults to 20) -- Maximum length that will be used by @@ -163,6 +168,7 @@ class PretrainedConfig(object): self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) self.bad_words_ids = kwargs.pop("bad_words_ids", None) self.num_return_sequences = kwargs.pop("num_return_sequences", 1) + self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0) # Fine-tuning task arguments self.architectures = kwargs.pop("architectures", None) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py old mode 100644 new mode 100755 index 9605d29cb2..a40d54f611 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -48,7 +48,12 @@ from .modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from .modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) logger = logging.getLogger(__name__) @@ -88,6 +93,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): """ try: import re + import numpy as np import tensorflow as tf except ImportError: @@ -376,6 +382,8 @@ class BertOutput(nn.Module): class BertLayer(nn.Module): def __init__(self, config): super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 self.attention = BertAttention(config) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention @@ -415,11 +423,17 @@ class BertLayer(nn.Module): attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights - intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) + layer_output = apply_chunking_to_forward( + self.chunk_size_feed_forward, self.seq_len_dim, self.feed_forward_chunk, attention_output + ) outputs = (layer_output,) + outputs return outputs + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + class BertEncoder(nn.Module): def __init__(self, config): diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index 0ec9a0b874..87382337d5 100644 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -370,6 +370,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + test_chunking = True def setUp(self): self.model_tester = BertModelTester(self) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9dde829d74..cb297008bf 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -60,6 +60,7 @@ class ModelTesterMixin: test_resize_embeddings = True test_head_masking = True test_missing_keys = True + test_chunking = False is_encoder_decoder = False def _prepare_for_class(self, inputs_dict, model_class): @@ -519,6 +520,29 @@ class ModelTesterMixin: check_hidden_states_output(inputs_dict, config, model_class) + def test_feed_forward_chunking(self): + (original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common() + if not self.test_chunking: + return + + for model_class in self.all_model_classes: + torch.manual_seed(0) + config = copy.deepcopy(original_config) + model = model_class(config) + model.to(torch_device) + model.eval() + + hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0] + + torch.manual_seed(0) + config.chunk_size_feed_forward = 1 + model = model_class(config) + model.to(torch_device) + model.eval() + + hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0] + self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3)) + def test_resize_tokens_embeddings(self): (original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common() if not self.test_resize_embeddings: diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index e5d07d8eb6..e878c310b1 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -291,24 +291,6 @@ class ReformerModelTester: torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,) ) - def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask, choice_labels): - torch.manual_seed(0) - model = ReformerModel(config=config) - model.to(torch_device) - model.eval() - hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0] - - config.chunk_size_lm_head = 1 - config.chunk_size_feed_forward = 1 - - torch.manual_seed(0) - model = ReformerModel(config=config) - model.to(torch_device) - model.eval() - - hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)["last_hidden_state"] - self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3)) - def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels): if not self.is_training: return @@ -517,10 +499,6 @@ class ReformerTesterMixin: self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=True) self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=False) - def test_reformer_chunking_forward_equality(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_feed_forward_chunking(*config_and_inputs) - def test_reformer_chunking_backward_equality(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs) @@ -577,6 +555,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest test_pruning = False test_headmasking = False test_torchscript = False + test_chunking = True def prepare_kwargs(self): return { @@ -637,6 +616,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T test_pruning = False test_headmasking = False test_torchscript = False + test_chunking = True def prepare_kwargs(self): return {