Add Llama Flax Implementation (#24587)
* Copies `modeling_flax_gpt_neo.py` to start * MLP Block. WIP Attention and Block * Adds Flax implementation of `LlamaMLP` Validated with in-file test. Some slight numeric differences, but assuming it isn't an issue * Adds `FlaxLlamaRMSNorm` layer `flax.linen` includes `RMSNorm` layer but not necessarily in all versions. Hence, we add in-file. * Adds FlaxLlamaAttention Copied from GPT-J as it has efficient caching implementation as well as rotary embeddings. Notice numerically different, but not by a huge amount. Needs investigating * Adds `FlaxLlamaDecoderLayer` numerically inaccurate, debugging.. * debugging rotary mismatch gptj uses interleaved whilst llama uses contiguous i think they match now but still final result is wrong. maybe drop back to just debugging attention layer? * fixes bug with decoder layer still somewhat numerically inaccurate, but close enough for now * adds markers for what to implement next the structure here diverges a lot from the PT version. not a big fan of it, but just get something working for now * implements `FlaxLlamaBlockCollection`] tolerance must be higher than expected, kinda disconcerting * Adds `FlaxLlamaModule` equivalent PyTorch model is `LlamaModel` yay! a language model🤗 * adds `FlaxLlamaForCausalLMModule` equivalent to `LlamaForCausalLM` still missing returning dict or tuple, will add later * start porting pretrained wrappers realised it probably needs return dict as a prereq * cleanup, quality, style * readds `return_dict` and model output named tuples * (tentatively) pretrained wrappers work 🔥 * fixes numerical mismatch in `FlaxLlamaRMSNorm` seems `jax.lax.rsqrt` does not match `torch.sqrt`. manually computing `1 / jax.numpy.sqrt` results in matching values. * [WIP] debugging numerics * numerical match I think issue was accidental change of backend. forcing CPU fixes test. We expect some mismatch on GPU. * adds in model and integration tests for Flax Llama summary of failing: - mul invalid combination of dimensions - one numerical mismatch - bf16 conversion (maybe my local backend issue) - params are not FrozenDict * adds missing TYPE_CHECKING import and `make fixup` * adds back missing docstrings needs review on quality of docstrings, not sure what is required. Furthermore, need to check if `CHECKPOINT_FOR_DOC` is valid. See TODO * commenting out equivalence test as can just use common * debugging * Fixes bug where mask and pos_ids were swapped in pretrained models This results in all tests passing now 🔥 * cleanup of modeling file * cleanup of test file * Resolving simpler review comments * addresses more minor review comments * fixing introduced pytest errors from review * wip additional slow tests * wip tests need to grab a GPU machine to get real logits for comparison otherwise, slow tests should be okay * `make quality`, `make style` * adds slow integration tests - checking logits - checking hidden states - checking generation outputs * `make fix-copies` * fix mangled function following `make fix-copies` * adds missing type checking imports * fixes missing parameter checkpoint warning * more finegrained 'Copied from' tags avoids issue of overwriting `LLAMA_INPUTS_DOCSTRING` * swaps import guards ??? how did these get swapped initially? * removing `inv_freq` again as pytorch version has now removed * attempting to get CI to pass * adds doc entries for llama flax models * fixes typo in __init__.py imports * adds back special equivalence tests these come from the gpt neo flax tests. there is special behaviour for these models that needs to override the common version * overrides tests with dummy to see if CI passes need to fill in these tests later * adds my contribution to docs * `make style; make quality` * replaces random masking with fixed to work with flax version * `make quality; make style` * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * updates `x`->`tensor` in `rotate_half` * addresses smaller review comments * Update docs/source/en/model_doc/llama.md Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * adds integration test class * adds `dtype` to rotary embedding to cast outputs * adds type to flax llama rotary layer * `make style` * `make fix-copies` * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * applies suggestions from review * Update modeling_flax_llama.py * `make fix-copies` * Update tests/models/llama/test_modeling_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * fixes shape mismatch in FlaxLlamaMLP * applies some suggestions from reviews * casts attn output logits to f32 regardless of dtype * adds attn bias using `LlamaConfig.attention_bias` * adds Copied From comments to Flax Llama test * mistral and persimmon test change -copy from llama * updates docs index * removes Copied from in tests it was preventing `make fix-copies` from succeeding * quality and style * ignores FlaxLlama input docstring * adds revision to `_CHECKPOINT_FOR_DOC` * repo consistency and quality * removes unused import * removes copied from from Phi test now diverges from llama tests following FlaxLlama changes * adds `_REAL_CHECKPOINT_FOR_DOC` * removes refs from pr tests * reformat to make ruff happy --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -50,6 +50,9 @@ come in several checkpoints they each contain a part of each weight of the model
|
||||
|
||||
- The LLaMA tokenizer is a BPE model based on [sentencepiece](https://github.com/google/sentencepiece). One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of the word (e.g. "Banana"), the tokenizer does not prepend the prefix space to the string.
|
||||
|
||||
This model was contributed by [zphang](https://huggingface.co/zphang) with contributions from [BlackSamorez](https://huggingface.co/BlackSamorez). The code of the implementation in Hugging Face is based on GPT-NeoX [here](https://github.com/EleutherAI/gpt-neox). The original code of the authors can be found [here](https://github.com/facebookresearch/llama). The Flax version of the implementation was contributed by [afmck](https://huggingface.co/afmck) with the code in the implementation based on Hugging Face's Flax GPT-Neo.
|
||||
|
||||
|
||||
Based on the original LLaMA model, Meta AI has released some follow-up works:
|
||||
|
||||
- **Llama2**: Llama2 is an improved version of Llama with some architectural tweaks (Grouped Query Attention), and is pre-trained on 2Trillion tokens. Refer to the documentation of Llama2 which can be found [here](llama2).
|
||||
@@ -112,3 +115,13 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
|
||||
[[autodoc]] LlamaForSequenceClassification
|
||||
- forward
|
||||
|
||||
## FlaxLlamaModel
|
||||
|
||||
[[autodoc]] FlaxLlamaModel
|
||||
- __call__
|
||||
|
||||
## FlaxLlamaForCausalLM
|
||||
|
||||
[[autodoc]] FlaxLlamaForCausalLM
|
||||
- __call__
|
||||
|
||||
Reference in New Issue
Block a user