Feed forward chunking (#6024)
* Chunked feed forward for Bert This is an initial implementation to test applying feed forward chunking for BERT. Will need additional modifications based on output and benchmark results. * Black and cleanup * Feed forward chunking in BertLayer class. * Isort * add chunking for all models * fix docs * Fix typo Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -64,11 +64,6 @@ class ReformerConfig(PretrainedConfig):
|
||||
A chunk size of 0 means that the feed forward layer is not chunked.
|
||||
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
|
||||
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
|
||||
chunk_size_feed_forward (:obj:`int`, optional, defaults to 0):
|
||||
The chunk size of all feed forward layers in the residual attention blocks.
|
||||
A chunk size of 0 means that the feed forward layer is not chunked.
|
||||
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
|
||||
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
|
||||
eos_token_id (:obj:`int`, optional, defaults to 2):
|
||||
The token id for the <EOS> token.
|
||||
feed_forward_size (:obj:`int`, optional, defaults to 512):
|
||||
@@ -147,7 +142,6 @@ class ReformerConfig(PretrainedConfig):
|
||||
axial_pos_shape=[64, 64],
|
||||
axial_pos_embds_dim=[64, 192],
|
||||
chunk_size_lm_head=0,
|
||||
chunk_size_feed_forward=0,
|
||||
eos_token_id=2,
|
||||
feed_forward_size=512,
|
||||
hash_seed=None,
|
||||
@@ -202,5 +196,4 @@ class ReformerConfig(PretrainedConfig):
|
||||
self.axial_pos_embds_dim = tuple(axial_pos_embds_dim)
|
||||
self.axial_norm_std = axial_norm_std
|
||||
self.chunk_size_lm_head = chunk_size_lm_head
|
||||
self.chunk_size_feed_forward = chunk_size_feed_forward
|
||||
self.attn_layers = attn_layers
|
||||
|
||||
6
src/transformers/configuration_utils.py
Normal file → Executable file
6
src/transformers/configuration_utils.py
Normal file → Executable file
@@ -66,6 +66,11 @@ class PretrainedConfig(object):
|
||||
2.
|
||||
xla_device (:obj:`bool`, `optional`):
|
||||
A flag to indicate if TPU are available or not.
|
||||
chunk_size_feed_forward (:obj:`int`, `optional`, defaults to :obj:`0`):
|
||||
The chunk size of all feed forward layers in the residual attention blocks.
|
||||
A chunk size of :obj:`0` means that the feed forward layer is not chunked.
|
||||
A chunk size of n means that the feed forward layer processes :obj:`n` < sequence_length embeddings at a time.
|
||||
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
|
||||
|
||||
Parameters for sequence generation
|
||||
- **max_length** (:obj:`int`, `optional`, defaults to 20) -- Maximum length that will be used by
|
||||
@@ -163,6 +168,7 @@ class PretrainedConfig(object):
|
||||
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
||||
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
|
||||
|
||||
# Fine-tuning task arguments
|
||||
self.architectures = kwargs.pop("architectures", None)
|
||||
|
||||
20
src/transformers/modeling_bert.py
Normal file → Executable file
20
src/transformers/modeling_bert.py
Normal file → Executable file
@@ -48,7 +48,12 @@ from .modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from .modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -88,6 +93,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
"""
|
||||
try:
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
@@ -376,6 +382,8 @@ class BertOutput(nn.Module):
|
||||
class BertLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = BertAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
@@ -415,11 +423,17 @@ class BertLayer(nn.Module):
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
||||
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.chunk_size_feed_forward, self.seq_len_dim, self.feed_forward_chunk, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
|
||||
Reference in New Issue
Block a user