Add Mega: Moving Average Equipped Gated Attention (#21766)

* add mega file structure and plain pytorch version of mega source code

* added config class with old naming conventions

* filled in mega documentation

* added config class and embeddings with optional token types

* updated notes

* starting the conversion process, deleted intermediate and added use_cache back to config

* renamed config attributes in modeling_mega.py

* checkpointing before refactoring incremental decoding functions

* removed stateful incremental key/values for EMA and self-attention

* refactored MovingAverageGatedAttention to remove stateful k/v history and use unified attention mask

* MovingAverageGatedAttention works with incremental decoding + past values, added sequence length enforcement

* more comments in MovingAverageGatedAttention + checkpointing before GatedCrossAttention

* bug fix in attention mask handling in MovingAverageGatedAttention

* removed incremental state from GatedCrossAttention and removed IncrementalState class

* finished gated cross attention and got MegaLayer working

* fixed causal masking in mega decoder

* fixed how padding and causal masks are passed through MegaLayer with and without k/v caching

* finished MegaModel; tested with encoder, decoder-only, and cross-attention type inputs; started work on downstream classes; removed mentions of position_ids

* added optional dense hidden layer for masked and causal LM classes

* docstring updates in MultiHeadEMA and GatedCrossAttention, removed unnecessary inputs in cross-attention

* removed before_attn_fn in Mega class and updated docstrings and comments up to there

* bug fix in MovingAverageGatedAttention masking

* working conversion of MLM checkpoint in scratchpad script -- perfect matches

* moved arg for hidden dense layer in LM head to config; discovered issue where from_pretrained is renaming gamma and beta parameters

* renamed gamma and beta parameters to avoid HF renaming when loading from checkpoint

* finished checkpoint conversion script

* cleanup old class in mega config script

* removed 'copied from' statements and passing integration tests

* added num_attention_heads=1 to config for integration compatibility, decoder tests working, generation tests failing

* fixed tuple output of megamodel

* all common tests passing after fixing issues in decoder, gradient retention, and initialization

* added mega-specific tests, ready for more documentation and style checks

* updated docstrings; checkpoint before style fixes

* style and quality checks, fixed initialization problem in float_tensor, ready for PR

* added mega to toctree

* removed unnecessary arg in megaconfig

* removed unused arg and fixed code samples with leftover roberta models

* Apply suggestions from code review

Applied all suggestions except the one renaming a class, as I'll need to update that througout

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fixed issue where .view breaks batch dimension, conversion script fixed with absolute imports, updated readme with Mega->MEGA

* removed asserts in Mega code, renamed sequencenorm, gatedcrossattention, and NFFN, replaced get_activation_fn with ACTFN, and added sequencenorm to layer norms

* reformatted .forward() docstrings to match style and removed unused mask input in cross-attention

* removed all reset_parameters() methods and rolled into MegaPreTrainedModel._init_weights()

* renamed all single-letter variables and improved readability in tensor size comments, Mega->MEGA in 2 documentation files

* variable names in NFFN

* manual Mega->MEGA changes in docs

* Mega->MEGA in config auto

* style and quality fixes

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* renamed parameters and variables with confusing names, added copied from statements, moved fft conv to its own method, other cleanup from PR comments

* commit before dealing with merge conflicts

* made new attention activation functions available in ACT2FN and added generation test from OPT

* style and quality in activations and tests

* documentation fixes, renaming variables in dropout and rotary positions, used built-in causal masking, encoders->layers in MegaModel, moved comments into docstrings

* style and quality fixes after latest updates, before rotary position ids

* causal mask in MegaBlock docstring + added missing device passing

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update README.md

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* added Mega prefixes where missing, reverted MegaSequenceNorm to if-else, other module renaming requested in PR

* style and quality fixes + readme updates pointing to main

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Mitch Naylor
2023-03-24 07:17:27 -05:00
committed by GitHub
parent 0fa46524ac
commit 57f25f4b7f
30 changed files with 3790 additions and 6 deletions

View File

@@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
<!--End of the generated tip-->