Refactor MambaCache to modeling_mamba.py (#38086)
* Refactor MambaCache to modeling_mamba.py (parity with Zamba) * ruff * fix dummies * update * update * remove mamba ref in cache tests * remove cache_implementation from tests * update * ruff * ruff * sneaky regression * model consistency * fix test_multi_gpu_data_parallel_forward * fix falcon slow tests * ruff * ruff * add sample false * try to fix slow tests * Revert "fix test_multi_gpu_data_parallel_forward" This reverts commit 66b7162c7c5c5ce8a73ccf48cffc8a96343ebb33. * fix tests on nvidia t4, remove dataparallel tests from mamba * ruff * remove DDP tests from mamba and falcon_mamba * add explicit error for MambaCache * mamba2 also needs to init cache in prepare_inputs_for_generation * ruff * ruff * move MambaCache to its own file * ruff * unprotected import fix * another attempt to fix unprotected imports * Revert "another attempt to fix unprotected imports" This reverts commit 2338354fcab630de5899321f5daced5fb312c2a2. * fixing unprotected import, attempt 3 * Update src/transformers/cache_utils.py * ruff's fault * fix arthur review * modular falcon mamba * found a hack * fix config docs * fix docs * add export info * merge modular falcon branch * oopsie * fix fast path failing * new approach * oopsie * fix types * Revert new pragma in modular This reverts commit 80b1cf160ee251536f07c40b8a0857d499e70db6. * trying another modular workaround * review & fix ci * oopsie * clear prepare_inputs on mamba/mamba2/falcon_mamba
This commit is contained in:
committed by
GitHub
parent
a419a40234
commit
1aa7256f01
@@ -141,7 +141,12 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
return updated_node
|
||||
|
||||
def leave_ImportFrom(self, original_node, updated_node):
|
||||
"""The imports from other file types (configuration, processing etc) should use original model name."""
|
||||
"""
|
||||
The imports from other file types (configuration, processing etc) should use original model name.
|
||||
Also, no replaces on absolute imports (e.g. `from mamba_ssm import ...`)
|
||||
"""
|
||||
if len(original_node.relative) == 0: # no replaces on absolute imports
|
||||
return original_node
|
||||
if self.original_new_model_name != self.new_name and m.matches(updated_node.module, m.Name()):
|
||||
patterns = "|".join(ALL_FILE_TYPES)
|
||||
regex = rf"({patterns})_{self.new_name}"
|
||||
|
||||
Reference in New Issue
Block a user