Refactor MambaCache to modeling_mamba.py (#38086)
* Refactor MambaCache to modeling_mamba.py (parity with Zamba) * ruff * fix dummies * update * update * remove mamba ref in cache tests * remove cache_implementation from tests * update * ruff * ruff * sneaky regression * model consistency * fix test_multi_gpu_data_parallel_forward * fix falcon slow tests * ruff * ruff * add sample false * try to fix slow tests * Revert "fix test_multi_gpu_data_parallel_forward" This reverts commit 66b7162c7c5c5ce8a73ccf48cffc8a96343ebb33. * fix tests on nvidia t4, remove dataparallel tests from mamba * ruff * remove DDP tests from mamba and falcon_mamba * add explicit error for MambaCache * mamba2 also needs to init cache in prepare_inputs_for_generation * ruff * ruff * move MambaCache to its own file * ruff * unprotected import fix * another attempt to fix unprotected imports * Revert "another attempt to fix unprotected imports" This reverts commit 2338354fcab630de5899321f5daced5fb312c2a2. * fixing unprotected import, attempt 3 * Update src/transformers/cache_utils.py * ruff's fault * fix arthur review * modular falcon mamba * found a hack * fix config docs * fix docs * add export info * merge modular falcon branch * oopsie * fix fast path failing * new approach * oopsie * fix types * Revert new pragma in modular This reverts commit 80b1cf160ee251536f07c40b8a0857d499e70db6. * trying another modular workaround * review & fix ci * oopsie * clear prepare_inputs on mamba/mamba2/falcon_mamba
This commit is contained in:
committed by
GitHub
parent
a419a40234
commit
1aa7256f01
@@ -110,6 +110,13 @@ outputs = model.generate(**inputs, max_new_tokens=100)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## FalconMambaCache
|
||||
|
||||
[[autodoc]] FalconMambaCache
|
||||
- update_conv_state
|
||||
- update_ssm_state
|
||||
- reset
|
||||
|
||||
## FalconMambaConfig
|
||||
|
||||
[[autodoc]] FalconMambaConfig
|
||||
|
||||
@@ -116,6 +116,13 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## MambaCache
|
||||
|
||||
[[autodoc]] MambaCache
|
||||
- update_conv_state
|
||||
- update_ssm_state
|
||||
- reset
|
||||
|
||||
## MambaConfig
|
||||
|
||||
[[autodoc]] MambaConfig
|
||||
|
||||
Reference in New Issue
Block a user