Feed forward chunking (#6024)
* Chunked feed forward for Bert This is an initial implementation to test applying feed forward chunking for BERT. Will need additional modifications based on output and benchmark results. * Black and cleanup * Feed forward chunking in BertLayer class. * Isort * add chunking for all models * fix docs * Fix typo Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
6
src/transformers/configuration_utils.py
Normal file → Executable file
6
src/transformers/configuration_utils.py
Normal file → Executable file
@@ -66,6 +66,11 @@ class PretrainedConfig(object):
|
||||
2.
|
||||
xla_device (:obj:`bool`, `optional`):
|
||||
A flag to indicate if TPU are available or not.
|
||||
chunk_size_feed_forward (:obj:`int`, `optional`, defaults to :obj:`0`):
|
||||
The chunk size of all feed forward layers in the residual attention blocks.
|
||||
A chunk size of :obj:`0` means that the feed forward layer is not chunked.
|
||||
A chunk size of n means that the feed forward layer processes :obj:`n` < sequence_length embeddings at a time.
|
||||
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
|
||||
|
||||
Parameters for sequence generation
|
||||
- **max_length** (:obj:`int`, `optional`, defaults to 20) -- Maximum length that will be used by
|
||||
@@ -163,6 +168,7 @@ class PretrainedConfig(object):
|
||||
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
||||
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
|
||||
|
||||
# Fine-tuning task arguments
|
||||
self.architectures = kwargs.pop("architectures", None)
|
||||
|
||||
Reference in New Issue
Block a user