Feed forward chunking (#6024)
* Chunked feed forward for Bert This is an initial implementation to test applying feed forward chunking for BERT. Will need additional modifications based on output and benchmark results. * Black and cleanup * Feed forward chunking in BertLayer class. * Isort * add chunking for all models * fix docs * Fix typo Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -64,11 +64,6 @@ class ReformerConfig(PretrainedConfig):
|
|||||||
A chunk size of 0 means that the feed forward layer is not chunked.
|
A chunk size of 0 means that the feed forward layer is not chunked.
|
||||||
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
|
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
|
||||||
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
|
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
|
||||||
chunk_size_feed_forward (:obj:`int`, optional, defaults to 0):
|
|
||||||
The chunk size of all feed forward layers in the residual attention blocks.
|
|
||||||
A chunk size of 0 means that the feed forward layer is not chunked.
|
|
||||||
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
|
|
||||||
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
|
|
||||||
eos_token_id (:obj:`int`, optional, defaults to 2):
|
eos_token_id (:obj:`int`, optional, defaults to 2):
|
||||||
The token id for the <EOS> token.
|
The token id for the <EOS> token.
|
||||||
feed_forward_size (:obj:`int`, optional, defaults to 512):
|
feed_forward_size (:obj:`int`, optional, defaults to 512):
|
||||||
@@ -147,7 +142,6 @@ class ReformerConfig(PretrainedConfig):
|
|||||||
axial_pos_shape=[64, 64],
|
axial_pos_shape=[64, 64],
|
||||||
axial_pos_embds_dim=[64, 192],
|
axial_pos_embds_dim=[64, 192],
|
||||||
chunk_size_lm_head=0,
|
chunk_size_lm_head=0,
|
||||||
chunk_size_feed_forward=0,
|
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
feed_forward_size=512,
|
feed_forward_size=512,
|
||||||
hash_seed=None,
|
hash_seed=None,
|
||||||
@@ -202,5 +196,4 @@ class ReformerConfig(PretrainedConfig):
|
|||||||
self.axial_pos_embds_dim = tuple(axial_pos_embds_dim)
|
self.axial_pos_embds_dim = tuple(axial_pos_embds_dim)
|
||||||
self.axial_norm_std = axial_norm_std
|
self.axial_norm_std = axial_norm_std
|
||||||
self.chunk_size_lm_head = chunk_size_lm_head
|
self.chunk_size_lm_head = chunk_size_lm_head
|
||||||
self.chunk_size_feed_forward = chunk_size_feed_forward
|
|
||||||
self.attn_layers = attn_layers
|
self.attn_layers = attn_layers
|
||||||
|
|||||||
6
src/transformers/configuration_utils.py
Normal file → Executable file
6
src/transformers/configuration_utils.py
Normal file → Executable file
@@ -66,6 +66,11 @@ class PretrainedConfig(object):
|
|||||||
2.
|
2.
|
||||||
xla_device (:obj:`bool`, `optional`):
|
xla_device (:obj:`bool`, `optional`):
|
||||||
A flag to indicate if TPU are available or not.
|
A flag to indicate if TPU are available or not.
|
||||||
|
chunk_size_feed_forward (:obj:`int`, `optional`, defaults to :obj:`0`):
|
||||||
|
The chunk size of all feed forward layers in the residual attention blocks.
|
||||||
|
A chunk size of :obj:`0` means that the feed forward layer is not chunked.
|
||||||
|
A chunk size of n means that the feed forward layer processes :obj:`n` < sequence_length embeddings at a time.
|
||||||
|
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
|
||||||
|
|
||||||
Parameters for sequence generation
|
Parameters for sequence generation
|
||||||
- **max_length** (:obj:`int`, `optional`, defaults to 20) -- Maximum length that will be used by
|
- **max_length** (:obj:`int`, `optional`, defaults to 20) -- Maximum length that will be used by
|
||||||
@@ -163,6 +168,7 @@ class PretrainedConfig(object):
|
|||||||
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
||||||
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
||||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||||
|
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
|
||||||
|
|
||||||
# Fine-tuning task arguments
|
# Fine-tuning task arguments
|
||||||
self.architectures = kwargs.pop("architectures", None)
|
self.architectures = kwargs.pop("architectures", None)
|
||||||
|
|||||||
20
src/transformers/modeling_bert.py
Normal file → Executable file
20
src/transformers/modeling_bert.py
Normal file → Executable file
@@ -48,7 +48,12 @@ from .modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from .modeling_utils import (
|
||||||
|
PreTrainedModel,
|
||||||
|
apply_chunking_to_forward,
|
||||||
|
find_pruneable_heads_and_indices,
|
||||||
|
prune_linear_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -88,6 +93,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -376,6 +382,8 @@ class BertOutput(nn.Module):
|
|||||||
class BertLayer(nn.Module):
|
class BertLayer(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
|
self.seq_len_dim = 1
|
||||||
self.attention = BertAttention(config)
|
self.attention = BertAttention(config)
|
||||||
self.is_decoder = config.is_decoder
|
self.is_decoder = config.is_decoder
|
||||||
self.add_cross_attention = config.add_cross_attention
|
self.add_cross_attention = config.add_cross_attention
|
||||||
@@ -415,11 +423,17 @@ class BertLayer(nn.Module):
|
|||||||
attention_output = cross_attention_outputs[0]
|
attention_output = cross_attention_outputs[0]
|
||||||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
intermediate_output = self.intermediate(attention_output)
|
layer_output = apply_chunking_to_forward(
|
||||||
layer_output = self.output(intermediate_output, attention_output)
|
self.chunk_size_feed_forward, self.seq_len_dim, self.feed_forward_chunk, attention_output
|
||||||
|
)
|
||||||
outputs = (layer_output,) + outputs
|
outputs = (layer_output,) + outputs
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def feed_forward_chunk(self, attention_output):
|
||||||
|
intermediate_output = self.intermediate(attention_output)
|
||||||
|
layer_output = self.output(intermediate_output, attention_output)
|
||||||
|
return layer_output
|
||||||
|
|
||||||
|
|
||||||
class BertEncoder(nn.Module):
|
class BertEncoder(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
|||||||
@@ -370,6 +370,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
test_chunking = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = BertModelTester(self)
|
self.model_tester = BertModelTester(self)
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ class ModelTesterMixin:
|
|||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
test_head_masking = True
|
test_head_masking = True
|
||||||
test_missing_keys = True
|
test_missing_keys = True
|
||||||
|
test_chunking = False
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
|
|
||||||
def _prepare_for_class(self, inputs_dict, model_class):
|
def _prepare_for_class(self, inputs_dict, model_class):
|
||||||
@@ -519,6 +520,29 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
check_hidden_states_output(inputs_dict, config, model_class)
|
check_hidden_states_output(inputs_dict, config, model_class)
|
||||||
|
|
||||||
|
def test_feed_forward_chunking(self):
|
||||||
|
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
if not self.test_chunking:
|
||||||
|
return
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
torch.manual_seed(0)
|
||||||
|
config = copy.deepcopy(original_config)
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
config.chunk_size_feed_forward = 1
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||||
|
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
|
||||||
|
|
||||||
def test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
if not self.test_resize_embeddings:
|
if not self.test_resize_embeddings:
|
||||||
|
|||||||
@@ -291,24 +291,6 @@ class ReformerModelTester:
|
|||||||
torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,)
|
torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,)
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask, choice_labels):
|
|
||||||
torch.manual_seed(0)
|
|
||||||
model = ReformerModel(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0]
|
|
||||||
|
|
||||||
config.chunk_size_lm_head = 1
|
|
||||||
config.chunk_size_feed_forward = 1
|
|
||||||
|
|
||||||
torch.manual_seed(0)
|
|
||||||
model = ReformerModel(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)["last_hidden_state"]
|
|
||||||
self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
|
|
||||||
|
|
||||||
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels):
|
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||||
if not self.is_training:
|
if not self.is_training:
|
||||||
return
|
return
|
||||||
@@ -517,10 +499,6 @@ class ReformerTesterMixin:
|
|||||||
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=True)
|
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=True)
|
||||||
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=False)
|
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=False)
|
||||||
|
|
||||||
def test_reformer_chunking_forward_equality(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_reformer_feed_forward_chunking(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_reformer_chunking_backward_equality(self):
|
def test_reformer_chunking_backward_equality(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs)
|
self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs)
|
||||||
@@ -577,6 +555,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
test_chunking = True
|
||||||
|
|
||||||
def prepare_kwargs(self):
|
def prepare_kwargs(self):
|
||||||
return {
|
return {
|
||||||
@@ -637,6 +616,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
test_chunking = True
|
||||||
|
|
||||||
def prepare_kwargs(self):
|
def prepare_kwargs(self):
|
||||||
return {
|
return {
|
||||||
|
|||||||
Reference in New Issue
Block a user