Feed forward chunking (#6024)

* Chunked feed forward for Bert

This is an initial implementation to test applying feed forward chunking for BERT.
Will need additional modifications based on output and benchmark results.

* Black and cleanup

* Feed forward chunking in BertLayer class.

* Isort

* add chunking for all models

* fix docs

* Fix typo

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Pradhy729
2020-08-11 00:12:45 -07:00
committed by GitHub
parent 8a3db6b303
commit b25cec13c5
6 changed files with 50 additions and 32 deletions

View File

@@ -64,11 +64,6 @@ class ReformerConfig(PretrainedConfig):
A chunk size of 0 means that the feed forward layer is not chunked.
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
chunk_size_feed_forward (:obj:`int`, optional, defaults to 0):
The chunk size of all feed forward layers in the residual attention blocks.
A chunk size of 0 means that the feed forward layer is not chunked.
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
eos_token_id (:obj:`int`, optional, defaults to 2):
The token id for the <EOS> token.
feed_forward_size (:obj:`int`, optional, defaults to 512):
@@ -147,7 +142,6 @@ class ReformerConfig(PretrainedConfig):
axial_pos_shape=[64, 64],
axial_pos_embds_dim=[64, 192],
chunk_size_lm_head=0,
chunk_size_feed_forward=0,
eos_token_id=2,
feed_forward_size=512,
hash_seed=None,
@@ -202,5 +196,4 @@ class ReformerConfig(PretrainedConfig):
self.axial_pos_embds_dim = tuple(axial_pos_embds_dim)
self.axial_norm_std = axial_norm_std
self.chunk_size_lm_head = chunk_size_lm_head
self.chunk_size_feed_forward = chunk_size_feed_forward
self.attn_layers = attn_layers

6
src/transformers/configuration_utils.py Normal file → Executable file
View File

@@ -66,6 +66,11 @@ class PretrainedConfig(object):
2.
xla_device (:obj:`bool`, `optional`):
A flag to indicate if TPU are available or not.
chunk_size_feed_forward (:obj:`int`, `optional`, defaults to :obj:`0`):
The chunk size of all feed forward layers in the residual attention blocks.
A chunk size of :obj:`0` means that the feed forward layer is not chunked.
A chunk size of n means that the feed forward layer processes :obj:`n` < sequence_length embeddings at a time.
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
Parameters for sequence generation
- **max_length** (:obj:`int`, `optional`, defaults to 20) -- Maximum length that will be used by
@@ -163,6 +168,7 @@ class PretrainedConfig(object):
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
# Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None)

20
src/transformers/modeling_bert.py Normal file → Executable file
View File

@@ -48,7 +48,12 @@ from .modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from .modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
logger = logging.getLogger(__name__)
@@ -88,6 +93,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
"""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
@@ -376,6 +382,8 @@ class BertOutput(nn.Module):
class BertLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
@@ -415,11 +423,17 @@ class BertLayer(nn.Module):
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
layer_output = apply_chunking_to_forward(
self.chunk_size_feed_forward, self.seq_len_dim, self.feed_forward_chunk, attention_output
)
outputs = (layer_output,) + outputs
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):