Add recurrent gemma (#30143)

* Fork.

* RecurrentGemma initial commit.

* Updating __init__.py.

* Minor modification to how we initialize the cache.
Changing how the config specifies the architecture.

* Reformat code to 4 spaces.
Fixed a few typos.

* Fixed the forward pass.
Still unclear on the cache?

* Fixed the RecurrentGemmaForCausalLM

* Minor comment that we might not need attention_mask and output_attention arguments.

* Now cache should work as well.

* Adding a temporary example to check whether the model generation works.

* Adding the tests and updating imports.

* Adding the example file missing in the previous commit.

* First working example.

* Removing .gitignore and reverting parts of __init__.

* Re-add .gitignore.

* Addressing comments for configuration.

* Move mask creation to `_prepare_inputs_for_generation`.

* First try at integration tests:
1. AttributeError: 'GriffinCausalLMOutput' object has no attribute 'attentions'.
2. `cache_position` not passed

* Transfoering between machines.

* Running normal tests.

* Minor fix.

* More fixes.

* Addressing more comments.

* Minor fixes.

* first stab at cleanup

* more refactoring

* fix copies and else

* renaming and get init to work

* fix causal mask creation

* update

* nit

* fix a hell lot of things

* updates

* update conversion script

* make all keys importable

* nits

* add auto mappings

* properly convert ffw_up and down

* add scaling

* fix generations

* for recurrent dtype

* update

* fix going beyong window

* fixup

* add missing files

* current updates to remove last einops

* finish modeling refactor

* TADA

* fix compile

* fix most failing testt ? ?

* update tests

* refactor and update

* update

* nits, fixup and update tests

* more fixup

* nits

* fix imports

* test format

* fixups

* nits

* tuple typing

* fix code quality

* add model card

* fix doc

* skip most generation tests

* nits

* style

* doc fixes

* fix pr and check_copies?

* last nit

* oupsy

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <hi@lysand.re>

* update

* Update src/transformers/models/recurrent_gemma/convert_recurrent_gemma_to_hf.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* update based on review

* doc nit

* fix quality

* quality

* fix slow test model path

* update default dype

* ignore attributes that can be safely ignored in check config attributes

* 0lallalala come on

* save nit

* style

* remove to dict update

* make sure we can also run in float16

* style

---------

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
Co-authored-by: Aleksandar Botev <botev@google.com>
Co-authored-by: Leonard Berrada <lberrada@users.noreply.github.com>
Co-authored-by: anushanf <anushanf@google.com>
Co-authored-by: botev <botevmg@gmail.com>
Co-authored-by: Lysandre Debut <hi@lysand.re>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Arthur
2024-04-10 16:59:13 +02:00
committed by GitHub
parent 33bca5419c
commit 0fe44059ae
32 changed files with 2001 additions and 1 deletions

View File

@@ -34,6 +34,8 @@ CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
SPECIAL_CASES_TO_ALLOW = {
# used to compute the property `self.chunk_length`
"EncodecConfig": ["overlap"],
# used to compute the property `self.layers_block_type`
"RecurrentGemmaConfig": ["block_types"],
# used as in the config to define `intermediate_size`
"MambaConfig": ["expand"],
# used as `self.bert_model = BertModel(config, ...)`

View File

@@ -86,6 +86,7 @@ PRIVATE_MODELS = [
# Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested
"RecurrentGemmaModel", # Building part of bigger (tested) model.
"FuyuForCausalLM", # Not tested fort now
"InstructBlipQFormerModel", # Building part of bigger (tested) model.
"UMT5EncoderModel", # Building part of bigger (tested) model.

View File

@@ -768,6 +768,7 @@ src/transformers/models/rag/modeling_tf_rag.py
src/transformers/models/rag/retrieval_rag.py
src/transformers/models/realm/modeling_realm.py
src/transformers/models/realm/retrieval_realm.py
src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py
src/transformers/models/regnet/configuration_regnet.py
src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py